亚洲在线久爱草,狠狠天天香蕉网,天天搞日日干久草,伊人亚洲日本欧美

為了賬號安全,請及時綁定郵箱和手機立即綁定
已解決430363個問題,去搜搜看,總會有你想問的

Keras:收集張量更改批量維度

Keras:收集張量更改批量維度

墨色風雨 2022-06-14 16:57:59
我有一個形狀為 (5, 2) 的輸入張量,代表 2D 空間中的五個點。我想取第一點,然后從所有五點中減去它。仔細閱讀,我想我可以用它K.gather來切片和重復第一層。在 Lambda 層中應用它后,批處理維度被覆蓋:_input = Input(shape=(5, 2))x = Reshape((5 * 2,))(_input)x_ = Lambda(lambda t: K.gather(t, [0, 1] * 5))(x)結果是:__________________________________________________________________________________________________Layer (type)                    Output Shape         Param #     Connected to                     ==================================================================================================input_1 (InputLayer)            (None, 5, 2)         0                                            __________________________________________________________________________________________________reshape_1 (Reshape)             (None, 10)           0           input_1[0][0]                    __________________________________________________________________________________________________lambda_1 (Lambda)               (10, 10)             0           reshape_1[0][0]                  __________________________________________________________________________________________________我究竟做錯了什么?另外,有沒有更簡單的方法來做到這一點?
查看完整描述

2 回答

?
幕布斯6054654

TA貢獻1876條經驗 獲得超7個贊

gather函數從批處理(0th)軸返回提供的索引值。因此,它為我們提供了形狀為 (10, 10) 的批次中的第一個 (index:0) 和第二個 (index:1) 樣本 (形狀 (10,)) 的列表 (length=10) 而我們想要第一個批次中每個樣本的(索引:0)和第二(索引:1)特征點。為了解決這個問題,我們可以在使用gather函數之前轉置張量,以便gather函數選擇正確的值,最后生成的張量應該再次轉置。


_input = Input(shape=(5, 2))

x = Reshape((5 * 2,))(_input)

x_ = Lambda(lambda t: K.transpose(K.gather(K.transpose(t), [0, 1]*5)))(x)

輸出:


_________________________________________________________________

Layer (type)                 Output Shape              Param #   

=================================================================

input_1 (InputLayer)         [(None, 5, 2)]            0         

_________________________________________________________________

reshape (Reshape)            (None, 10)                0         

_________________________________________________________________

lambda (Lambda)              (None, 10)                0         

=================================================================


查看完整回答
反對 回復 2022-06-14
?
慕雪6442864

TA貢獻1812條經驗 獲得超5個贊

如果你使用tf.gather(),你可以避免使用@bit01 描述的轉置操作。中有一個axis論點tf.gather()。


_input = Input(shape=(5, 2))

x = Reshape((5 * 2,))(_input)

x_ = Lambda(lambda t: tf.gather(t, [0, 1]*5, axis=1))(x)


# Layer (type)                 Output Shape              Param #   

# =================================================================

# input_2 (InputLayer)         (None, 5, 2)              0         

# _________________________________________________________________

# reshape_2 (Reshape)          (None, 10)                0         

# _________________________________________________________________

# lambda_1 (Lambda)            (None, 10)                0         

# =================================================================


查看完整回答
反對 回復 2022-06-14
  • 2 回答
  • 0 關注
  • 241 瀏覽
慕課專欄
更多

添加回答

舉報

0/150
提交
取消
微信客服

購課補貼
聯系客服咨詢優惠詳情

幫助反饋 APP下載

慕課網APP
您的移動學習伙伴

公眾號

掃描二維碼
關注慕課網微信公眾號