2 回答

TA貢獻1797條經驗 獲得超6個贊
當我首先初始化模型并將其作為額外參數添加到回調方法時,它會起作用。所以解決辦法如下:
class LossCallback(tf.keras.callbacks.Callback):
def __init__(self, model):
super(LossCallback, self).__init__()
model.beta_x = tf.Variable(1.0, trainable=False, name='weight1', dtype=tf.float32)
def on_epoch_begin(self, epoch, logs=None):
tf.keras.backend.set_value(self.model.beta_x, tf.constant(0.5) * epoch)
def on_epoch_end(self, epoch, logs=None):
logs = logs or {}
logs['beta_x'] = tf.keras.backend.get_value(self.model.beta_x)
model = create_model() # initialize custom keras model
callback = LossCallback(model)
model.fit(..., callbacks=[callback])

TA貢獻1876條經驗 獲得超7個贊
避免直接編輯變量。您必須像這樣訪問 keras 變量
import tensorflow as tf
from tensorflow import keras
import numpy as np
def warm_up(epoch, logs):
val = keras.backend.get_value(model.optimizer.lr)
val *= 1.1
tf.keras.backend.set_value(model.optimizer.lr, val)
callback = tf.keras.callbacks.LambdaCallback(on_epoch_begin=warm_up)
model = tf.keras.models.Sequential([
keras.layers.Dense(10, 'relu'),
keras.layers.Dense(1, 'sigmoid')
])
model.compile(loss='binary_crossentropy')
X_train = tf.random.uniform((10,10))
y_train = tf.ones((10,))
model.fit(X_train, y_train,
callbacks = [callback])
請注意我如何獲取當前值,例如val = keras.backend.get_value(model.optimizer.lr)。這是在運行時獲取正確值的正確方法。另外,不要在循環內使用或聲明新變量。您可能可以new_value通過閱讀和更改舊的內容來獲得新的內容。另外,請避免在回調內部使用除 Tensorflow 之外的任何其他庫,尤其是當您的回調經常被調用時。不要使用numpy,使用tensorflow。實際上總有一種張量流操作可以滿足您的需要。
編輯:如果您有一些自定義值要更新,您可以使用如下模式:
class LossCallback(tf.keras.callbacks.Callback):
def __init__(self):
super(LossCallback, self).__init__()
self.someValue = tf.Variable(1.0, trainable=False, name='weight1', dtype=tf.float32)
def on_epoch_end(self, epoch, logs=None):
tf.keras.backend.set_value(self.model.loss.someValue, self.someValue * epoch)
或者您仍然可以嘗試使用 lambda 回調。
從回調中,您可以訪問模型的任何變量。像這樣self.model.someVariable。您還可以訪問模型自定義__init__函數中定義的任何自定義變量,如下所示:
#in model's custom __init__
def __init__(self, someArgs):
...
self.someArg = someArgs
...
#in callback's "on_epoch_..." method
...
keras.backend.set_value(self.model.someArg, 42)
...
請注意,您不能self.model在回調__init____init__函數中使用,因為調用回調時模型仍未初始化。
這有幫助嗎?
添加回答
舉報