Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

RuntimeError: Error(s) in loading state_dict: Unexpected key(s) when recovering results from main process during Trainer.fit() #246

Open
davzaman opened this issue Jan 26, 2023 · 0 comments

Comments

@davzaman
Copy link

I am trying to get multi-gpu training working for running tuning with Ray[Tune]. However, I am getting the following error:

  File "/home/davina/Private/repos/autopopulus/autopopulus/models/ap.py", line 182, in _fit
    self.trainer.fit(self.ae, datamodule=data)
  File "/home/davina/mambaforge/envs/ap/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 838, in fit
    self._call_and_handle_interrupt(
  File "/home/davina/mambaforge/envs/ap/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 783, in _call_and_handle_interrupt
    return self.strategy.launcher.launch(
  File "/home/davina/mambaforge/envs/ap/lib/python3.9/site-packages/ray_lightning/launchers/ray_launcher.py", line 73, in launch
    self._recover_results_in_main_process(ray_output, trainer)
  File "/home/davina/mambaforge/envs/ap/lib/python3.9/site-packages/ray_lightning/launchers/ray_launcher.py", line 399, in _recover_results_in_main_process
    trainer.lightning_module.load_state_dict(state_dict)
  File "/home/davina/mambaforge/envs/ap/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1604, in load_state_dict
    raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for AEDitto:
        Unexpected key(s) in state_dict: "fc_mu.weight", "fc_mu.bias", "fc_var.weight", "fc_var.bias", "decoder.0.weight", "decoder.0.bias".

Looking at the following lines in RayLauncher:_recover_results_in_main_process():

if ray_output.weights_path is not None:
            state_stream = ray_output.weights_path
            # DDPSpawnPlugin.__recover_child_process_weights begin
            # Difference here is that instead of writing the model weights to a
            # file and loading it, we use the state dict of the model directly.
            state_dict = load_state_stream(state_stream, to_gpu=self._strategy.use_gpu)
            # Set the state for PTL using the output from remote training.
            trainer.lightning_module.load_state_dict(state_dict)

If I probe for what state_dict is vs trainer.lightning_module.state_dict() it seems that the latter is completely empty, it just outputs OrderedDict(). The former has all the weights listed in the error with actual data. So for some reason the lightning module is not being set up (or something like that?) for it to have no state. This is not an issue when I don't use ray_lightning for 1-gpu-per-trial and just normal ray[tune].

For reference of how I'm running tuning.

Other info:

ray-core                  2.2.0            py39h4d85f9a_1    conda-forge
ray-dashboard             2.2.0            py39h9a2ef2b_1    conda-forge
ray-default               2.2.0            py39hf3d152e_1    conda-forge
ray-lightning             0.3.0                    pypi_0    pypi
ray-tune                  2.2.0            py39hf3d152e_1    conda-forge

Python 3.9.15

OS: Ubuntu 18.04.4 LTS (Bionic Beaver)

Other relevant information:
cudatoolkit=10.2
pytorch=1.12.1
pytorch-lightning=1.6.5
cudnn=7.6.5

Specs:
4 GeForce RTX 2080 Ti's
32 CPUs (x86_64)

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant