diff --git a/examples/basic_tutorials/load_pytorch_parameters_to_tensorlayerx.py b/examples/basic_tutorials/load_pytorch_parameters_to_tensorlayerx.py index 5815be6..c656897 100644 --- a/examples/basic_tutorials/load_pytorch_parameters_to_tensorlayerx.py +++ b/examples/basic_tutorials/load_pytorch_parameters_to_tensorlayerx.py @@ -88,7 +88,7 @@ def def_torch_weight_reshape(weight): # Step1: save pytorch model parameters to a.pth # On the first run, uncomment lines 90 and 91. # b = B() - # torch.save(a.state_dict(), 'a.pth') + # torch.save(b.state_dict(), 'a.pth') a = A() # Step2: Converts pytorch a.pth to the model parameter format of tensorlayerx