我有兩個張量,張量 a 和張量 b。我想獲取張量 b 中值的所有索引。例如。a = torch.Tensor([1,2,2,3,4,4,4,5])b = torch.Tensor([1,2,4])1, 2, 4我想要張量 a的索引。我可以通過以下代碼來做到這一點。a = torch.Tensor([1,2,2,3,4,4,4,5])b = torch.Tensor([1,2,4])mask = torch.zeros(a.shape).type(torch.bool)print(mask)for e in b: mask = mask + (a == e) print(mask)如果沒有 ,我該怎么做for?
2 回答

繁花不似錦
由于 PyTorch
TA貢獻1851條經驗 獲得超4個贊
由于 PyTorch1.10
和isin()
(isinf()
以及許多其他 numpy 等效項)也可用,因此您可以簡單地執行以下操作:
torch.isin(a,?b)
這會給你:
Out[4]:?tensor([?True,??True,??True,?False,??True,??True,??True,?False])
舊答案:
這是你想要的嗎?:
np.in1d(a.numpy(),?b.numpy())
將導致:
array([?True,??True,??True,?False,??True,??True,??True,?False])

拉風的咖菲貓
TA貢獻1995條經驗 獲得超2個贊
如果您只是不想使用 for 循環,則可以使用列表理解:
mask = [a[index] for index in b]
如果甚至不想使用“for”一詞,您可以隨時將張量轉換為 numpy 并使用 numpy 索引。
mask = torch.tensor(a.numpy()[b.numpy()])
更新
可能誤解了你的問題。在這種情況下,我想說實現這一點的最佳方法是通過列表理解。(切片可能無法實現這一點。
mask = [index for index,value in enumerate(a) if value in b.tolist()]
這會迭代 a 中的每個元素,獲取它們的索引和值,如果該值在 b 內,則獲取索引。
添加回答
舉報
0/150
提交
取消