4 回答

TA貢獻1816條經驗 獲得超4個贊
您可以使用np.vectorize將字典值映射到數組
In [6]: b_dict = { 1:10., 2:20., 3:30 }
In [7]: a = np.full( 10, 2 )
In [8]: np.vectorize(b_dict.get)(a)
Out[8]: array([20., 20., 20., 20., 20., 20., 20., 20., 20., 20.])

TA貢獻1876條經驗 獲得超5個贊
解決問題的另一種方法:
from operator import itemgetter
np.array(itemgetter(*a)(b_dict))
輸出:
[20., 20., 20., 20., 20., 20., 20., 20., 20., 20.]
比較:
#@kmundnic solution
def m1(a):
def get_b(x):
b_dict = { 1:10., 2:20., 3:30. }
return b_dict[x]
return np.fromiter(map(get_b, a),dtype=np.float)
#@bigbounty solution
def m2(a):
b_dict = { 1:10., 2:20., 3:30. }
return np.vectorize(b_dict.get)(a)
#@Ehsan solution
def m3(a):
b_dict = { 1:10., 2:20., 3:30. }
return np.array(itemgetter(*a)(b_dict))
#@Sun Bear solution
def m4(a):
def get_b( a ):
b_dict = { 1:10., 2:20., 3:30. }
return b_dict[ a ]
return np.array( [get_b(i) for i in a] )
in_ = [np.full( n, 2 ) for n in [10,100,1000,10000]]
對于small dictionary,似乎m2在大輸入時最快,而m3在小輸入時最快。
對于更大的字典:
b_dict = dict(zip(np.arange(100),np.arange(100)))
in_ = [np.full(n,50) for n in [10,100,1000,10000]]
m3是最快的方法。您可以根據您的字典大小和鍵數組大小進行選擇。

TA貢獻1793條經驗 獲得超6個贊
map使用and怎么樣np.fromiter?
def get_b( a ):
b_dict = { 1:10., 2:20., 3:30. }
return b_dict[ a ]
a = np.full( 10, 2 )
b = np.fromiter(map(get_b, a), dtype=np.float64)
編輯 1:小時間比較:
%timeit np.array( [get_b(i) for i in a] )
5.58 μs ± 123 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
%timeit np.fromiter(map(get_b, a), dtype=np.float64)
5.77 μs ± 177 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
%timeit np.vectorize(b_dict.get)(a)
12.9 μs ± 76.8 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
編輯 2:好像那個例子太小了:
a = np.full( 1000, 2 )
%timeit np.array( [get_b(i) for i in a] )
415 μs ± 9.13 μs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
%timeit np.fromiter(map(get_b, a), dtype=np.float64)
383 μs ± 2.5 μs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
%timeit np.vectorize(b_dict.get)(a)
68.6 μs ± 625 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)

TA貢獻1111條經驗 獲得超0個贊
必須b_dict
是字典嗎?如果你有一個數組,例如。ref = np.array([0, 10,20,30])
您可以按索引快速選擇值,?ref[a]
。在使用 numpy 時,我會盡量避免使用 dict。
我發現使用 NumPy 的索引會使性能比嘗試使用 python 快幾個到幾個數量級dict
。下面是一個進行此類比較的腳本。
import numpy as np
from operator import itemgetter
import timeit
import matplotlib.pyplot as plt
#@kmundnic solution
def m1(a):
? ? def get_b(x):
? ? ? ? b = {? 1:10., 2:20., 3:30. }
? ? ? ? #b = dict( zip( np.arange(1,101),np.arange(10,1001,10) ) )
? ? ? ? return b[x]
? ? return np.fromiter(map(get_b, a),dtype=np.float)
#@bigbounty solution
def m2(a):
? ? b = {? 1:10., 2:20., 3:30. }
? ? #b = dict( zip( np.arange(1,101),np.arange(10,1001,10) ) )
? ? return np.vectorize(b.get)(a)
#@Ehsan solution
def m3(a):
? ? b = {? 1:10., 2:20., 3:30. }
? ? #b = dict( zip( np.arange(1,101),np.arange(10,1001,10) ) )
? ? return np.array(itemgetter(*a)(b))
#@Sun Bear solution
def m4(a):
? ? def get_b( a ):
? ? ? ? b = {? 1:10., 2:20., 3:30. }
? ? ? ? #b = dict( zip( np.arange(1,101),np.arange(10,1001,10) ) )
? ? ? ? return b[ a ]
? ? return np.array( [get_b(i) for i in a] )
#@hpaulj solution
def m5(a):
? ? b = np.array([10, 20, 30])
? ? #b = np.arange(10,1001,10)?
? ? return b[a]
? ? ? ??
sizes=[10,100,1000,10000]
pm1 = []
pm2 = []
pm3 = []
pm4 = []
pm5 = []
for size in sizes:
? ? a = np.full( size, 2 )
? ? pm1.append( timeit.timeit( 'm1(a)', number=1000, globals=globals() ) )
? ? pm2.append( timeit.timeit( 'm2(a)', number=1000, globals=globals() ) )
? ? pm3.append( timeit.timeit( 'm3(a)', number=1000, globals=globals() ) )
? ? pm4.append( timeit.timeit( 'm4(a)', number=1000, globals=globals() ) )
? ? pm5.append( timeit.timeit( 'm5(a)', number=1000, globals=globals() ) )
print( 'm1 slower than m5 by :',np.array(pm1) / np.array(pm5) )
print( 'm2 slower than m5 by :',np.array(pm2) / np.array(pm5) )
print( 'm3 slower than m5 by :',np.array(pm3) / np.array(pm5) )
print( 'm4 slower than m5 by :',np.array(pm4) / np.array(pm5) )
fig = plt.figure()
ax = fig.add_subplot(1, 1, 1)
ax.plot( sizes, pm1, label='m1' )
ax.plot( sizes, pm2, label='m2' )
ax.plot( sizes, pm3, label='m3' )
ax.plot( sizes, pm4, label='m4' )
ax.plot( sizes, pm5, label='m5' )
ax.grid( which='both' )
ax.set_xscale('log')
ax.set_yscale('log')
ax.legend()
ax.get_xaxis().set_label_text( label='len(a)', fontweight='bold' )
ax.get_yaxis().set_label_text( label='Runtime (sec)', fontweight='bold' )
plt.show()
結果:
長度 (b) = 3:
m1 slower than m5 by : [? 4.22462367? 29.79407905? 85.03454097 339.2915358 ]
m2 slower than m5 by : [? 8.64220685 11.57175871 13.76761749 46.1940683 ]
m3 slower than m5 by : [? 3.25785432? 21.63131578? 54.71305704 220.15777696 ]
m4 slower than m5 by : [? 4.60710166? 30.93616607? 91.8936744? 371.00398273 ]
長度 (b) = 100:
m1 slower than m5 by : [? 218.98603678? 1976.50128737? 9697.76615006 17742.79151719 ]
m2 slower than m5 by : [? 41.76535891? 53.85600913 109.35129345 164.13075291 ]
m3 slower than m5 by : [? 24.82715462? 36.77830986? 87.56253196 141.04493237 ]
m4 slower than m5 by : [? 222.04184193? 2001.72120836? 9775.22464369 18431.00155305 ]
添加回答
舉報