TensorFlow 中的回調函數
回調函數是 TensorFlow 訓練之中非常重要的一部分,我們在之前的學習之中或多或少地用到了回調函數。比如在之前的過擬合一節之中,我們就曾經用到了早停回調。那么這節課我們就來學習以下 TensorFlow 之中的回調函數。
1. 什么是回調函數
簡單來說,回調函數就是在訓練到一定階段的時候而執行的函數,我們最常采用的策略是每個Epoch結束之后執行一次回調函數。
回調函數的絕大多數 API 集中在 tf.keras.callbacks 之中,也就是說這是 Keras 之中的一個 API 。由于之前已經學習過早?;卣{,這節課我們來學習一下其他的幾個常用的回調:
- 模型保存回調:tf.keras.callbacks.ModelCheckpoint;
- 學習率回調;tf.keras.callbacks.LearningRateScheduler;
- 自定義回調:tf.keras.callbacks.CallBack。
對于回調的使用方法,也是非常簡單的,假設以下的數組之中定義了我們所需要的全部回調函數:
callbacks = [......]
那么我們在使用回調的時候,之中只需要在訓練函數中指定回調即可:
model.fit(..., ..., callbacks=callbacks)
對于要介紹的回調,我們會首先給出介紹,然后再在統一的代碼之中示例使用。
2. 模型保存回調
模型保存的回調函數為:
tf.keras.callbacks.ModelCheckpoint(
path, monitor='val_loss', verbose=0, save_best_only=False,
save_weights_only=False, save_freq='epoch')
這里只列出來了我們常用的參數,對于其中的每個參數,它們的作用如下:
- path: 保存模型的路徑;
- monitor: 用哪個指標來評價模型的好壞,默認是驗證集上的損失;
- verbose: 輸出日志的等級,只能為 0 或 1;
- save_best_only: 是否只保存最好的模型,模型的好壞由 monitor 指定;
- save_weights_only: 是否只保存權重,默認 False ,也就是保存整個模型;
- save_freq: 保存的頻率,可以為 ‘Epoch’ 或者一個整數,默認為每個 Epoch 保存一次模型;若是一個整數N,則是每訓練 N 個 Batch 保存一次模型。
3. 學習率回調
學習率回調函數為:
tf.keras.callbacks.LearningRateScheduler(
schedule, verbose=0
)
其中 verbose 參數仍然是日志輸出的等級,默認為 0 ;而 schedule 則是一個函數,用來定義一個學習率的變化。其中 schedule 函數的一個示例如下所示:
def my_schedule(epoch, lr):
if epoch < 20:
return lr
else:
return lr * 0.1
該學習率回調是在 20 個 Epoch 之前學習率保持不變,而在 20 個 Epoch 之后,每個 Epoch 學習率變為原來的 0.1 。
可以看出,該 schedule 函數由嚴格的形式,其中第一個參數為訓練的 Epoch ,第二個參數為當前的學習率。
4. 自定義回調
我們在使用回調的過程之中難免會遇到要自定義回調的情況,這時我們便需要編寫類來繼承 tf.keras.callbacks.CallBack 類,從而實現我們的自定義回調。
在自定義回調的過程之中,你可以覆寫不同的函數,從而可以實現在不同的時間來運行我們自定義的函數,這些函數包括:
- on_train_begin(self, logs=None): 在訓練開始時調用;
- on_test_begin(self, logs=None): 在測試開始時調用;
- on_predict_begin(self, logs=None): 在預測開始時調用;
- on_train_end(self, logs=None) 在訓練結束時調用;
- on_test_end(self, logs=None) 在測試結束時調用;
- on_predict_end(self, logs=None) 在預測結束時調用;
- on_train_batch_begin(self, batch, logs=None) 在訓練期間的每個批次之前調用;
- on_test_batch_begin(self, batch, logs=None) 在測試期間的每個批次之前調用;
- on_predict_batch_begin(self, batch, logs=None) 在預測期間的每個批次之前調用;
- on_train_batch_end(self, batch, logs=None) 在訓練期間的每個批次之后調用;
- on_test_batch_end(self, batch, logs=None) 在測試期間的每個批次之后調用;
- on_predict_batch_end(self, batch, logs=None) 在預測期間的每個批次之后調用;
- on_epoch_begin(self, epoch, logs=None) 在每次迭代訓練開始時調用;
- on_epoch_end(self, epoch, logs=None) 在每次迭代訓練結束時調用。
我們可以來使用其中兩個簡單的函數來做一個簡單的示例:
class MyCallback(tf.keras.callbacks.Callback):
def on_epoch_begin(self, epoch, logs=None):
print("Start epoch {}.".format(epoch))
def on_train_begin(self, logs=None):
print("Starting training.")
這個樣子,我們便可以在每次訓練開始,以及每個 Epoch 開始之時進行輸出日志。
5. 程序示例
在這里,我們將同時使用模型保存回調、學習率回調以及自定義回調來做一個簡單的示例:
import tensorflow as tf
model = tf.keras.Sequential([
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(64, activation='relu'),
tf.keras.layers.Dense(10)
])
lr = 0.01
model.compile(
optimizer=tf.keras.optimizers.Adam(learning_rate=lr),
loss="mse"
)
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.fashion_mnist.load_data()
def my_schedule(epoch, lr):
print('Learning rate: ' + str(lr))
if epoch < 5:
return lr
else:
return lr * 0.1
lr_callback = tf.keras.callbacks.LearningRateScheduler(my_schedule)
save_model_callback = tf.keras.callbacks.ModelCheckpoint(
filepath='/model/', save_weights_only=True, verbose=1,
monitor='val_loss', mode='min', save_best_only=True)
class MyCallback(tf.keras.callbacks.Callback):
def on_epoch_begin(self, epoch, logs=None):
print("Start epoch {}.".format(epoch))
def on_train_begin(self, logs=None):
print("Starting training.")
model.fit(x_train, y_train,
batch_size=64, epochs=10,
validation_data=(x_test, y_test),
callbacks=[MyCallback(), lr_callback, save_model_callback],
)
在這里,我們按照之前學習的方法定義了三個回調函數,分別是模型保存回調、學習率回調、以及自定義回調。其中模型保存回調會在每次訓練后保存模型、學習率回調會在第五個 Epoch 之后便每個 Epoch 變為原來的 0.1 ,而自定義回調會在訓練開始之前、每個 Epoch 開始之前輸出相應的信息。
于是我們可以得到輸出:
Starting training.
Start epoch 0.
Learning rate: 0.009999999776482582
Epoch 1/10
931/938 [============================>.] - ETA: 0s - loss: 556.1402
Epoch 00001: val_loss improved from inf to 15.96259, saving model to /model/
938/938 [==============================] - 2s 2ms/step - loss: 552.3954 - val_loss: 15.9626
Start epoch 1.
Learning rate: 0.009999999776482582
Epoch 2/10
927/938 [============================>.] - ETA: 0s - loss: 12.4227
Epoch 00002: val_loss improved from 15.96259 to 10.01533, saving model to /model/
938/938 [==============================] - 2s 2ms/step - loss: 12.3927 - val_loss: 10.0153
Start epoch 2.
Learning rate: 0.009999999776482582
Epoch 3/10
914/938 [============================>.] - ETA: 0s - loss: 9.0919
Epoch 00003: val_loss improved from 10.01533 to 8.50834, saving model to /model/
938/938 [==============================] - 2s 2ms/step - loss: 9.0744 - val_loss: 8.5083
Start epoch 3.
Learning rate: 0.009999999776482582
Epoch 4/10
913/938 [============================>.] - ETA: 0s - loss: 8.3514
Epoch 00004: val_loss improved from 8.50834 to 8.26637, saving model to /model/
938/938 [==============================] - 2s 2ms/step - loss: 8.3450 - val_loss: 8.2664
Start epoch 4.
Learning rate: 0.009999999776482582
Epoch 5/10
920/938 [============================>.] - ETA: 0s - loss: 8.2481
Epoch 00005: val_loss improved from 8.26637 to 8.25048, saving model to /model/
938/938 [==============================] - 2s 2ms/step - loss: 8.2544 - val_loss: 8.2505
Start epoch 5.
Learning rate: 0.009999999776482582
Epoch 6/10
933/938 [============================>.] - ETA: 0s - loss: 8.2504
Epoch 00006: val_loss improved from 8.25048 to 8.25035, saving model to /model/
938/938 [==============================] - 2s 2ms/step - loss: 8.2502 - val_loss: 8.2504
Start epoch 6.
Learning rate: 0.0009999999310821295
Epoch 7/10
932/938 [============================>.] - ETA: 0s - loss: 8.2509
Epoch 00007: val_loss improved from 8.25035 to 8.25034, saving model to /model/
938/938 [==============================] - 2s 2ms/step - loss: 8.2501 - val_loss: 8.2503
Start epoch 7.
Learning rate: 9.99999901978299e-05
Epoch 8/10
916/938 [============================>.] - ETA: 0s - loss: 8.2600
Epoch 00008: val_loss improved from 8.25034 to 8.25034, saving model to /model/
938/938 [==============================] - 2s 2ms/step - loss: 8.2501 - val_loss: 8.2503
Start epoch 8.
Learning rate: 9.99999883788405e-06
Epoch 9/10
914/938 [============================>.] - ETA: 0s - loss: 8.2541
Epoch 00009: val_loss did not improve from 8.25034
938/938 [==============================] - 2s 2ms/step - loss: 8.2501 - val_loss: 8.2503
Start epoch 9.
Learning rate: 9.99999883788405e-07
Epoch 10/10
925/938 [============================>.] - ETA: 0s - loss: 8.2446
Epoch 00010: val_loss did not improve from 8.25034
938/938 [==============================] - 2s 2ms/step - loss: 8.2501 - val_loss: 8.2503
<tensorflow.python.keras.callbacks.History at 0x7eff7317f748>
可以看到,我們的三個回調函數都能正確地輸出相應的信息,說明我們的回調函數已經成功生效。
6. 小結
在這節課之中,我們學習了什么是回調函數、模型保存回調、學習率回調以及如何自定義回調。同時我們又通過相應的示例演示了如何使用回調。