亚洲在线久爱草,狠狠天天香蕉网,天天搞日日干久草,伊人亚洲日本欧美

為了賬號安全,請及時綁定郵箱和手機立即綁定
已解決430363個問題,去搜搜看,總會有你想問的

Keras 的 BatchNormalization 和 PyTorch 的 BatchNorm2d

Keras 的 BatchNormalization 和 PyTorch 的 BatchNorm2d

心有法竹 2022-07-19 10:35:24
我有一個在 Keras 和 PyTorch 中實現的示例微型 CNN。當我打印兩個網絡的摘要時,可訓練參數的總數相同,但參數總數和批量標準化的參數數不匹配。這是 Keras 中的 CNN 實現:inputs = Input(shape = (64, 64, 1)). # Channel Last: (NHWC)model = Conv2D(filters=32, kernel_size=(3, 3), padding='SAME', activation='relu', input_shape=(IMG_SIZE, IMG_SIZE, 1))(inputs)model = BatchNormalization(momentum=0.15, axis=-1)(model)model = Flatten()(model)dense = Dense(100, activation = "relu")(model)head_root = Dense(10, activation = 'softmax')(dense)以上模型打印的摘要是:Model: "model_8"_________________________________________________________________Layer (type)                 Output Shape              Param #   =================================================================input_9 (InputLayer)         (None, 64, 64, 1)         0         _________________________________________________________________conv2d_10 (Conv2D)           (None, 64, 64, 32)        320       _________________________________________________________________batch_normalization_2 (Batch (None, 64, 64, 32)        128       _________________________________________________________________flatten_3 (Flatten)          (None, 131072)            0         _________________________________________________________________dense_11 (Dense)             (None, 100)               13107300  _________________________________________________________________dense_12 (Dense)             (None, 10)                1010      =================================================================Total params: 13,108,758Trainable params: 13,108,694Non-trainable params: 64_________________________________________________________________正如您在上面的結果中看到的,Keras 中的批量標準化比 PyTorch 具有更多的參數(準確地說是 2 倍)。那么上述 CNN 架構有什么區別呢?如果它們是等效的,那么我在這里缺少什么?
查看完整描述

1 回答

?
慕斯王

TA貢獻1864條經驗 獲得超2個贊

Keras 將許多將在層中“保存/加載”的東西視為參數(權重)。

雖然這兩種實現都自然具有批次的累積“均值”和“方差”,但這些值無法通過反向傳播進行訓練。

然而,這些值每批都會更新,Keras 將它們視為不可訓練的權重,而 PyTorch 只是將它們隱藏起來。這里的“不可訓練”一詞的意思是“不能通過反向傳播訓練”,但并不意味著這些值被凍結了。

總的來說,它們是BatchNormalization一層的 4 組“權重”??紤]到選定的軸(默認 = -1,層大小 = 32)

  • scale(32) - 可訓練

  • offset(32) - 可訓練

  • accumulated means(32) - 不可訓練,但每批更新

  • accumulated std (32) - 不可訓練,但每批更新

在 Keras 中這樣做的好處是,當您保存圖層時,您還可以保存均值和方差值,就像您自動保存圖層中的所有其他權重一樣。當你加載圖層時,這些權重會一起加載。


查看完整回答
反對 回復 2022-07-19
  • 1 回答
  • 0 關注
  • 215 瀏覽
慕課專欄
更多

添加回答

舉報

0/150
提交
取消
微信客服

購課補貼
聯系客服咨詢優惠詳情

幫助反饋 APP下載

慕課網APP
您的移動學習伙伴

公眾號

掃描二維碼
關注慕課網微信公眾號