2 回答

TA貢獻1795條經驗 獲得超7個贊
這里的關鍵特性是將張量的值lengths作為 的索引傳遞x。這里簡化的例子,我交換了容器的尺寸,所以 index dimenson 首先:
container = torch.arange(0, 50 )
container = f.reshape((5, 10))
>>>tensor([[ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
[10, 11, 12, 13, 14, 15, 16, 17, 18, 19],
[20, 21, 22, 23, 24, 25, 26, 27, 28, 29],
[30, 31, 32, 33, 34, 35, 36, 37, 38, 39],
[40, 41, 42, 43, 44, 45, 46, 47, 48, 49]])
indices = torch.arange( 2, 7, dtype=torch.long )
>>>tensor([2, 3, 4, 5, 6])
print( container[ range( len(indices) ), indices] )
>>>tensor([ 2, 13, 24, 35, 46])
注意:我們從一行中得到一件事(range( len(indices) )產生連續的行號),列號由索引[ row_number ]

TA貢獻1827條經驗 獲得超9個贊
在被這種行為難住之后,我對此進行了更多挖掘,發現它與多維 NumPy 數組的索引行為一致。使這種違反直覺的原因是兩個數組必須具有相同的長度這一不太明顯的事實,即在這種情況下len(lengths)。
事實上,它的工作原理如下: *lengths確定您訪問第一個維度的順序。即,如果您有一個一維數組a = [0, 1, 2, ...., 500],并使用 list 訪問它b = [300, 200, 100],那么結果a[b] = [301, 201, 101](這也解釋了lengths - 1運算符,它只會導致訪問的值與分別在b、 或lengths中使用的索引相同)。*range(len(lengths))然后 * 只需選擇第 - 行i中的第 - 個元素i。如果您有一個方陣,您可以將其解釋為矩陣的對角線。由于您只能訪問前兩個維度上每個位置的單個元素,因此可以將其存儲在一個維度中(從而將您的 3D 張量減少到 2D)。后一個維度簡單地保持“原樣”。
如果你想玩這個,我強烈建議將range()值更改為更長/更短的值,這將導致以下錯誤:
IndexError:形狀不匹配:索引數組無法與形狀(x,)(y,)一起廣播
其中x和y是您的特定長度值。
要以長形式編寫此訪問方法以了解“幕后”發生的情況,還請考慮以下示例:
import torch
x = torch.randint(500, 50, 1)
lengths = torch.tensor([2, 30, 1, 4]) # random examples to explore
diag = list(range(len(lengths))) # [0, 1, 2, 3]
result = []
for i, row in enumerate(lengths):
temp_tensor = x[row, :, :] # temp_tensor.shape = [1, 50, 1]
temp_tensor = temp_tensor.squeeze(0)[diag[i]] # temp_tensor.shape = [1, 1]
result.append(temp.tensor)
# back to pytorch
result = torch.tensor(result)
result.shape # [4, 1]
添加回答
舉報