1 回答

TA貢獻2037條經驗 獲得超6個贊
這是對代碼的一種重寫,使其更適合numba.jit. 這并不完全是矢量化解決方案,但我發現該基準測試速度提高了 230 倍。
from numba import jit
from scipy import spatial
@jit
def D_from_cost(cost, D):
# operates on D inplace
ns, nt = cost.shape
for i in range(ns):
for j in range(nt):
D[i+1, j+1] = cost[i,j]+min(D[i, j+1], D[i+1, j], D[i, j])
# avoiding the list creation inside mean enables better jit performance
# D[i+1, j+1] = cost[i,j]+min([D[i, j+1], D[i+1, j], D[i, j]])
@jit
def get_d(D, matchidx):
ns = D.shape[0] - 1
nt = D.shape[1] - 1
d = D[ns,nt]
matchidx[0,0] = ns - 1
matchidx[0,1] = nt - 1
i = ns
j = nt
for k in range(1, ns+nt+3):
idx = 0
if not (D[i-1,j] <= D[i,j-1] and D[i-1,j] <= D[i-1,j-1]):
if D[i,j-1] <= D[i-1,j-1]:
idx = 1
else:
idx = 2
if idx == 0 and i > 1 and j > 0:
# matchidx.append([i-2, j-1])
matchidx[k,0] = i - 2
matchidx[k,1] = j - 1
i -= 1
elif idx == 1 and i > 0 and j > 1:
# matchidx.append([i-1, j-2])
matchidx[k,0] = i-1
matchidx[k,1] = j-2
j -= 1
elif idx == 2 and i > 1 and j > 1:
# matchidx.append([i-2, j-2])
matchidx[k,0] = i-2
matchidx[k,1] = j-2
i -= 1
j -= 1
else:
break
return d, matchidx[:k]
def seqdist2(seq1, seq2):
ns = len(seq1)
nt = len(seq2)
cost = spatial.distance_matrix(seq1, seq2)
# initialize and update D
D = np.full((ns+1, nt+1), np.inf)
D[0, 0] = 0
D_from_cost(cost, D)
matchidx = np.zeros((ns+nt+2,2), dtype=np.int)
d, matchidx = get_d(D, matchidx)
return d, matchidx[::-1].tolist()
assert seqdist2(seq1, seq2) == seqdist(seq1, seq2)
%timeit seqdist2(seq1, seq2) # 1000 loops, best of 3: 365 μs per loop
%timeit seqdist(seq1, seq2) # 10 loops, best of 3: 86.1 ms per loop
以下是一些變化:
cost
是使用 計算的spatial.distance_matrix
。的定義
idx
被一堆丑陋的 if 語句取代,這使得編譯的代碼更快。min([D[i, j+1], D[i+1, j], D[i, j]])
替換為min(D[i, j+1], D[i+1, j], D[i, j])
,即我們不取列表的最小值,而是取三個值的最小值。這導致了令人驚訝的加速jit
。matchidx
被預先分配為 numpy 數組,并在輸出之前截斷為正確的大小。
添加回答
舉報