2 回答

TA貢獻1824條經驗 獲得超8個贊
我只是想從文檔中指出以下內容np.vectorize:
提供該vectorize功能主要是為了方便,而不是為了性能。該實現本質上是一個 for 循環。
所以,實際上,你并沒有在這里使用 NumPy 的矢量化能力。使用 NumPy 的布爾數組索引和np.where,您可以重寫您的函數,這樣您就有了“真正的”矢量化。
這是我這邊的一個想法。我不得不承認,實際的代碼看起來很丑陋,但是通過預先計算布爾數組,我們可以最大限度地減少處理時間和內存使用量。
def f_vec(x, c=0.7):
# Initialize output array of same size and type as input array
out = np.zeros_like(x)
# Pre-calculate boolean arrays to prevent multiple calculation in following steps
x_gtq_0 = (x >= 0)
x_lt_0 = (x < 0)
x_gt_c = (x > c)
x_ltq_2c = (x <= 2 * c)
x_gt_2c = (x > 2 * c)
abs_x = np.abs(x)
abs_x_gt_c = abs_x > c
abs_x_ltq_2c = abs_x <= 2 * c
abs_x_gt_2c = (abs_x > 2 * c)
# Re-writing if-else blocks as operations on before calculated boolean arrays
out[np.where(x_gtq_0 & x_gt_c & x_ltq_2c)] = x[np.where(x_gtq_0 & x_gt_c & x_ltq_2c)] - c
out[np.where(x_gtq_0 & x_gt_2c)] = c
out[np.where(x_lt_0 & abs_x_gt_c & abs_x_ltq_2c)] = c - abs_x[np.where(x_lt_0 & abs_x_gt_c & abs_x_ltq_2c)]
out[np.where(x_lt_0 & abs_x_gt_2c)] = -c
return out
我添加了以下小型測試功能來進行一些比較:
def test(x):
print(x.shape)
vfunc = np.vectorize(f)
tic = time.perf_counter()
res_func = vfunc(x, c=0.7)
print(time.perf_counter() - tic)
tic = time.perf_counter()
res_vec = f_vec(x, c=0.7)
print(time.perf_counter() - tic)
print('Differences: ', np.count_nonzero(np.abs(res_func - res_vec) > 10e-9), '\n')
test((np.random.rand(10) - 0.5) * 4)
test((np.random.rand(1000, 1000) - 0.5) * 4)
test((np.random.rand(1920, 1280, 3) - 0.5) * 4)
這些是結果:
(10,)
0.0001590869999999467
7.954300000001524e-05
Differences: 0
(1000, 1000)
1.53853834
0.0843256779999999
Differences: 0
(1920, 1280, 3)
10.974010127
0.7489308680000004
Differences: 0
np.vectorize因此,在性能方面,對于較大的輸入,與實際矢量化方法之間的差異是巨大的。不過,如果np.vectorize解決方案足以滿足您的輸入,并且您不想花太多精力重新編寫代碼,請堅持下去!正如我所說,我只是想表明,矢量化不僅僅如此。
希望有幫助!
----------------------------------------
System information
----------------------------------------
Platform: Windows-10-10.0.16299-SP0
Python: 3.8.1
NumPy: 1.18.1
----------------------------------------

TA貢獻1878條經驗 獲得超4個贊
此功能適用于您的功能。試試看:
vfunc = np.vectorize(f) vfunc(a, c=0.7)
如果您仍然遇到錯誤 - 請使用輸入數據示例發布它們
添加回答
舉報