亚洲在线久爱草,狠狠天天香蕉网,天天搞日日干久草,伊人亚洲日本欧美

為了賬號安全,請及時綁定郵箱和手機立即綁定
已解決430363個問題,去搜搜看,總會有你想問的

如何使用 tf.keras.utils.Sequence API 擴充訓練集?

如何使用 tf.keras.utils.Sequence API 擴充訓練集?

喵喔喔 2023-07-27 16:31:06
TensorFlow 文檔有以下示例,可以說明當訓練集太大而無法放入內存時,如何創建批量生成器以將訓練集批量提供給模型:from skimage.io import imreadfrom skimage.transform import resizeimport tensorflow as tfimport numpy as npimport math# Here, `x_set` is list of path to the images# and `y_set` are the associated classes.class CIFAR10Sequence(tf.keras.utils.Sequence):    def __init__(self, x_set, y_set, batch_size):        self.x, self.y = x_set, y_set        self.batch_size = batch_size    def __len__(self):        return math.ceil(len(self.x) / self.batch_size)    def __getitem__(self, idx):        batch_x = self.x[idx * self.batch_size:(idx + 1) *        self.batch_size]        batch_y = self.y[idx * self.batch_size:(idx + 1) *        self.batch_size]        return np.array([            resize(imread(file_name), (200, 200))               for file_name in batch_x]), np.array(batch_y)我的目的是通過將每個圖像旋轉 3 倍 90° 來進一步增加訓練集的多樣性。在訓練過程的每個 Epoch 中,模型將首先輸入“0° 訓練集”,然后分別輸入 90°、180° 和 270° 旋轉集。如何修改前面的代碼以在CIFAR10Sequence()數據生成器中執行此操作?請不要使用tf.keras.preprocessing.image.ImageDataGenerator(),以免答案失去對其他類型不同性質的類似問題的普遍性。注意:這個想法是在模型被輸入時“實時”創建新數據,而不是(提前)創建并在磁盤上存儲一個比稍后使用的原始訓練集更大的新的增強訓練集(也在批次)在模型的訓練過程中。
查看完整描述

1 回答

?
米脂

TA貢獻1836條經驗 獲得超3個贊

使用自定義Callback并掛鉤到on_epoch_end. 每個紀元結束后更改數據迭代器對象的角度。


示例(內聯記錄)

from skimage.io import imread

from skimage.transform import resize, rotate

import numpy as np


import tensorflow as tf

from tensorflow import keras

from tensorflow.keras import layers

from keras.utils import Sequence 

from keras.models import Sequential

from keras.layers import Conv2D, Activation, Flatten, Dense


# Model architecture  (dummy)

model = Sequential()

model.add(Conv2D(32, (3, 3), input_shape=(15, 15, 4)))

model.add(Activation('relu'))

model.add(Flatten())

model.add(Dense(1))

model.add(Activation('sigmoid'))

model.compile(loss='binary_crossentropy',

              optimizer='rmsprop',

              metrics=['accuracy'])


# Data iterator 

class CIFAR10Sequence(Sequence):

    def __init__(self, filenames, labels, batch_size):

        self.filenames, self.labels = filenames, labels

        self.batch_size = batch_size

        self.angles = [0,90,180,270]

        self.current_angle_idx = 0


    # Method to loop throught the available angles

    def change_angle(self):

      self.current_angle_idx += 1

      if self.current_angle_idx >= len(self.angles):

        self.current_angle_idx = 0

  

    def __len__(self):

        return int(np.ceil(len(self.filenames) / float(self.batch_size)))


    # read, resize and rotate the image and return a batch of images

    def __getitem__(self, idx):

        angle = self.angles[self.current_angle_idx]

        print (f"Rotating Angle: {angle}")


        batch_x = self.filenames[idx * self.batch_size:(idx + 1) * self.batch_size]

        batch_y = self.labels[idx * self.batch_size:(idx + 1) * self.batch_size]

        return np.array([

            rotate(resize(imread(filename), (15, 15)), angle)

               for filename in batch_x]), np.array(batch_y)


# Custom call back to hook into on epoch end

class CustomCallback(keras.callbacks.Callback):

    def __init__(self, sequence):

      self.sequence = sequence


    # after end of each epoch change the rotation for next epoch

    def on_epoch_end(self, epoch, logs=None):

      self.sequence.change_angle()               



# Create data reader

sequence = CIFAR10Sequence(["f1.PNG"]*10, [0, 1]*5, 8)

# fit the model and hook in the custom call back

model.fit(sequence, epochs=10, callbacks=[CustomCallback(sequence)])

輸出:


Rotating Angle: 0

Epoch 1/10

Rotating Angle: 0

Rotating Angle: 0

2/2 [==============================] - 2s 755ms/step - loss: 1.0153 - accuracy: 0.5000

Epoch 2/10

Rotating Angle: 90

Rotating Angle: 90

2/2 [==============================] - 0s 190ms/step - loss: 0.6975 - accuracy: 0.5000

Epoch 3/10

Rotating Angle: 180

Rotating Angle: 180

2/2 [==============================] - 2s 772ms/step - loss: 0.6931 - accuracy: 0.5000

Epoch 4/10

Rotating Angle: 270

Rotating Angle: 270

2/2 [==============================] - 0s 197ms/step - loss: 0.6931 - accuracy: 0.5000

Epoch 5/10

Rotating Angle: 0

Rotating Angle: 0

2/2 [==============================] - 0s 189ms/step - loss: 0.6931 - accuracy: 0.5000

Epoch 6/10

Rotating Angle: 90

Rotating Angle: 90

2/2 [==============================] - 2s 757ms/step - loss: 0.6932 - accuracy: 0.5000

Epoch 7/10

Rotating Angle: 180

Rotating Angle: 180

2/2 [==============================] - 2s 757ms/step - loss: 0.6931 - accuracy: 0.5000

Epoch 8/10

Rotating Angle: 270

Rotating Angle: 270

2/2 [==============================] - 2s 761ms/step - loss: 0.6932 - accuracy: 0.5000

Epoch 9/10

Rotating Angle: 0

Rotating Angle: 0

2/2 [==============================] - 1s 744ms/step - loss: 0.6932 - accuracy: 0.5000

Epoch 10/10

Rotating Angle: 90

Rotating Angle: 90

2/2 [==============================] - 0s 192ms/step - loss: 0.6931 - accuracy: 0.5000

<tensorflow.python.keras.callbacks.History at 0x7fcbdf8bcdd8>



查看完整回答
反對 回復 2023-07-27
  • 1 回答
  • 0 關注
  • 164 瀏覽
慕課專欄
更多

添加回答

舉報

0/150
提交
取消
微信客服

購課補貼
聯系客服咨詢優惠詳情

幫助反饋 APP下載

慕課網APP
您的移動學習伙伴

公眾號

掃描二維碼
關注慕課網微信公眾號