1 回答

TA貢獻1802條經驗 獲得超6個贊
在不了解模型詳細信息的情況下,以下代碼段可能會有所幫助:
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Dense, Input
# Train your initial model
def get_initial_model():
...
return model
model = get_initial_model()
model.fit(...)
model.save_weights('initial_model_weights.h5')
# Use Model API to create another model, built on your initial model
initial_model = get_initial_model()
initial_model.load_weights('initial_model_weights.h5')
nn_input = Input(...)
x = initial_model(nn_input)
x = Dense(...)(x) # This is the additional layer, connected to your initial model
nn_output = Dense(...)(x)
# Combine your model
full_model = Model(inputs=nn_input, outputs=nn_output)
# Compile and train as usual
full_model.compile(...)
full_model.fit(...)
基本上,你訓練你的初始模型,保存它。然后再次重新加載它,并使用 API 將其與其他層包裝在一起。如果您不熟悉API,可以在此處查看Keras文檔(afaik API對于Tensorflow.Keras 2.0保持不變)。ModelModel
請注意,您需要檢查初始模型的最終層的輸出形狀是否與其他層兼容(例如,如果您只是執行特征提取,則可能需要從初始模型中刪除最終的密集層)。
添加回答
舉報