2 回答

TA貢獻1876條經驗 獲得超7個贊
gather函數從批處理(0th)軸返回提供的索引值。因此,它為我們提供了形狀為 (10, 10) 的批次中的第一個 (index:0) 和第二個 (index:1) 樣本 (形狀 (10,)) 的列表 (length=10) 而我們想要第一個批次中每個樣本的(索引:0)和第二(索引:1)特征點。為了解決這個問題,我們可以在使用gather函數之前轉置張量,以便gather函數選擇正確的值,最后生成的張量應該再次轉置。
_input = Input(shape=(5, 2))
x = Reshape((5 * 2,))(_input)
x_ = Lambda(lambda t: K.transpose(K.gather(K.transpose(t), [0, 1]*5)))(x)
輸出:
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_1 (InputLayer) [(None, 5, 2)] 0
_________________________________________________________________
reshape (Reshape) (None, 10) 0
_________________________________________________________________
lambda (Lambda) (None, 10) 0
=================================================================

TA貢獻1812條經驗 獲得超5個贊
如果你使用tf.gather(),你可以避免使用@bit01 描述的轉置操作。中有一個axis論點tf.gather()。
_input = Input(shape=(5, 2))
x = Reshape((5 * 2,))(_input)
x_ = Lambda(lambda t: tf.gather(t, [0, 1]*5, axis=1))(x)
# Layer (type) Output Shape Param #
# =================================================================
# input_2 (InputLayer) (None, 5, 2) 0
# _________________________________________________________________
# reshape_2 (Reshape) (None, 10) 0
# _________________________________________________________________
# lambda_1 (Lambda) (None, 10) 0
# =================================================================
添加回答
舉報