在 TensorFlow 之中進行遷移學習
在之前的學習之中,我們都是從定義模型開始,逐步的獲取數據并且對數據進行處理,最終訓練模型以達到一個良好的效果。這些任務都是從零開始訓練的例子,那么我們能不能使用別人已經訓練好的模型來幫助我們來進行相似的工作呢?答案是肯定的,這就是我們這節課要學習到的 “遷移學習”。
1. 什么是遷移學習
遷移學習,顧名思義,就是將學習任務遷移的意思。在實際的應用之中,我們遇到的好多學習任務都具有很強的相似性,比如圖片分割任務和圖片分類任務就很相似,因為他們都是對圖片進行處理的任務。
而對相似數據類型進行處理的任務的模型往往可以互相遷移使用,而不必重新訓練一個新的模型,從而節省時間和空間的開支。
在遷移學習的領域之中,圖片處理的任務往往占據大多數,因為圖片任務的處理往往都含有相似的部分 —— 提取特征。在實際的任務之中,我們往往會使用已經在大型數據集(比如 ImageNet )上訓練得到的模型作為遷移學習的基本模型,以此來提取圖片的特征,從而進行下一步的處理。
簡單來說就是:使用別人訓練好的模型來做自己的學習任務。
2. 遷移學習的基本思路
遷移學習是一個非常寬泛的概念,其的種類包括很多,我們這里以圖片任務為例來講解遷移學習的基本思路:
- 選擇遷移學習的基本模型,一般為在大型數據集上訓練的大型網絡,比如:
- ResNet 網絡;
- GoogLeNet 網絡;
- Xception 網絡;
- 然后選擇使用網絡的哪些部分,一般使用除了頂層的所有部分;
- 編寫剩余的部分,也就是自己接下來的處理過程;
- 訓練自己編寫的處理過程。
這幾個步驟看起來非常簡單,在實際過程之中也是非常簡單的,接下來我們就以在 ImageNet 超大數據集上訓練的 Xception 模型作為基本模型進行遷移學習的演示。
3. 使用遷移學習的實例
這次,我們依然使用貓狗分類的例子來進行實現,具體的代碼如下所示:
注意:部分代碼來自 TensorFlow 官方 API 。
import tensorflow_datasets as tfds
import tensorflow as tf
import numpy as np
train_data, validation_data = tfds.load(
"cats_vs_dogs",
split=["train[:80%]", "train[80%:]"],
as_supervised=True,
)
# 重新調整大小
train_data = train_data.map(lambda x, y: (tf.image.resize(x, (150, 150)), y))
validation_data = validation_data.map(lambda x, y: (tf.image.resize(x, (150, 150)), y))
# 分批次
train_data = train_data.batch(32)
validation_data = validation_data.batch(32)
# 遷移模型
base_model = tf.keras.applications.Xception(
weights="imagenet",
input_shape=(150, 150, 3),
include_top=False,
)
base_model.trainable = False
# 定義輸入
inputs = tf.keras.Input(shape=(150, 150, 3))
# 數據正則化
norm_layer = tf.keras.layers.experimental.preprocessing.Normalization()
x = norm_layer(inputs)
mean = np.array([127.5] * 3)
norm_layer.set_weights([mean, mean ** 2])
# 數據經過遷移模型
x = base_model(x, training=False)
# 數據經過自定義網絡
x = tf.keras.layers.GlobalAveragePooling2D()(x)
outputs = tf.keras.layers.Dense(1)(x)
model = tf.keras.Model(inputs, outputs)
model.summary()
model.compile(
optimizer=tf.keras.optimizers.Adam(),
loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
metrics=[tf.keras.metrics.BinaryAccuracy()],
)
model.fit(train_ds, epochs=20, validation_data=validation_ds)
在這里的代碼之中,我們有幾處需要注意的地方:
- 在數據獲取方面,我們采用了 tfds.load 函數,該函數能夠直接獲取相應的內置數據集,同時進行相應的分割,這里我們按照 8:2 的比例來進行訓練集、測試集的劃分;
- 我們使用 map 函數,來將所有的數據的圖片重新調整至(150, 150)大小,我們將圖片調整至相同大小是為了方便后面的處理;
- 使用 tf.keras.applications.Xception API 來獲取已經預訓練的 Xception 模型,在該 API 之中,包含三個參數:
- weights:表示在哪個數據集上訓練;
- input_shape:表示輸入圖片的形狀;
- include_top=False:表示不含頂層網絡,因為我們要定義自己的網絡。
- 然后我們使用 base_model.trainable=False 語句來將基本模型的訓練參數凍結,這樣我們就不能訓練 Xception 的參數。
- 我們使用了 tf.keras.layers.experimental.preprocessing.Normalization 這個 API 來進行數據的正則化,我們需要通過 norm_layer.set_weights () 設定它的權重:
- 第一個參數是輸入的每個通道的平均值,這里是 255/2=127.5;
- 第二個參數是第一個參數的平方;
- 最后我們采用了一種新的定義模型的方式:先定義一個 Input ,然后將該 Input 逐次經過自己需要處理的網絡層得到 output,最后通過 tf.keras.Model (inputs, output) 來讓 TensorFlow s 根據數據的流動過程來自動生成網絡模型。
最終我們可以得到結果:
Model: "functional_5"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_10 (InputLayer) [(None, 150, 150, 3)] 0
_________________________________________________________________
normalization_3 (Normalizati (None, 150, 150, 3) 7
_________________________________________________________________
xception (Functional) (None, 5, 5, 2048) 20861480
_________________________________________________________________
global_average_pooling2d_2 ( (None, 2048) 0
_________________________________________________________________
dropout_2 (Dropout) (None, 2048) 0
_________________________________________________________________
dense_2 (Dense) (None, 1) 2049
=================================================================
Total params: 20,863,536
Trainable params: 2,049
Non-trainable params: 20,861,487
_________________________________________________________________
Epoch 1/20
291/291 [==============================] - 9s 31ms/step - loss: 0.1607 - binary_accuracy: 0.9313 - val_loss: 0.0872 - val_binary_accuracy: 0.9703
Epoch 2/20
291/291 [==============================] - 8s 27ms/step - loss: 0.1181 - binary_accuracy: 0.9501 - val_loss: 0.0869 - val_binary_accuracy: 0.9690
......
Epoch 20/20
291/291 [==============================] - 8s 27ms/step - loss: 0.0914 - binary_accuracy: 0.9841 - val_loss: 0.0875 - val_binary_accuracy: 0.9765
我們可以看到,我們的模型最終達到了 97% 的分類準確率,這是一個非常高的準確率,而這得益于 Xception 模型強大的特征提取能力。
4. 小結
在這節課之中,我們學習了什么是遷移學習,同時了解了遷移學習的一般思路,同時我們有手動實現了一個使用遷移學習進行分類的例子。在示例之中,我們學習到了一種新的模型定義方式。