亚洲在线久爱草,狠狠天天香蕉网,天天搞日日干久草,伊人亚洲日本欧美

為了賬號安全,請及時綁定郵箱和手機立即綁定
已解決430363個問題,去搜搜看,總會有你想問的

如何使用經過訓練和保存的前饋神經網絡來預測新數據

如何使用經過訓練和保存的前饋神經網絡來預測新數據

慕俠2389804 2023-09-12 17:26:59
我正在嘗試使用經過訓練和保存的模型對新數據進行預測。我的新數據與用于構建已保存模型的數據形狀不同。我嘗試過使用 model.save() 和 model.save_weights(),因為我仍然想保留訓練配置,但它們都會產生相同的錯誤。即使形狀不同,有沒有辦法在新數據上使用保存的模型?from tensorflow.keras.models import Sequentialfrom tensorflow.keras.layers import Activation, Densemodel = Sequential([    Dense(units=11, activation='relu', input_shape = (42,), kernel_regularizer=keras.regularizers.l2(0.001)),    Dense(units=1, activation='sigmoid')])new_model.load_weights('Fin_weights.h5')y_pred = new_model.predict(X)ValueError: Error when checking input: expected dense_6_input to have shape (44,) but got array with shape (42,)
查看完整描述

1 回答

?
繁星淼淼

TA貢獻1775條經驗 獲得超11個贊

不,您必須完全匹配相同的輸入形狀。


您的模型代碼(model = Sequential([...行)應與您保存的模型完全對應,并且您的輸入數據(X行y_pred = new_model.predict(X))應與保存的模型()中的形狀相同'Fin_weights.h5'。


您唯一能做的就是以某種方式用例如零填充新數據。但只有當其余值對應相同的特征或信號時,這才有幫助。


例如,假設您正在訓練 NN 來識別形狀 (2, 3) 的灰度圖像,如下所示:


1 2 3

4 5 6

然后您訓練了模型并將其保存以供以后使用。之后,您決定將神經網絡用于更小或更大尺寸的圖像,如下所示


1 2

3 4

或這個


1? 2? 3? 4

5? 6? 7? 8

9 10 11 12

而且您幾乎可以肯定,您的神經網絡仍然會在不同形狀的輸入上給出良好的結果。


然后,您只需在右側填充第一個不匹配的圖像,并在右側添加額外的零,如下所示:


1 2 0

3 4 0

或另一種填充方式,在左側


0 1 2

0 3 4

第二張圖片你剪了一點


1? 2? 3

5? 6? 7

(或從另一邊剪掉)。


只有這樣,您才能將神經網絡應用到經過處理的輸入圖像。


與您的情況相同,您必須添加兩個零。但僅限于編碼輸入信號或特征的序列幾乎相同的情況。


如果您的預測數據大小錯誤,請執行以下操作:


y_pred = new_model.predict(

? ? np.pad(X, ((0, 0), (0, 2)))

)

這會在右側填充您的數據的兩個零,盡管您可能希望將其填充在左側((2, 0)而不是(0, 2))或兩側((1, 1)而不是(0, 2))。


如果您保存的權重具有不同的形狀,則模型的代碼會在模型代碼中執行此操作(更改42 --> 44):


model = Sequential([

? ? Dense(units=11, activation='relu', input_shape = (44,), kernel_regularizer=keras.regularizers.l2(0.001)),

? ? Dense(units=1, activation='sigmoid')

])

您可能應該執行上述兩件事,以匹配您保存的模型/權重。


如果針對44數字輸入訓練的神經網絡對于任何數據填充都會給出完全錯誤的結果42,那么唯一的方法是重新訓練神經網絡的42輸入并再次保存模型。


但是你必須考慮到這樣一個事實,input_shape = (44,)在 keras 庫中實際上意味著X輸入的最終數據model.predict(X)應該是二維形狀(10, 44)(其中 10 是你的神經網絡要識別的不同對象的數量),keras 隱藏第 0 維,即所謂的批量維度。批次(第 0 個)維度實際上可能會有所不同,您可以提供 5 個對象(即 shape 數組(5, 44))或 7 個對象(形狀 (7, 44))或任何其他數量的對象。批處理僅意味著 keras 在一次調用中并行處理多個對象,只是為了快速/高效。但每個單個對象都是形狀的一維子數組(44,)。您可能誤解了數據如何輸入網絡并表示的一些內容。44 不是數據集的大小(對象數量),而是單個對象的特征數量,例如,如果網絡識別/分類一個人,那么 44 可以表示一個人的 44 個特征,例如年齡、性別、身高、體重,出生月份,種族,膚色,每天卡路里,月收入,月支出,工資等總共1個人類對象的44個不同的固定特征。他們可能不會改變。但是,如果您獲得了一些僅具有42或36特征的其他數據,而您只需要0精確地將其放置在 中缺少的特征位置44,則在右側或左側填充零是不正確的,您必須放置0s 正好位于 中缺失的那些位置44。


但是你的 44 和 42 和 36 可能意味著不同輸入對象的數量,每個對象都只有1特征。想象一個任務,當你有一個50只有兩列數據的人類數據集(表) salary,country然后你可能想要構建猜測country到salary那時你將擁有的神經網絡input_shape = (1,)(對應于 1 個數字的一維數組 - salary),但絕對是不是input_shape = (50,)(表中的人數)。input_shape只講述 1 個物體、1 個人的形狀。50 是對象(人類)的數量,它是 numpy 數組中用于預測的批量(第 0 個)維度,因此您的X數組model.predict(X)的形狀為(50, 1),但input_shape = (1,)在模型中?;旧?keras 省略(隱藏)第 0 個批次維度。如果44在您的情況下,實際上意味著數據集大?。▽ο髷盗浚敲茨e誤地訓練了 NN,應該使用 ,input_shape = (1,)作為44批量維度進行重新訓練,這44可能會根據訓練或測試數據集的大小而有所不同。

如果您要重新訓練您的網絡,那么整個訓練/評估過程的簡單形式如下:

  1. 假設您有一個 CSV 文件格式的數據集data.csv。例如,總共有 126 行和 17 列。

  2. 以某種方式讀入您的數據,例如通過np.loadtxt或通過pd.read_csv或通過標準 python 的csv.reader()。將數據轉換為數字(浮點數)。

  3. 按行將數據隨機分成兩部分training/evaluation大約相應的大小90%/10%行,例如 110 行用于訓練,16 行用于評估(總共 126 行)。

  4. 決定將預測數據中的哪些列,您可以預測任意數量的列,假設我們要預測兩列,即第 16 列和第 17 列?,F在,您的數據列被分為兩部分X(15 列,編號為 1-15)和Y(2 列,編號為 16-17)。

  5. 在網絡層的代碼中,第一層設置input_shape = (15,)(15 是 中的列數X),Dense(2)最后一層(2 是 中的列數Y)。

  6. 使用model.fit(X, Y, epochs = 1000, ...)方法在訓練數據集上訓練網絡。

  7. 將經過訓練的網絡保存到模型文件model.save(...)net.h5.

  8. 通過 加載您的網絡model.load(...)。

  9. 通過 測試網絡質量predicted_Y = model.predict(testing_X),與 進行比較testing_Y,如果網絡模型選擇正確,則testing_Y應該接近predicted_Y,例如80%正確(這個比率稱為準確性)。

  10. 為什么我們將數據集分成訓練/測試部分。因為訓練階段只看到訓練數據集子部分。網絡訓練的任務是很好地記住整個訓練數據,并通過找到X和之間的一些隱藏依賴關系來概括預測Y。因此,如果調用model.predict(...)訓練數據,應該給出接近的100%準確性,因為網絡會看到所有這些訓練數據并記住它。但測試數據它根本看不到,因此需要聰明并真正預測通過 X 測試 Y,因此測試的準確性較低,例如80%

  11. 如果測試結果的質量不好,您必須改進網絡架構并從頭開始重新運行整個訓練過程。

  12. 如果您需要預測部分數據,例如,當您的X數據中只有 15 列中的 12 列時,則用零填充缺失的列值,例如,如果您缺少第 7 列和第 11 列,則將零插入到第 7 列和第 11 位。這樣總列數又是 15。您的網絡將僅支持 model.predict() 的輸入,即訓練時使用的列數,即 15,該數字在 中提供input_shape = (15,)


查看完整回答
反對 回復 2023-09-12
  • 1 回答
  • 0 關注
  • 116 瀏覽
慕課專欄
更多

添加回答

舉報

0/150
提交
取消
微信客服

購課補貼
聯系客服咨詢優惠詳情

幫助反饋 APP下載

慕課網APP
您的移動學習伙伴

公眾號

掃描二維碼
關注慕課網微信公眾號