2 回答

TA貢獻1831條經驗 獲得超9個贊
最簡單的方法是使用 numpy 的 masked_arrays 根據 allowed_categories 來屏蔽權重,然后查找argmax:
np.ma.masked_where(~np.isin(answers_category,categories_allowed1),answers_weight).argmax()
#2
另一種使用掩碼的方法(假設最大權重是唯一的):
mask = np.isin(answers_category, categories_allowed1)
np.argwhere(answers_weight==answers_weight[mask].max())[0,0]
#2

TA貢獻1795條經驗 獲得超7個贊
我也使用面膜解決了這個問題
inds = np.arange(res.shape[0])
# a mask is an array [False True False False True False]
mask = np.all(res[:,1][:,None] != categories_allowed1,axis=1)
allowed_inds = inds[mask]
# max_ind is not yet the real answer because the not allowed values are not taken into account
max_ind = np.argmax(res[:,0][mask])
real_ind = allowed_inds[max_ind]
添加回答
舉報