亚洲在线久爱草,狠狠天天香蕉网,天天搞日日干久草,伊人亚洲日本欧美

為了賬號安全,請及時綁定郵箱和手機立即綁定
已解決430363個問題,去搜搜看,總會有你想問的

預期輸入batch_size (64) 與目標batch_size (30) 匹配

預期輸入batch_size (64) 與目標batch_size (30) 匹配

PHP
慕后森 2023-11-09 16:57:49
我目前正在訓練一個神經網絡來對食物圖像的食物組進行分類,從而產生 5 個輸出類別。然而,每當我開始訓練網絡時,我都會收到此錯誤:ValueError: Expected input batch_size (64) to match target batch_size (30).這是我的神經網絡定義和訓練代碼。我真的很感謝幫助,我對 pytorch 比較陌生,無法弄清楚我的代碼中到底有什么問題。謝謝!#Define the Network Architechturemodel = nn.Sequential(nn.Linear(7500, 4950),                      nn.ReLU(),                      nn.Linear(4950, 1000),                      nn.ReLU(),                      nn.Linear(1000, 250),                      nn.ReLU(),                      nn.Linear(250, 5),                      nn.LogSoftmax(dim = 1))#Define losscriterion = nn.NLLLoss()#Initial forward passimages, labels = next(iter(trainloader))images = images.view(images.shape[0], -1)print(images.shape)logits = model(images)print(logits.size)loss = criterion(logits, labels)print(loss)#Define Optimizeroptimizer = optim.SGD(model.parameters(), lr = 0.01)訓練網絡:epochs = 10for e in range(epochs):    running_loss = 0    for image, labels in trainloader:        #Flatten Images        images = images.view(images.shape[0], -1)        #Set gradients to 0        optimizer.zero_grad()        #Output        output = model(images)        loss = criterion(output, labels) #Where the error occurs        loss.backward()        #Gradient Descent Step        optimizer.step()        running_loss += loss.item()    else:        print(f"Training loss: {running_loss/len(trainloader)}")
查看完整描述

2 回答

?
幕布斯6054654

TA貢獻1876條經驗 獲得超7個贊

不是 100% 確定,但我認為錯誤在于這一行:

nn.Linear(7500, 4950)

除非您絕對確定您的輸入是 7500,否則請輸入 1 而不是 7500。請記住,第一個值始終是您的輸入大小。通過設置 1,您將確保您的模型可以處理任何尺寸的圖像。

順便說一句,PyTorch 有一個扁平化功能。使用nn.Flatten而不是使用,images.view()因為您不想犯任何形狀錯誤并必然浪費更多時間。

您犯的另一個小錯誤是您繼續images and image在 for 循環中用作變量和參數。這是非常糟糕的做法,因為每當別人閱讀你的代碼時,你都會讓他們感到困惑。確保不要一遍又一遍地重復使用相同的變量。

另外,您能否提供有關您的數據的更多信息?比如灰度、image_size 等。


查看完整回答
反對 回復 2023-11-09
?
拉丁的傳說

TA貢獻1789條經驗 獲得超8個贊

錯誤結果出現在“for image, labels in trainloader:”行中(應該是圖像)。修復了它,模型現在訓練得很好。



查看完整回答
反對 回復 2023-11-09
  • 2 回答
  • 0 關注
  • 472 瀏覽

添加回答

舉報

0/150
提交
取消
微信客服

購課補貼
聯系客服咨詢優惠詳情

幫助反饋 APP下載

慕課網APP
您的移動學習伙伴

公眾號

掃描二維碼
關注慕課網微信公眾號