1 回答

TA貢獻1804條經驗 獲得超2個贊
您擁有的文件是 ,它們只是圖層名稱到權重偏差和類似值的映射(有關更全面的介紹,請參閱此處)。state_dict
tensor
這意味著你需要一個模型,以便可以映射那些節省的權重和偏差,但首先要做的事情是:
1. 模型準備
克隆模型定義所在的存儲庫并打開文件 。我們需要進行一些修改才能使其與 . 導出器需要僅作為(或/個)傳遞,而類需要和參數)。/pro_gan_pytorch/pro_gan_pytorch/PRO_GAN.py
onnx
onnx
input
torch.tensor
list
dict
Generator
int
float
簡單的解決方案是稍微修改一下函數(文件中的行,可以在GitHub上驗證它)到以下內容:forward
80
def forward(self, x, depth, alpha):
"""
forward pass of the Generator
:param x: input noise
:param depth: current depth from where output is required
:param alpha: value of alpha for fade-in effect
:return: y => output
"""
# THOSE TWO LINES WERE ADDED
# We will pas tensors but unpack them here to `int` and `float`
depth = depth.item()
alpha = alpha.item()
# THOSE TWO LINES WERE ADDED
assert depth < self.depth, "Requested output depth cannot be produced"
y = self.initial_block(x)
if depth > 0:
for block in self.layers[: depth - 1]:
y = block(y)
residual = self.rgb_converters[depth - 1](self.temporaryUpsampler(y))
straight = self.rgb_converters[depth](self.layers[depth - 1](y))
out = (alpha * straight) + ((1 - alpha) * residual)
else:
out = self.rgb_converters[0](y)
return out
此處僅添加了解包方式。每個不屬于類型的輸入都應在函數定義中打包為一個,并在函數頂部盡快解壓縮。它不會破壞您創建的檢查點,所以不用擔心,因為它只是映射。item()Tensorlayer-weight
2. 模型導出
將此腳本放在 (位置也位于):/pro_gan_pytorchREADME.md
import torch
from pro_gan_pytorch import PRO_GAN as pg
gen = torch.nn.DataParallel(pg.Generator(depth=9))
gen.load_state_dict(torch.load("GAN_GEN_SHADOW_8.pth"))
module = gen.module.to("cpu")
# Arguments like depth and alpha may need to be changed
dummy_inputs = (torch.randn(1, 512), torch.tensor([5]), torch.tensor([0.1]))
torch.onnx.export(module, dummy_inputs, "GAN_GEN8.onnx", verbose=True)
請注意以下幾點:
我們必須在加載權重之前創建模型,因為它是唯一的。
state_dict
torch.nn.DataParallel
是必需的,因為這是模型的訓練對象(不確定您的情況,請相應地進行調整)。加載后,我們可以通過屬性獲取模塊本身。module
一切都被扔到了,我認為沒有必要在這里。如果你堅持的話,你可以把一切都扔到。
CPU
GPU
GPU
生成器的虛擬輸入不能是圖像(我使用了存儲庫作者在其Google云端硬盤上提供的文件),它必須是帶有元素的噪音。
512
運行它,你的文件應該在那里。.onnx
哦,由于您遵循不同的檢查點,您可能希望遵循類似的過程,盡管不能保證一切都會正常工作(盡管它看起來確實如此)。
添加回答
舉報