亚洲在线久爱草,狠狠天天香蕉网,天天搞日日干久草,伊人亚洲日本欧美

為了賬號安全,請及時綁定郵箱和手機立即綁定
已解決430363個問題,去搜搜看,總會有你想問的

PyTorch - 將 ProGAN 代理從 pth 轉換為 onnx

PyTorch - 將 ProGAN 代理從 pth 轉換為 onnx

人到中年有點甜 2022-09-13 10:01:46
我使用此 PyTorch 重新實現訓練了一個 ProGAN 代理,并將該代理另存為 .現在我需要將代理轉換為格式,我正在使用此scipt執行此操作:.pth.onnxfrom torch.autograd import Variableimport torch.onnximport torchvisionimport torchdevice = torch.device("cuda")dummy_input = torch.randn(1, 3, 64, 64)state_dict = torch.load("GAN_agent.pth", map_location = device)torch.onnx.export(state_dict, dummy_input, "GAN_agent.onnx")一旦我運行它,我得到錯誤(下面的完整提示)。據我所知,問題在于將代理轉換為.onnx需要更多信息。我錯過了什么嗎?AttributeError: 'collections.OrderedDict' object has no attribute 'state_dict'---------------------------------------------------------------------------AttributeError                            Traceback (most recent call last)<ipython-input-2-c64481d4eddd> in <module>     10 state_dict = torch.load("GAN_agent.pth", map_location = device)     11 ---> 12 torch.onnx.export(state_dict, dummy_input, "GAN_agent.onnx")~\anaconda3\envs\Basemap_upres\lib\site-packages\torch\onnx\__init__.py in export(model, args, f, export_params, verbose, training, input_names, output_names, aten, export_raw_ir, operator_export_type, opset_version, _retain_param_name, do_constant_folding, example_outputs, strip_doc_string, dynamic_axes, keep_initializers_as_inputs)    146                         operator_export_type, opset_version, _retain_param_name,    147                         do_constant_folding, example_outputs,--> 148                         strip_doc_string, dynamic_axes, keep_initializers_as_inputs)    149     150 
查看完整描述

1 回答

?
慕婉清6462132

TA貢獻1804條經驗 獲得超2個贊

您擁有的文件是 ,它們只是圖層名稱到權重偏差和類似值的映射(有關更全面的介紹,請參閱此處)。state_dicttensor

這意味著你需要一個模型,以便可以映射那些節省的權重和偏差,但首先要做的事情是:

1. 模型準備

克隆模型定義所在的存儲庫并打開文件 。我們需要進行一些修改才能使其與 . 導出器需要僅作為(或/個)傳遞,而類需要和參數)。/pro_gan_pytorch/pro_gan_pytorch/PRO_GAN.pyonnxonnxinputtorch.tensorlistdictGeneratorintfloat

簡單的解決方案是稍微修改一下函數(文件中的行,可以在GitHub上驗證它)到以下內容:forward80

def forward(self, x, depth, alpha):

    """

    forward pass of the Generator

    :param x: input noise

    :param depth: current depth from where output is required

    :param alpha: value of alpha for fade-in effect

    :return: y => output

    """


    # THOSE TWO LINES WERE ADDED

    # We will pas tensors but unpack them here to `int` and `float`

    depth = depth.item()

    alpha = alpha.item()

    # THOSE TWO LINES WERE ADDED

    assert depth < self.depth, "Requested output depth cannot be produced"


    y = self.initial_block(x)


    if depth > 0:

        for block in self.layers[: depth - 1]:

            y = block(y)


        residual = self.rgb_converters[depth - 1](self.temporaryUpsampler(y))

        straight = self.rgb_converters[depth](self.layers[depth - 1](y))


        out = (alpha * straight) + ((1 - alpha) * residual)


    else:

        out = self.rgb_converters[0](y)


    return out

此處僅添加了解包方式。每個不屬于類型的輸入都應在函數定義中打包為一個,并在函數頂部盡快解壓縮。它不會破壞您創建的檢查點,所以不用擔心,因為它只是映射。item()Tensorlayer-weight


2. 模型導出

將此腳本放在 (位置也位于):/pro_gan_pytorchREADME.md


import torch


from pro_gan_pytorch import PRO_GAN as pg


gen = torch.nn.DataParallel(pg.Generator(depth=9))

gen.load_state_dict(torch.load("GAN_GEN_SHADOW_8.pth"))


module = gen.module.to("cpu")


# Arguments like depth and alpha may need to be changed

dummy_inputs = (torch.randn(1, 512), torch.tensor([5]), torch.tensor([0.1]))

torch.onnx.export(module, dummy_inputs, "GAN_GEN8.onnx", verbose=True)

請注意以下幾點:

  • 我們必須在加載權重之前創建模型,因為它是唯一的。state_dict

  • torch.nn.DataParallel是必需的,因為這是模型的訓練對象(不確定您的情況,請相應地進行調整)。加載后,我們可以通過屬性獲取模塊本身。module

  • 一切都被扔到了,我認為沒有必要在這里。如果你堅持的話,你可以把一切都扔到。CPUGPUGPU

  • 生成器的虛擬輸入不能是圖像(我使用了存儲庫作者在其Google云端硬盤上提供的文件),它必須是帶有元素的噪音。512

運行它,你的文件應該在那里。.onnx

哦,由于您遵循不同的檢查點,您可能希望遵循類似的過程,盡管不能保證一切都會正常工作(盡管它看起來確實如此)。


查看完整回答
反對 回復 2022-09-13
  • 1 回答
  • 0 關注
  • 89 瀏覽
慕課專欄
更多

添加回答

舉報

0/150
提交
取消
微信客服

購課補貼
聯系客服咨詢優惠詳情

幫助反饋 APP下載

慕課網APP
您的移動學習伙伴

公眾號

掃描二維碼
關注慕課網微信公眾號