使用圖像數據來訓練模型
在之前的學習中,我們曾經學習過使用 Keras 進行圖片分類。具體來說,我們學習了:
- 將二位圖片數據進行扁平化處理;
- 將圖片數據使用卷積神經網絡進行處理。
然而在實際的機器學習之中,當我們使用圖片數據來訓練模型的時候,我們會用到更多的操作。因此在這節課之中我們便整體地了解一下如何使用圖像數據來構建數據集。
在實際的應用過程中,我們最常用的圖片數據加載方式一共有三種,因此這節課我們主要學習這三種主要地圖片加載方式:
- 使用 TFRecord 構建圖片數據集;
- 使用 tf.keras.preprocessing.image.ImageDataGenerator 構建圖片數據集;
- 使用 tf.data.Dataset 原生方法構建數據集。
在這節課之中,我們使用之前用過的貓狗分類的數據集之中的貓的訓練集的圖片進行測試,具體來說,我們可以通過以下代碼準備具體的數據集:
import tensorflow as tf
import os
dataset_url = 'https://storage.googleapis.com/mledu-datasets/cats_and_dogs_filtered.zip'
path_download = os.path.dirname(tf.keras.utils.get_file('cats_and_dogs.zip', origin=dataset_url, extract=True))
cat_train_dir = path_download + '/cats_and_dogs_filtered/train/cats'
這樣,cat_train_dir 就是我們要測試的圖片的路徑。
1. 使用TFRecord構建圖片數據集
TFRecord 是一種二進制的數據文件,也正是因為 TFRecord 是一種二進制的數據文件,因此他的讀寫速度較快,同時也不會產生編碼錯誤之類的問題。
使用 TFRecord 主要包括兩個步驟:
- 生成 TFRecord 文件并進行存儲;
- 讀取 TFRecord 文件,并用于訓練。
1. 生成 TFRecord 文件并進行存儲
既然我們已經獲得了圖片文件所在的目錄,那么我們便可以生成 TFRecord 文件:
from PIL import Image
# 打開TFRecord文件
writer = tf.io.TFRecordWriter('./cat_data')
for img_path in os.listdir(cat_train_dir):
# 讀取并將圖片Resize
img = os.path.join(cat_train_dir, img_path)
img = Image.open(img)
img = img.convert('RGB').resize((32,32)).tobytes()
# 定義標簽,假設貓的標簽是0
label = 0 # 0:cat, 1:dog
# 構建一條數據
example = tf.train.Example(
features = tf.train.Features(
feature = {
'label': tf.train.Feature(int64_list=tf.train.Int64List (value=[int(label)])),
'data' : tf.train.Feature(bytes_list=tf.train.BytesList(value=[img]))
}
)
)
# 將數據寫入
writer.write(example.SerializeToString())
writer.close()
如上述代碼所示,我們首先需要打開 TFRecord 文件,然后再保存結束時再將其關閉。
其次我們首先使用讀取了圖片文件,然后將其進行了以下處理:
- 轉化為 RGB 模式;
- Resize 到 (32,32 )大小;
- 轉化為二進制字節數據。
最后我們使用 tf.train.Example 函數將每一條數據按照 label 和 data 的形式進行封裝,并寫入到 TFRecord 文件之中。
2. 讀取 TFRecord 文件
在讀取的時候,我們會將 TFRecord 文件讀入到內存之中,并且轉化為 tf.data.Dataset ,以便日后使用。
cat_reader = tf.data.TFRecordDataset('./cat_data')
def decode_image(example):
# 加載單條數據
single_example = tf.io.parse_single_example(
example,
{
'data' : tf.io.FixedLenFeature([], tf.string),
'label': tf.io.FixedLenFeature([], tf.int64)
}
)
img = single_example['data']
label = single_example['label']
# 圖片處理
img = tf.io.decode_raw(img, tf.uint8)
img = tf.reshape(img, [32, 32, 3])
return (img, label)
# 映射并分批次
cat_dataset = cat_reader.map(decode_image).batch(32)
print(cat_dataset)
這其中有幾點需要注意:
- 首先我們需要根據存儲的路徑來載入 TFRecord ;
- 我們需要使用一個函數來處理每一條數據,這個函數可以通過 cat_reader.map() 來調用;
- 在 decode_image 之中:
- tf.io.parse_single_example 函數用于加載每一條數據,它接收兩個參數,第一個是當前數據,第二個是數據的格式;
- 我們又采用了 tf.io.decode_raw 函數來對圖片進行了解碼,將其轉化為數字類型。
- 最后我們將圖片數據分批次,大小為32 。
于是我們可以得到輸出為:
<BatchDataset shapes: ((None, 32, 32, 3), (None,)), types: (tf.uint8, tf.int64)>
由此可見,我們正確地加載了該數據集。
2.使用 tf.keras.preprocessing.image.ImageDataGenerator 構建圖片數據集
使用這種方式會非常簡單,我們只需要一條語句即可實現:
cat_generator = tf.keras.preprocessing.image.ImageDataGenerator().flow_from_directory(
directory=path_download + '/cats_and_dogs_filtered/train',
target_size=(32, 32),
batch_size=32,
shuffle = True,
class_mode='binary')
print(cat_generator)
我們可以得到如下輸出:
Found 2000 images belonging to 2 classes.
<tensorflow.python.keras.preprocessing.image.DirectoryIterator object at 0x7f28d0c4a048>
在使用的過程中, directory 參數需要我們注意,該路徑應該是圖片路徑之外的一層路徑。
也就是說,如果圖片路徑為“/a/b/c.jpg”,那么我們要傳入的路徑應該是“/a”。
其余的參數為:
- target_size: 圖片的大小;
- batch_size: 批次大小;
- shuffle: 是否亂序;
- class_modle: 若是binary則為二分類,multi則為多分類。
由于我們得到的數據集是一個迭代器,因此我們不能使用常用的 fit 方式來訓練,我們可以通過以下方式進行訓練:
model.fit_generator(cat_generator)
3. 使用 tf.data.Dataset 原生方法構建數據集
使用這種方法也非常簡單,我們需要兩個步驟來進行數據集的構建:
- 定義圖片加載函數;
- 使用 tf.data.Dataset 構建數據集。
于是我們可以使用如下代碼進行數據集的構建:
def load_image(img_path):
label = tf.constant(0,tf.int8)
img = tf.io.read_file(img_path)
img = tf.image.decode_jpeg(img)
img = tf.image.resize(img, (32, 32))
return (img,label)
cat_dataset = tf.data.Dataset.list_files(cat_train_dir).map(load_image).batch(32)
print(cat_dataset)
在這段程序中,我們首先在載入圖片函數中進行了如下處理:
- 定義標簽,因為全部是貓,因此我們設置為 0 ;
- 使用 tf.io.read_file 讀取文件;
- 因為我們的圖片都是 jpeg 格式,因此我們使用 tf.image.decode_jpeg 來解碼圖片;
- 最后使用 tf.image.resize 來對圖片進行尺寸調整,統一為(32, 32)。
然后我們使用 tf.data.Dataset.list_files() 函數構建了數據集,它接收的第一個參數就是圖片所在的文件夾。
我們可以得到輸出:
<BatchDataset shapes: ((None, 32, 32, None), (None,)), types: (tf.float32, tf.int8)>
可見我們已經成功地構建了數據集。
4. 小結
在這節課之中,我們學習了三種圖片數據加載的方式,他們分別是:
- 使用 TFRecord 構建圖片數據集;
- 使用 tf.keras.preprocessing.image.ImageDataGenerator 構建圖片數據集;
- 使用 tf.data.Dataset 原生方法構建數據集。
其中第一種方式最為快速,而第二種方式更為方便,我們可以根據自己的實際需求來進行選擇。