3 回答

TA貢獻1804條經驗 獲得超8個贊
numba
有了numba它可以優化這兩個場景。從語法上講,您只需要構造一個帶有簡單for循環的函數:
from numba import njit
@njit
def get_first_index_nb(A, k):
for i in range(len(A)):
if A[i] > k:
return i
return -1
idx = get_first_index_nb(A, 0.9)
Numba通過JIT(“及時”)編譯代碼并利用CPU級別的優化來提高性能。一個常規的 for無環路@njit裝飾通常會慢比你已經嘗試了在條件滿足后期的情況下的方法。
對于Pandas數值系列df['data'],您可以簡單地將NumPy表示提供給JIT編譯的函數:
idx = get_first_index_nb(df['data'].values, 0.9)
概括
由于numba允許將函數用作參數,并且假設傳遞的函數也可以JIT編譯,則可以找到一種方法來計算第n個索引,其中滿足任意條件的條件func。
@njit
def get_nth_index_count(A, func, count):
c = 0
for i in range(len(A)):
if func(A[i]):
c += 1
if c == count:
return i
return -1
@njit
def func(val):
return val > 0.9
# get index of 3rd value where func evaluates to True
idx = get_nth_index_count(arr, func, 3)
對于第三個最后的值,可以喂相反,arr[::-1]和否定的結果len(arr) - 1,則- 1需要考慮0索引。
績效基準
# Python 3.6.5, NumPy 1.14.3, Numba 0.38.0
np.random.seed(0)
arr = np.random.rand(10**7)
m = 0.9
n = 0.999999
@njit
def get_first_index_nb(A, k):
for i in range(len(A)):
if A[i] > k:
return i
return -1
def get_first_index_np(A, k):
for i in range(len(A)):
if A[i] > k:
return i
return -1
%timeit get_first_index_nb(arr, m) # 375 ns
%timeit get_first_index_np(arr, m) # 2.71 μs
%timeit next(iter(np.where(arr > m)[0]), -1) # 43.5 ms
%timeit next((idx for idx, val in enumerate(arr) if val > m), -1) # 2.5 μs
%timeit get_first_index_nb(arr, n) # 204 μs
%timeit get_first_index_np(arr, n) # 44.8 ms
%timeit next(iter(np.where(arr > n)[0]), -1) # 21.4 ms
%timeit next((idx for idx, val in enumerate(arr) if val > n), -1) # 39.2 ms

TA貢獻1911條經驗 獲得超7個贊
我也想做類似的事情,發現這個問題中提出的解決方案并沒有真正幫助我。特別是,numba對我來說,解決方案比問題本身中介紹的更常規的方法慢得多。我有一個times_all列表,通常為數萬個元素的數量級,并且想要找到第一個元素的索引times_all大于a 的索引time_event。而且我有數千個time_event。我的解決方案是將其times_all分成例如100個元素的塊,首先確定time_event屬于哪個時間段,保留該時間段的第一個元素的索引,然后找到該時間段中的哪個索引,然后將兩個索引相加。這是最少的代碼。對我來說,它的運行速度比本頁中的其他解決方案快幾個數量級。
def event_time_2_index(time_event, times_all, STEPS=100):
import numpy as np
time_indices_jumps = np.arange(0, len(times_all), STEPS)
time_list_jumps = [times_all[idx] for idx in time_indices_jumps]
time_list_jumps_idx = next((idx for idx, val in enumerate(time_list_jumps)\
if val > time_event), -1)
index_in_jumps = time_indices_jumps[time_list_jumps_idx-1]
times_cropped = times_all[index_in_jumps:]
event_index_rel = next((idx for idx, val in enumerate(times_cropped) \
if val > time_event), -1)
event_index = event_index_rel + index_in_jumps
return event_index
添加回答
舉報