1 回答

TA貢獻1827條經驗 獲得超9個贊
受到 的啟發this post
,我們可以利用np.searchsorted
-
def find_closest(a, v):
? ? sidx = v.argsort()
? ? v_s = v[sidx]
? ? idx = np.searchsorted(v_s, a)
? ? idx[idx==len(v)] = len(v)-1
? ? idx0 = (idx-1).clip(min=0)
? ??
? ? m = np.abs(a-v_s[idx]) >= np.abs(v_s[idx0]-a)
? ? m[idx==0] = 0
? ? idx[m] -= 1
? ? out = sidx[idx]
? ? return out
更多性能。numexpr在大型數據集上進行提升:
import numexpr as ne
def find_closest_v2(a, v):
? ? sidx = v.argsort()
? ? v_s = v[sidx]
? ? idx = np.searchsorted(v_s, a)
? ? idx[idx==len(v)] = len(v)-1
? ? idx0 = (idx-1).clip(min=0)
? ??
? ? p1 = v_s[idx]
? ? p2 = v_s[idx0]
? ? m = ne.evaluate('(idx!=0) & (abs(a-p1) >= abs(p2-a))', {'p1':p1, 'p2':p2, 'idx':idx})
? ? idx[m] -= 1
? ? out = sidx[idx]
? ? return out
時間安排
設置 :
N,M = 500,100000
a = np.random.rand(N,M)
v = np.random.rand(N)
In [22]: %timeit find_closest_v2(a, v)
4.35 s ± 21.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
In [23]: %timeit find_closest(a, v)
4.69 s ± 173 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
添加回答
舉報