亚洲在线久爱草,狠狠天天香蕉网,天天搞日日干久草,伊人亚洲日本欧美

為了賬號安全,請及時綁定郵箱和手機立即綁定
已解決430363個問題,去搜搜看,總會有你想問的

在卷積神經網絡中設置層的維度

在卷積神經網絡中設置層的維度

桃花長相依 2021-09-14 17:44:24
假設我有 4 個批次的 3x100x100 圖像作為輸入,并且我正在嘗試使用 pytorch 制作我的第一個卷積神經網絡。我真的不確定我的卷積神經網絡是否正確,因為當我通過以下安排訓練我的輸入時,我遇到了錯誤:Expected input batch_size (1) to match target batch_size (4).以下是我的轉發nnet:那么如果我要通過它:nn.Conv2d(3, 6, 5)我會得到 6 層地圖,每層都有尺寸(100-5+1)。那么如果我要通過它:nn.MaxPool2d(2, 2)我會得到 6 層地圖,每層都有尺寸 (96/2)然后,如果我要通過它:nn.Conv2d(6, 16, 5)我會得到 16 層地圖,每層都有尺寸 (48-5+1)那么如果我要通過它:self.fc1 = nn.Linear(44*44*16, 120)我會得到 120 個神經元那么如果我要通過它:self.fc2 = nn.Linear(120, 84)我會得到 84 個神經元那么如果我要通過它:self.fc3 = nn.Linear(84, 3)我會得到 3 個輸出,這將是完美的,因為我有 3 類標簽。但正如我之前所說,這會導致一個非常令人驚訝的錯誤,因為這對我來說很有意義。完整的神經網絡代碼:import torch.nn as nnimport torch.nn.functional as Fclass Net(nn.Module):    def __init__(self):        super(Net, self).__init__()        self.conv1 = nn.Conv2d(3, 6, 5)        self.pool = nn.MaxPool2d(2, 2)        self.conv2 = nn.Conv2d(6, 16, 5)        self.fc1 = nn.Linear(44*44*16, 120)        self.fc2 = nn.Linear(120, 84)        self.fc3 = nn.Linear(84, 3)    def forward(self, x):        x = self.pool(F.relu(self.conv1(x)))        x = self.pool(F.relu(self.conv2(x)))        x = x.view(-1, 16 *44*44)        x = F.relu(self.fc1(x))        x = F.relu(self.fc2(x))        x = self.fc3(x)        return xnet = Net()net.to(device)
查看完整描述

1 回答

?
慕村225694

TA貢獻1880條經驗 獲得超4個贊

你的理解是正確的,非常詳細。


但是,您使用了兩個池化層(請參閱下面的相關代碼)。所以第二步之后的輸出將是16個44/2=22維度的地圖。


x = self.pool(F.relu(self.conv1(x)))

x = self.pool(F.relu(self.conv2(x)))

要解決此問題,要么不池化,要么將全連接層的維度更改為22*22*16。


要通過不池化來修復,請修改您的轉發功能,如下所示。


def forward(self, x):

    x = self.pool(F.relu(self.conv1(x)))

    x = F.relu(self.conv2(x))

    x = x.view(-1, 16 *44*44)

    x = F.relu(self.fc1(x))

    x = F.relu(self.fc2(x))

    x = self.fc3(x)

    return x

要通過更改全連接層的維度來修復,請更改網絡的聲明如下。


def __init__(self):

    super(Net, self).__init__()

    self.conv1 = nn.Conv2d(3, 6, 5)

    self.pool = nn.MaxPool2d(2, 2)

    self.conv2 = nn.Conv2d(6, 16, 5)

    self.fc1 = nn.Linear(22*22*16, 120)

    self.fc2 = nn.Linear(120, 84)

    self.fc3 = nn.Linear(84, 10)


查看完整回答
反對 回復 2021-09-14
  • 1 回答
  • 0 關注
  • 335 瀏覽
慕課專欄
更多

添加回答

舉報

0/150
提交
取消
微信客服

購課補貼
聯系客服咨詢優惠詳情

幫助反饋 APP下載

慕課網APP
您的移動學習伙伴

公眾號

掃描二維碼
關注慕課網微信公眾號