亚洲在线久爱草,狠狠天天香蕉网,天天搞日日干久草,伊人亚洲日本欧美

為了賬號安全,請及時綁定郵箱和手機立即綁定
已解決430363個問題,去搜搜看,總會有你想問的

將 CNN Pytorch 中的預訓練權重傳遞給 Tensorflow 中的 CNN

將 CNN Pytorch 中的預訓練權重傳遞給 Tensorflow 中的 CNN

千巷貓影 2022-10-18 17:05:47
我已經在 Pytorch 中為 224x224 大小的圖像和 4 個類訓練了這個網絡。class CustomConvNet(nn.Module):    def __init__(self, num_classes):        super(CustomConvNet, self).__init__()        self.layer1 = self.conv_module(3, 64)        self.layer2 = self.conv_module(64, 128)        self.layer3 = self.conv_module(128, 256)        self.layer4 = self.conv_module(256, 256)        self.layer5 = self.conv_module(256, 512)        self.gap = self.global_avg_pool(512, num_classes)        #self.linear = nn.Linear(512, num_classes)        #self.relu = nn.ReLU()        #self.softmax = nn.Softmax()    def forward(self, x):        out = self.layer1(x)        out = self.layer2(out)        out = self.layer3(out)        out = self.layer4(out)        out = self.layer5(out)        out = self.gap(out)        out = out.view(-1, 4)        #out = self.linear(out)        return out    def conv_module(self, in_num, out_num):        return nn.Sequential(            nn.Conv2d(in_num, out_num, kernel_size=3, stride=1, padding=1),            nn.ReLU(),            nn.MaxPool2d(kernel_size=(2, 2), stride=None))    def global_avg_pool(self, in_num, out_num):        return nn.Sequential(            nn.Conv2d(in_num, out_num, kernel_size=3, stride=1, padding=1),            #nn.BatchNorm2d(out_num),            #nn.LeakyReLU(),            nn.ReLU(),            nn.Softmax(),            nn.AdaptiveAvgPool2d((1, 1)))我從第一個 Conv2D 得到了權重,它的大小torch.Size([64, 3, 3, 3])我已將其保存為:weightsCNN = net.layer1[0].weight.datanp.save('CNNweights.npy', weightsCNN)這是我在 Tensorflow 中構建的模型。我想將從 Pytorch 模型中保存的權重傳遞到這個 Tensorflow CNN 中。    model = models.Sequential()    model.add(layers.Conv2D(64, (3, 3), activation='relu', input_shape=(224, 224, 3)))    model.add(layers.MaxPooling2D((2, 2)))    model.add(layers.Conv2D(128, (3, 3), activation='relu'))    model.add(layers.MaxPooling2D((2, 2)))我應該怎么做?Tensorflow 需要什么形狀的權重?謝謝!
查看完整描述

1 回答

?
郎朗坤

TA貢獻1921條經驗 獲得超9個贊

keras您可以非常簡單地檢查所有層的所有權重的形狀:

for layer in model.layers:  
  print([tensor.shape for tensor in layer.get_weights()])

這將為您提供所有權重的形狀(包括偏差),因此您可以numpy相應地準備加載的權重。

要設置它們,請執行類似的操作:

for torch_weight, layer in zip(model.layers, torch_weights):
    layer.set_weights(torch_weight)

wheretorch_weights應該是一個列表,np.array其中包含您必須加載的列表。

通常每個元素torch_weights都包含一個np.array權重和一個偏差。

請記住,從打印中收到的形狀必須與您放入的形狀完全相同set_weights。

有關更多信息,請參閱文檔。

順便提一句。確切的形狀取決于模型執行的層和操作,有時您可能必須轉置一些數組以“適應它們”。


查看完整回答
反對 回復 2022-10-18
  • 1 回答
  • 0 關注
  • 115 瀏覽
慕課專欄
更多

添加回答

舉報

0/150
提交
取消
微信客服

購課補貼
聯系客服咨詢優惠詳情

幫助反饋 APP下載

慕課網APP
您的移動學習伙伴

公眾號

掃描二維碼
關注慕課網微信公眾號