我已經在 Pytorch 中為 224x224 大小的圖像和 4 個類訓練了這個網絡。class CustomConvNet(nn.Module): def __init__(self, num_classes): super(CustomConvNet, self).__init__() self.layer1 = self.conv_module(3, 64) self.layer2 = self.conv_module(64, 128) self.layer3 = self.conv_module(128, 256) self.layer4 = self.conv_module(256, 256) self.layer5 = self.conv_module(256, 512) self.gap = self.global_avg_pool(512, num_classes) #self.linear = nn.Linear(512, num_classes) #self.relu = nn.ReLU() #self.softmax = nn.Softmax() def forward(self, x): out = self.layer1(x) out = self.layer2(out) out = self.layer3(out) out = self.layer4(out) out = self.layer5(out) out = self.gap(out) out = out.view(-1, 4) #out = self.linear(out) return out def conv_module(self, in_num, out_num): return nn.Sequential( nn.Conv2d(in_num, out_num, kernel_size=3, stride=1, padding=1), nn.ReLU(), nn.MaxPool2d(kernel_size=(2, 2), stride=None)) def global_avg_pool(self, in_num, out_num): return nn.Sequential( nn.Conv2d(in_num, out_num, kernel_size=3, stride=1, padding=1), #nn.BatchNorm2d(out_num), #nn.LeakyReLU(), nn.ReLU(), nn.Softmax(), nn.AdaptiveAvgPool2d((1, 1)))我從第一個 Conv2D 得到了權重,它的大小torch.Size([64, 3, 3, 3])我已將其保存為:weightsCNN = net.layer1[0].weight.datanp.save('CNNweights.npy', weightsCNN)這是我在 Tensorflow 中構建的模型。我想將從 Pytorch 模型中保存的權重傳遞到這個 Tensorflow CNN 中。 model = models.Sequential() model.add(layers.Conv2D(64, (3, 3), activation='relu', input_shape=(224, 224, 3))) model.add(layers.MaxPooling2D((2, 2))) model.add(layers.Conv2D(128, (3, 3), activation='relu')) model.add(layers.MaxPooling2D((2, 2)))我應該怎么做?Tensorflow 需要什么形狀的權重?謝謝!
1 回答

郎朗坤
TA貢獻1921條經驗 獲得超9個贊
keras
您可以非常簡單地檢查所有層的所有權重的形狀:
for layer in model.layers: print([tensor.shape for tensor in layer.get_weights()])
這將為您提供所有權重的形狀(包括偏差),因此您可以numpy
相應地準備加載的權重。
要設置它們,請執行類似的操作:
for torch_weight, layer in zip(model.layers, torch_weights): layer.set_weights(torch_weight)
wheretorch_weights
應該是一個列表,np.array
其中包含您必須加載的列表。
通常每個元素torch_weights
都包含一個np.array
權重和一個偏差。
請記住,從打印中收到的形狀必須與您放入的形狀完全相同set_weights
。
有關更多信息,請參閱文檔。
順便提一句。確切的形狀取決于模型執行的層和操作,有時您可能必須轉置一些數組以“適應它們”。
添加回答
舉報
0/150
提交
取消