在 TensorFlow 之中進行圖像分割
在之前的學習之中,對于圖像數據,我們進行過分類等一些常見的任務;這節課我們便來學習一下對于圖像數據的另外一種任務:圖像分割。
1. 什么是圖像分割
圖像分割,顧名思義,就是對圖像數據進行分割,而分類的物體一般是我們認為進行指定的。比如物品分割、人臉分割、醫學病灶分割等。
舉個例子,如下圖所示,原來的圖像是一個馬路的圖片,通過圖像分割,我們會按照不同的物體進行不同的分割,比如車分為一類、人分為一類、建筑分為一類、馬路分為一類等。
圖像分割是很多任務的前提,有很多的任務只有進行了有效的分割之后才能進行有效的處理,比如:
- 醫學病灶識別;
- 人臉情緒識別;
- 路況檢測;
- 自動駕駛;
- 等等。
2. 如何進行圖像分割
圖像分割看上去是一個很復雜的任務,但是實現起來的原理卻是非常簡單,具體來說分為以下幾步:
- 確定要分類的類別,比如,我們可以將圖片中所有的物體分割為 10 類,包括車、人等;
- 對于每個像素點進行數字分類,數字分類的類別數量對應于上述的類別,這里是 10 ;
- 將每個數字類別對應于分類的類別,比如 0 代表車、1 代表人。
可以看出,圖像分割任務其實就是一個分類任務,只不過是對于每個像素點進行分類,也就是確定每個像素點所對應的類別。
在這節課之中,我們會使用圖像分割的基礎數據集:oxford_iiit_pet 圖像分割數據集來進行演示。與此同時,我們也會采用之前學習到的遷移學習的方式來進行模型的構建,從而完成圖像分割的任務。
3. 使用 TensorFlow 進行圖像分割的程序示例
在 oxford_iiit_pet 之中,所有的圖片都是寵物,我們的任務是將圖片中的寵物分割出來,所有的像素點都被分為三類:
- 1: 對應于寵物的一部分;
- 2: 對應于寵物的邊界;
- 3: 不屬于寵物的一部分。
在這里,我們使用代碼有一部分來自 TensorFlow 官方的一個例子,這個例子非常的簡單易懂,作為圖像分割任務的入門是再適合不過的了。
我們會逐步進行代碼的解釋與理解,從而幫助大家學習圖像分割的任務的特點。
1. 首先我們獲取數據集
import tensorflow as tf
from tensorflow_examples.models.pix2pix import pix2pix
import tensorflow_datasets as tfds
import matplotlib.pyplot as plt
dataset, info = tfds.load('oxford_iiit_pet:3.*.*', with_info=True)
這里會下載數據集,因為是圖片數據,因此數據集相對比較大。
2. 定義歸一化處理函數
def normalize(input_image, input_mask):
input_image = tf.cast(input_image, tf.float32) / 255.0
return input_image, input_mask
它接收兩個參數,第一個參數是圖片,我們會將其歸一化到 [0, 1] ,第二個參數是圖像的標簽。
3. 構建數據集
def load_image_train(data):
input_image = tf.image.resize(data['image'], (128, 128))
input_mask = tf.image.resize(data['segmentation_mask'], (128, 128))
input_image, input_mask = normalize(input_image, input_mask)
return input_image, input_mask
def load_image_test(data):
input_image = tf.image.resize(data['image'], (128, 128))
input_mask = tf.image.resize(data['segmentation_mask'], (128, 128))
input_image, input_mask = normalize(input_image, input_mask)
return input_image, input_mask
num_examples = info.splits['train'].num_examples
BATCH = 64
step_per_epch = num_examples // BATCH
train = dataset['train'].map(load_image_train)
test = dataset['test'].map(load_image_test)
train_dataset = train.cache().shuffle(1000).batch(BATCH).repeat()
test_dataset = test.batch(BATCH)
在構建數據集函數之中,我們做了兩件事情:
- 將圖像與標簽重新調整大小到 [128, 128] ;
- 將數據歸一化。
然后我們進行了分批的處理,這里取批次的大小為 64 ,大家可以根據自己的內存或現存大小靈活調整。
4. 構建網絡模型
output_channels = 3
# 獲取基礎模型
base_model = tf.keras.applications.MobileNetV2(input_shape=[128, 128, 3], include_top=False)
# 定義要使用其輸出的基礎模型網絡層
layer_names = [
'block_1_expand_relu', # 64x64
'block_3_expand_relu', # 32x32
'block_6_expand_relu', # 16x16
'block_13_expand_relu', # 8x8
'block_16_project', # 4x4
]
layers = [base_model.get_layer(name).output for name in layer_names]
# 創建特征提取模型
down_stack = tf.keras.Model(inputs=base_model.input, outputs=layers)
down_stack.trainable = False
# 進行降頻采樣
up_stack = [
pix2pix.upsample(512, 3), # 4x4 -> 8x8
pix2pix.upsample(256, 3), # 8x8 -> 16x16
pix2pix.upsample(128, 3), # 16x16 -> 32x32
pix2pix.upsample(64, 3), # 32x32 -> 64x64
]
# 定義UNet網絡模型
def unet_model(output_channels):
inputs = tf.keras.layers.Input(shape=[128, 128, 3])
x = inputs
# 在模型中降頻取樣
skips = down_stack(x)
x = skips[-1]
skips = reversed(skips[:-1])
# 升頻取樣然后建立跳躍連接
for up, skip in zip(up_stack, skips):
x = up(x)
concat = tf.keras.layers.Concatenate()
x = concat([x, skip])
# 這是模型的最后一層
last = tf.keras.layers.Conv2DTranspose(
output_channels, 3, strides=2,
padding='same') #64x64 -> 128x128
x = last(x)
return tf.keras.Model(inputs=inputs, outputs=x)
model = unet_model(output_channels)
model.compile(optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
在這里,我們首先得到了一個預訓練的 MobileNetV2 用于特征提取,在這里我們并沒有包含它的輸出層,因為我們要根據自己的任務靈活調節。
然后定義了我們要使用的 MobileNetV2 的網絡層的輸出,我們使用這些輸出來作為我們提取的特征。
然后我們定義了我們的網絡模型,這個模型的理解有些困難,大家可能不用詳細了解網絡的具體原理。大家只需要知道,這個網絡大致經過的步驟包括:
- 先將數據壓縮(便于數據的處理);
- 然后進行數據的處理;
- 最后將數據解壓返回到原來的大小,從而完成網絡的任務。
最后我們編譯該模型,我們使用 adam 優化器,交叉熵損失函數(因為圖像分割是個分類任務)。
5. 模型的訓練
epoch = 20
valid_steps = info.splits['test'].num_examples//BATCH
model_history = model.fit(train_dataset, epochs=epoch,
steps_per_epoch=step_per_epch,
validation_steps=valid_steps,
validation_data=test_dataset)
loss = model_history.history['loss']
val_loss = model_history.history['val_loss']
這邊就是一個簡單的訓練過程,我們可以得到如下輸出:
Epoch 1/20
57/57 [==============================] - 296s 5s/step - loss: 0.4928 - accuracy: 0.7995 - val_loss: 0.6747 - val_accuracy: 0.7758
......
Epoch 20/20
57/57 [==============================] - 276s 5s/step - loss: 0.2586 - accuracy: 0.9218 - val_loss: 0.2821 - val_accuracy: 0.9148
我們可以看到我們最后達到了 91% 的準確率,還是一個可以接受的結果。
感興趣的同學可以嘗試一下進行結果的可視化,從而更加直觀的查看到結果。
4. 小結
在這節課之中,我們學習了什么是圖像分割,同時了解了圖像分割的簡單的實現方式,最終我們通過一個示例來了解了如何在 TensorFlow 之中進行圖像分割。