當 JIT 保存具有許多自定義類的復雜 pytorch 模型的“model.pt”時,我遇到了 pytorch 不知道這些自定義類之一的類型注釋的錯誤。換句話說,以下代碼(對原始代碼進行了徹底總結)在第七行失?。篿mport torchfrom gan import Generatorfrom gan.blocks import SpadeBlockgenerator = Generator()generator.load_weights("path/to/weigts")jitted = torch.jit.script(generator)torch.jit.save(jitted, "model.pt")錯誤:Tracebck (most recent call last): File "pth2onnx.py", line 72, in <module> to_torch_jit(generator) File "pth2onnx.py", line 24, in to_torch_jit jitted = torch.jit.script(generator) File "/home/a.nieuwland/.conda/envs/python3.6/lib/python3.6/site-packages/torch/jit/__init__.py", line 1516, in script return torch.jit._recursive.create_script_module(obj, torch.jit._recursive.infer_methods_to_compile) File "/home/a.nieuwland/.conda/envs/python3.6/lib/python3.6/site-packages/torch/jit/_recursive.py", line 310, in create_script_module concrete_type = concrete_type_store.get_or_create_concrete_type(nn_module) File "/home/a.nieuwland/.conda/envs/python3.6/lib/python3.6/site-packages/torch/jit/_recursive.py", line 269, in get_or_create_concrete_type concrete_type_builder = infer_concrete_type_builder(nn_module) File "/home/a.nieuwland/.conda/envs/python3.6/lib/python3.6/site-packages/torch/jit/_recursive.py", line 138, in infer_concrete_type_builder sub_concrete_type = concrete_type_store.get_or_create_concrete_type(item) File "/home/a.nieuwland/.conda/envs/python3.6/lib/python3.6/site-packages/torch/jit/_recursive.py", line 269, in get_or_create_concrete_type它抱怨的類型確實是我們自己編寫并在加載的Generator. 我將不勝感激有關可能導致此問題的原因或如何調查此問題的指示!我嘗試了以下方法:在調用 torch.jit.script 的腳本中顯式導入 SpadeBlock確保它繼承自 nn.Module (生成器也是如此)使用 pip install --user -e 確保安裝了 gan 軟件包有任何想法嗎?提前致謝!
1 回答

弒天下
TA貢獻1818條經驗 獲得超8個贊
問題原來是我使用的類變量名稱被破壞了。例子:
class Generator(nn.Module): __main: nn.Module
兩個前導下劃線就是原因。將它們更改為單個下劃線或無下劃線。解決問題。
class Generator(nn.Module): main: nn.Module
添加回答
舉報
0/150
提交
取消