亚洲在线久爱草,狠狠天天香蕉网,天天搞日日干久草,伊人亚洲日本欧美

為了賬號安全,請及時綁定郵箱和手機立即綁定
已解決430363個問題,去搜搜看,總會有你想問的

如何索引具有形狀 (batch_size, 200, 256) 的張量以獲得

如何索引具有形狀 (batch_size, 200, 256) 的張量以獲得

HUX布斯 2022-06-22 18:27:00
我有形狀為 (batch_size, 200, 256) 的 LSTM 層的輸出,其中 200 是標記序列的長度,256 是 LSTM 輸出維度。我還有另一個形狀為 (batch_size) 的張量,它是我想從批次中的每個樣本序列中切出的標記的索引列表。如果令牌索引不是 -1,我將切出一個令牌向量表示(長度 = 256)。如果令牌索引為 -1,我將給出零向量(長度 = 256)。預期的輸出結果具有形狀 (batch_size, 1, 256)。我該怎么做?謝謝這是我到目前為止嘗試過的bidir = concatenate([forward, backward]) # shape = (batch_size, 200, 256) dropout = Dropout(params['dropout_rate'])(bidir)def slice_by_tensor(x):    matrix_to_slice = x[0]    index_tensor = x[1]    out_tensor = tf.where(index_tensor == -1,                           tf.zeros(tf.shape(tf.gather(matrix_to_slice,                                                       index_tensor, axis=1))),                           tf.gather(matrix_to_slice, index_tensor, axis=1))    return out_tensorrepresentation_stack0 = Lambda(lambda x: slice_by_tensor(x))([dropout,stack_idx0]) # stack_idx0 shape is (batch_size) # I got output with shape (batch_size, batch_size, 256) with this code
查看完整描述

1 回答

?
慕娘9325324

TA貢獻1783條經驗 獲得超4個贊

a=tf.reshape(tf.range(2*3*4),shape=(2,3,4))

#     [[[ 0,  1,  2,  3],

#        [ 4,  5,  6,  7],

#        [ 8,  9, 10, 11]],


#      [[12, 13, 14, 15],

#      [16, 17, 18, 19],

#       [20, 21, 22, 23]]]


b=tf.constant([-1,2]) 


aa=tf.pad(a,[[0,0],[1,0],[0,0]]) 


bb=b+1 


index=tf.stack([tf.range(tf.size(b)),bb],axis=-1) 

res=tf.expand_dims(tf.gather_nd(aa, index),axis=1)

#[[[ 0,  0,  0,  0]],

#[[20, 21, 22, 23]]]

當 index 為 -1 時,我們需要像張量這樣的零。所以我們可以先沿第二個軸填充原始張量。然后將索引增加 1。在此之后,使用tf.gather_nd將返回答案。


查看完整回答
反對 回復 2022-06-22
  • 1 回答
  • 0 關注
  • 114 瀏覽
慕課專欄
更多

添加回答

舉報

0/150
提交
取消
微信客服

購課補貼
聯系客服咨詢優惠詳情

幫助反饋 APP下載

慕課網APP
您的移動學習伙伴

公眾號

掃描二維碼
關注慕課網微信公眾號