Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LLaMA PRO training resume problem #1231

Open
germanjke opened this issue May 23, 2024 · 6 comments
Open

LLaMA PRO training resume problem #1231

germanjke opened this issue May 23, 2024 · 6 comments
Labels
question Further information is requested

Comments

@germanjke
Copy link

Hello,

I'm currently training LLaMA PRO. Initially, I expanded the model from 32 layers to 40 layers and proceeded to train only the newly added 8 layers (every fifth layer). Therefore, I froze 32 out of the 40 layers.

layer_freezing: 
    layer_names: [ 
    'model._fsdp_wrapped_module.model.layers.36._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param',
    'model._fsdp_wrapped_module.model.layers.16._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param',
    'model._fsdp_wrapped_module.model.layers.18._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param',
    'model._fsdp_wrapped_module.model.layers.1._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param',
    'model._fsdp_wrapped_module.model.layers.27._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param',
    'model._fsdp_wrapped_module.model.layers.32._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param',
    'model._fsdp_wrapped_module.model.layers.35._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param',
    'model._fsdp_wrapped_module.model.layers.0._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param',
    'model._fsdp_wrapped_module._flat_param',
    'model._fsdp_wrapped_module.model.layers.10._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param',
    'model._fsdp_wrapped_module.model.layers.3._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param',
    'model._fsdp_wrapped_module.model.layers.37._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param',
    'model._fsdp_wrapped_module.model.layers.28._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param',
    'model._fsdp_wrapped_module.model.layers.22._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param',
    'model._fsdp_wrapped_module.model.layers.12._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param',
    'model._fsdp_wrapped_module.model.layers.2._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param',
    'model._fsdp_wrapped_module.model.layers.5._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param',
    'model._fsdp_wrapped_module.model.layers.8._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param',
    'model._fsdp_wrapped_module.model.layers.20._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param',
    'model._fsdp_wrapped_module.model.layers.17._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param',
    'model._fsdp_wrapped_module.model.layers.25._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param',
    'model._fsdp_wrapped_module.model.layers.30._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param',
    'model._fsdp_wrapped_module.model.layers.38._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param',
    'model._fsdp_wrapped_module.model.layers.7._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param',
    'model._fsdp_wrapped_module.model.layers.33._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param',
    'model._fsdp_wrapped_module.model.layers.6._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param',
    'model._fsdp_wrapped_module.model.layers.31._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param',
    'model._fsdp_wrapped_module.model.layers.13._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param',
    'model._fsdp_wrapped_module.model.layers.15._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param',
    'model._fsdp_wrapped_module.model.layers.11._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param',
    'model._fsdp_wrapped_module.model.layers.21._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param',
    'model._fsdp_wrapped_module.model.layers.26._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param',
    'model._fsdp_wrapped_module.model.layers.23._fsdp_wrapped_module._checkpoint_wrapped_module._flat_param'
    ]

The training is going well and only the layers I need are trained.

But after following a hardware failure, I attempted to resume training using load_path, but I encountered an error:

[rank6]:   File 
[rank6]: "/usr/lib/python3/dist-packages/torch/distributed/checkpoint/state_dict_loader.p
[rank6]: y", line 198, in local_step
[rank6]:     local_plan = planner.create_local_plan()
[rank6]:                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank6]:   File 
[rank6]: "/usr/lib/python3/dist-packages/torch/distributed/checkpoint/default_planner.py"
[rank6]: , line 185, in create_local_plan
[rank6]:     return create_default_local_load_plan(self.state_dict, self.metadata)
[rank6]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank6]:   File 
[rank6]: "/usr/lib/python3/dist-packages/torch/distributed/checkpoint/default_planner.py"
[rank6]: , line 235, in create_default_local_load_plan
[rank6]:     md = metadata.state_dict_metadata[fqn]
[rank6]:          ~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^
[rank6]: KeyError: 
[rank6]: 'state.optimizers.DecoupledAdamW.state.model.model.embed_tokens.weight.exp_avg'
[rank6]: Traceback (most recent call last): (RANK 14)
[rank6]:   File "/usr/lib/python3/dist-packages/torch/distributed/checkpoint/utils.py", 
[rank6]: line 163, in reduce_scatter
[rank6]:     local_data = map_fun()
[rank6]:                  ^^^^^^^^^
[rank6]:   File 
[rank6]: "/usr/lib/python3/dist-packages/torch/distributed/checkpoint/state_dict_loader.p
[rank6]: y", line 198, in local_step
[rank6]:     local_plan = planner.create_local_plan()
[rank6]:                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank6]:   File 
[rank6]: "/usr/lib/python3/dist-packages/torch/distributed/checkpoint/default_planner.py"
[rank6]: , line 185, in create_local_plan
[rank6]:     return create_default_local_load_plan(self.state_dict, self.metadata)
[rank6]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank6]:   File 
[rank6]: "/usr/lib/python3/dist-packages/torch/distributed/checkpoint/default_planner.py"
[rank6]: , line 235, in create_default_local_load_plan
[rank6]:     md = metadata.state_dict_metadata[fqn]
[rank6]:          ~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^
[rank6]: KeyError: 
[rank6]: 'state.optimizers.DecoupledAdamW.state.model.model.embed_tokens.weight.exp_avg'
[rank6]: Traceback (most recent call last): (RANK 15)
[rank6]:   File "/usr/lib/python3/dist-packages/torch/distributed/checkpoint/utils.py", 
[rank6]: line 163, in reduce_scatter
[rank6]:     local_data = map_fun()
[rank6]:                  ^^^^^^^^^
[rank6]:   File 
[rank6]: "/usr/lib/python3/dist-packages/torch/distributed/checkpoint/state_dict_loader.p
[rank6]: y", line 198, in local_step
[rank6]:     local_plan = planner.create_local_plan()
[rank6]:                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank6]:   File 
[rank6]: "/usr/lib/python3/dist-packages/torch/distributed/checkpoint/default_planner.py"
[rank6]: , line 185, in create_local_plan
[rank6]:     return create_default_local_load_plan(self.state_dict, self.metadata)
[rank6]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank6]:   File 
[rank6]: "/usr/lib/python3/dist-packages/torch/distributed/checkpoint/default_planner.py"
[rank6]: , line 235, in create_default_local_load_plan
[rank6]:     md = metadata.state_dict_metadata[fqn]
[rank6]:          ~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^
[rank6]: KeyError: 
[rank6]: 'state.optimizers.DecoupledAdamW.state.model.model.embed_tokens.weight.exp_avg'

My ep0-ba4500/.metadata looks like this:

������������%torch.distributed.checkpoint.metadata���Metadata���)��}�(��state_dict_metadata�}�(�+state.model.model.model.embed_tokens.weight�h���TensorStorageMetadata���)��}�(�
properties�h���TensorProperties���)��(��torch���float32�����torch.serialization���_get_layout����
torch.strided���R��h���_MEM_FORMAT_ENCODING���K���R��t�b��size���torch���Size���J����M������R���chunks�]�(h���ChunkStorageMetadata���)��}�(��offsets�h!K�K�����R���sizes�h!M�>M������R�ubh()��}�(h+h!M��K�����R�h/h!M�>M������R�ubh()��}�(h+h!J�w��K�����R�h/h!M�>M������R�ubh()��}�(h+h!M�>K�����R�h/h!M�>M������R�ubh()��}�(h+h!M@}K�����R�h/h!M�>M������R�ubh()��}�(h+h!M�K�����R�h/h!M�>M������R�ubh()��}�(h+h!J 9��K�����R�h/h!M�>M������R�ubh()��}�(h+h!J`���K�����R�h/h!M�>M������R�ubeub�3state.model.model.model.layers.2.mlp.up_proj.weight�h	)��}�(hh�)��(h�h��h��t�bh�h!M�8M������R�h%]�(h()��}�(h+h!K�K�����R�h/h!M��M������R�ubh()��}�(h+h!M��K�����R�h/h!M��M������R�ubh()��}�(h+h!M�*K�����R�h/h!M��M������R�ubh()��}�(h+h!M��K�����R�h/h!M��M������R�ubh()��}�(h+h!M��K�����R�h/h!M��M������R�ubh()��}�(h+h!M��K�����R�h/h!M��M������R�ubh()��}�(h+h!M�#K�����R�h/h!M��M������R�ubh()��}�(h+h!M�1K�����R�h/h!M��M������R�ubeub�7state.model.model.model.layers.2.input_layernorm.weight�h	)��}�(hh�)��(h�h��h��t�bh�h!M������R�h%]�(h()��}�(h+h!K�����R�h/h!M������R�ubh()��}�(h+h!M������R�h/h!M������R�ubh()��}�(h+h!M�����R�h/h!M������R�ubh()��}�(h+h!M������R�h/h!M������R�ubh()��}�(h+h!M������R�h/h!M������R�ubh()��}�(h+h!M������R�h/h!M������R�ubh()��}�(h+h!M�
����R�h/h!M������R�ubh()��}�(h+h!M������R�h/h!M������R�ubeub�@state.model.model.model.layers.2.post_attention_layernorm.weight�h	)��}�(hh�)��(h�h��h��t�bh�h!M������R�h%]�(h()��}�(h+h!K�����R�h/h!M������R�ubh()��}�(h+h!M������R�h/h!M������R�ubh()��}�(h+h!M�����R�h/h!M������R�ubh()��}�(h+h!M������R�h/h!M������R�ubh()��}�(h+h!M������R�h/h!M������R�ubh()��}�(h+h!M������R�h/h!M������R�ubh()��}�(h+h!M�
����R�h/h!M������R�ubh()��}�(h+h!M������R�h/h!M������R�ubeub�8state.model.model.model.layers.3.self_attn.q_proj.weight�h	)��}�(hh�)��(h�h��h��t�bh�h!M��M������R�h%]�(h()��}�(h+h!K�K�����R�h/h!M��M������R�ubh()��}�(h+h!M��K�����R�h/h!M��M������R�ubh()��}�(h+h!M�K�����R�h/h!M��M������R�ubh()��}�(h+h!M��K�����R�h/h!M��M������R�ubh()��}�(h+h!M��K�����R�h/h!M��M������R�ubh()��}�(h+h!M��K�����R�h/h!M��M������R�ubh()��}�(h+h!M�
K�����R�h/h!M��M������R�ubh()��}�(h+h!M��K�����R�h/h!M��M������R�ubeub�8state.model.model.model.layers.3.self_attn.o_proj.weight�h	)��}�(hh�)��(h�h��h��t�bh�h!M��M������R�h%]�(h()��}�(h+h!K�K�����R�h/h!M��M������R�ubh()��}�(h+h!M��K�����R�h/h!M��M������R�ubh()��}�(h+h!M�K�����R�h/h!M��M������R�ubh()��}�(h+h!M��K�����R�h/h!M��M������R�ubh()��}�(h+h!M��K�����R�h/h!M��M������R�ubh()��}�(h+h!M��K�����R�h/h!M��M������R�ubh()��}�(h+h!M�
K�����R�h/h!M��M������R�ubh()��}�(h+h!M��K�����R�h/h!M��M������R�ubeub�3state.model.model.model.layers.3.mlp.up_proj.weight�h	)��}�(hh�)��(h�h��h��t�bh�h!M�8M������R�h%]�(h()��}�(h+h!K�K�����R�h/h!M��M������R�ubh()��}�(h+h!M��K�����R�h/h!M��M������R�ubh()��}�(h+h!M�*K�����R�h/h!M��M������R�ubh()��}�(h+h!M��K�����R�h/h!M��M������R�ubh()��}�(h+h!M��K�����R�h/h!M��M������R�ubh()��}�(h+h!M��K�����R�h/h!M��M������R�ubh()��}�(h+h!M�#K�����R�h/h!M��M������R�ubh()��}�(h+h!M�1K�����R�h/h!M��M������R�ubeub�7state.model.model.model.layers.3.input_layernorm.weight�h	)��}�(hh�)��
etc...

Have you experienced similar issues?

@dakinggg
Copy link
Collaborator

We haven't thoroughly tested with layer freezing + FSDP, so possible something complicated is going on here. However, we have seen this error when you try to load a checkpoint that doesn't have optimizer state. So it is possible that loading a checkpoint only containing optimizer state for some of the parameters does not work properly.

@Riccorl
Copy link

Riccorl commented May 29, 2024

I'm facing a similar issue with the latest release (0.8.0). When resuming from a monolithic checkpoint with HYBRID_SHARD I get the following error (KeyError: 'state'):

[rank24]: ╭───────────────────── Traceback (most recent call last) ──────────────────────╮
[rank24]: │ /leonardo/home/userexternal/rorland1/llm-foundry/scripts/slurm/base_scripts/ │
[rank24]: │ ../../train/train.py:786 in <module>                                         │
[rank24]: │                                                                              │
[rank24]: │   783 │   cfg = om.merge(yaml_cfg, cli_cfg)                                  │
[rank24]: │   784 │   om.resolve(cfg)                                                    │
[rank24]: │   785 │   assert isinstance(cfg, DictConfig)                                 │
[rank24]: │ ❱ 786 │   main(cfg)                                                          │
[rank24]: │   787                                                                        │
[rank24]: │                                                                              │
[rank24]: │ /leonardo/home/userexternal/rorland1/llm-foundry/scripts/slurm/base_scripts/ │
[rank24]: │ ../../train/train.py:717 in main                                             │
[rank24]: │                                                                              │
[rank24]: │   714 │                                                                      │
[rank24]: │   715 │   # Build the Trainer                                                │
[rank24]: │   716 │   log.info('Building trainer...')                                    │
[rank24]: │ ❱ 717 │   trainer = Trainer(                                                 │
[rank24]: │   718 │   │   run_name=run_name,                                             │
[rank24]: │   719 │   │   seed=seed,                                                     │
[rank24]: │   720 │   │   model=model,                                                   │
[rank24]: │                                                                              │
[rank24]: │ /leonardo_scratch/large/userexternal/rorland1/python-envs/llm-foundry-0.8.0- │
[rank24]: │ venv/lib/python3.11/site-packages/composer/trainer/trainer.py:1715 in        │
[rank24]: │ __init__                                                                     │
[rank24]: │                                                                              │
[rank24]: │   1712 │   │   │   │   if wandb.run is None:                                 │
[rank24]: │   1713 │   │   │   │   │   load_object_store.init(self.state, self.logger)   │
[rank24]: │   1714 │   │   │   _, _, parsed_load_path = parse_uri(load_path)             │
[rank24]: │ ❱ 1715 │   │   │   self._rng_state = checkpoint.load_checkpoint(             │
[rank24]: │   1716 │   │   │   │   state=self.state,                                     │
[rank24]: │   1717 │   │   │   │   logger=self.logger,                                   │
[rank24]: │   1718 │   │   │   │   path=parsed_load_path,                                │
[rank24]: │                                                                              │
[rank24]: │ /leonardo_scratch/large/userexternal/rorland1/python-envs/llm-foundry-0.8.0- │
[rank24]: │ venv/lib/python3.11/site-packages/composer/utils/checkpoint.py:531 in        │
[rank24]: │ load_checkpoint                                                              │
[rank24]: │                                                                              │
[rank24]: │    528 │   │   │   │   │   fsdp_sharded_state_dict_enabled=state.fsdp_sharde │
[rank24]: │    529 │   │   │   │   │   deepspeed_sharded_checkpoint=is_model_deepspeed(s │
[rank24]: │    530 │   │   │   │   )                                                     │
[rank24]: │ ❱  531 │   │   │   │   rng_state_dicts = _restore_checkpoint(                │
[rank24]: │    532 │   │   │   │   │   state,                                            │
[rank24]: │    533 │   │   │   │   │   logger,                                           │
[rank24]: │    534 │   │   │   │   │   composer_states_filepath,                         │
[rank24]: │                                                                              │
[rank24]: │ /leonardo_scratch/large/userexternal/rorland1/python-envs/llm-foundry-0.8.0- │
[rank24]: │ venv/lib/python3.11/site-packages/composer/utils/checkpoint.py:999 in        │
[rank24]: │ _restore_checkpoint                                                          │
[rank24]: │                                                                              │
[rank24]: │    996 │   │   │   algorithm_passes=algorithm_passes,                        │
[rank24]: │    997 │   │   )                                                             │
[rank24]: │    998 │   if not load_weights_only:                                         │
[rank24]: │ ❱  999 │   │   state.load_state_dict(                                        │
[rank24]: │   1000 │   │   │   state_dict['state'],                                      │
[rank24]: │   1001 │   │   │   logger,                                                   │
[rank24]: │   1002 │   │   │   exclude_algorithms=exclude_algorithms,                    │
[rank24]: │                                                                              │
[rank24]: │ /leonardo_scratch/large/userexternal/rorland1/python-envs/llm-foundry-0.8.0- │
[rank24]: │ venv/lib/python3.11/site-packages/composer/core/state.py:1418 in             │
[rank24]: │ load_state_dict                                                              │
[rank24]: │                                                                              │
[rank24]: │   1415 │   │   │   if attribute_name == 'dataset_state':                     │
[rank24]: │   1416 │   │   │   │   self._load_dataset_state(serialized_value)            │
[rank24]: │   1417 │   │   │   elif attribute_name == 'optimizers':                      │
[rank24]: │ ❱ 1418 │   │   │   │   self.load_optim_state(state)                          │
[rank24]: │   1419 │   │   │   elif attribute_name == 'train_metrics':                   │
[rank24]: │   1420 │   │   │   │   # Get current metrics object and populate each metric │
[rank24]: │   1421 │   │   │   │   # in serialization with serialized data via load_stat │
[rank24]: │                                                                              │
[rank24]: │ /leonardo_scratch/large/userexternal/rorland1/python-envs/llm-foundry-0.8.0- │
[rank24]: │ venv/lib/python3.11/site-packages/composer/core/state.py:1331 in             │
[rank24]: │ load_optim_state                                                             │
[rank24]: │                                                                              │
[rank24]: │   1328 │   │   │   │   # errors) before discarding the output. Accordingly,  │
[rank24]: │   1329 │   │   │   │   # See: https://github.com/pytorch/pytorch/issues/1251 │
[rank24]: │   1330 │   │   │   │   optim_state_dict = MagicMock() if optim_state_dict is │
[rank24]: │ ❱ 1331 │   │   │   │   set_optimizer_state_dict(                             │
[rank24]: │   1332 │   │   │   │   │   model=self.model,                                 │
[rank24]: │   1333 │   │   │   │   │   optimizers=optimizer,                             │
[rank24]: │   1334 │   │   │   │   │   optim_state_dict=optim_state_dict,                │
[rank24]: │                                                                              │
[rank24]: │ /leonardo_scratch/large/userexternal/rorland1/python-envs/llm-foundry-0.8.0- │
[rank24]: │ venv/lib/python3.11/site-packages/composer/trainer/mosaic_fsdp_utils.py:719  │
[rank24]: │ in set_optimizer_state_dict                                                  │
[rank24]: │                                                                              │
[rank24]: │   716 │   │   │   info = _verify_options(model, optimizers, optim_only=True, │
[rank24]: │   717 │   │   │                                                              │
[rank24]: │   718 │   │   │   _verify_state_dict({}, optim_state_dict, info)             │
[rank24]: │ ❱ 719 │   │   │   _load_optim_state_dict(model, optimizers, optim_state_dict │
[rank24]: │   720                                                                        │
[rank24]: │                                                                              │
[rank24]: │ /leonardo_scratch/large/userexternal/rorland1/python-envs/llm-foundry-0.8.0- │
[rank24]: │ venv/lib/python3.11/site-packages/torch/distributed/checkpoint/state_dict.py │
[rank24]: │ :616 in _load_optim_state_dict                                               │
[rank24]: │                                                                              │
[rank24]: │    613 │   │   │   │   │   │   osd_state[k.replace(fqn, fqn_with_compiler)]  │
[rank24]: │    614 │   │   │                                                             │
[rank24]: │    615 │   │   │   with info.fsdp_context():                                 │
[rank24]: │ ❱  616 │   │   │   │   optim_state_dict = FSDP.optim_state_dict_to_load(     │
[rank24]: │    617 │   │   │   │   │   model, optim, optim_state_dict                    │
[rank24]: │    618 │   │   │   │   )                                                     │
[rank24]: │    619                                                                       │
[rank24]: │                                                                              │
[rank24]: │ /leonardo_scratch/large/userexternal/rorland1/python-envs/llm-foundry-0.8.0- │
[rank24]: │ venv/lib/python3.11/site-packages/torch/distributed/fsdp/fully_sharded_data_ │
[rank24]: │ parallel.py:1928 in optim_state_dict_to_load                                 │
[rank24]: │                                                                              │
[rank24]: │   1925 │   │   │   │   Default: ``None``)                                    │
[rank24]: │   1926 │   │   """                                                           │
[rank24]: │   1927 │   │   state_dict_settings = FullyShardedDataParallel.get_state_dict │
[rank24]: │ ❱ 1928 │   │   result = FullyShardedDataParallel._optim_state_dict_to_load_i │
[rank24]: │   1929 │   │   │   optim_state_dict=optim_state_dict,                        │
[rank24]: │   1930 │   │   │   model=model,                                              │
[rank24]: │   1931 │   │   │   optim_input=None,                                         │
[rank24]: │                                                                              │
[rank24]: │ /leonardo_scratch/large/userexternal/rorland1/python-envs/llm-foundry-0.8.0- │
[rank24]: │ venv/lib/python3.11/site-packages/torch/distributed/fsdp/fully_sharded_data_ │
[rank24]: │ parallel.py:1319 in _optim_state_dict_to_load_impl                           │
[rank24]: │                                                                              │
[rank24]: │   1316 │   │                                                                 │
[rank24]: │   1317 │   │   if rank0_only and dist.get_rank(group) > 0:                   │
[rank24]: │   1318 │   │   │   optim_state_dict = {}                                     │
[rank24]: │ ❱ 1319 │   │   sharded_osd = _flatten_optim_state_dict(                      │
[rank24]: │   1320 │   │   │   optim_state_dict,                                         │
[rank24]: │   1321 │   │   │   model=model,                                              │
[rank24]: │   1322 │   │   │   use_orig_params=use_orig_params,                          │
[rank24]: │                                                                              │
[rank24]: │ /leonardo_scratch/large/userexternal/rorland1/python-envs/llm-foundry-0.8.0- │
[rank24]: │ venv/lib/python3.11/site-packages/torch/distributed/fsdp/_optim_utils.py:461 │
[rank24]: │  in _flatten_optim_state_dict                                                │
[rank24]: │                                                                              │
[rank24]: │    458 │                                                                     │
[rank24]: │    459 │   # Construct the "state" part                                      │
[rank24]: │    460 │   flat_osd_state: Dict[Union[_OptimStateKey, str], Any] = {}        │
[rank24]: │ ❱  461 │   unflat_osd_state = unflat_osd["state"]                            │
[rank24]: │    462 │   all_state_keys = set(unflat_osd_state.keys())                     │
[rank24]: │    463 │                                                                     │
[rank24]: │    464 │   for param, fqns in param_to_fqns.items():                         │
[rank24]: ╰──────────────────────────────────────────────────────────────────────────────╯
[rank24]: KeyError: 'state'

With prior version of llm-foundry I didn't have this issue (albeit I was using FULL_SHARD strategy), so I tried to change to the old Composer code for resuming the optimizer, from version <0.22.0 but it doesn't work and I get a different error:

[rank118]: ╭───────────────────── Traceback (most recent call last) ──────────────────────╮
[rank118]: │ /leonardo/home/userexternal/rorland1/llm-foundry/scripts/slurm/base_scripts/ │
[rank118]: │ ../../train/train.py:786 in <module>                                         │
[rank118]: │                                                                              │
[rank118]: │   783 │   cfg = om.merge(yaml_cfg, cli_cfg)                                  │
[rank118]: │   784 │   om.resolve(cfg)                                                    │
[rank118]: │   785 │   assert isinstance(cfg, DictConfig)                                 │
[rank118]: │ ❱ 786 │   main(cfg)                                                          │
[rank118]: │   787                                                                        │
[rank118]: │                                                                              │
[rank118]: │ /leonardo/home/userexternal/rorland1/llm-foundry/scripts/slurm/base_scripts/ │
[rank118]: │ ../../train/train.py:717 in main                                             │
[rank118]: │                                                                              │
[rank118]: │   714 │                                                                      │
[rank118]: │   715 │   # Build the Trainer                                                │
[rank118]: │   716 │   log.info('Building trainer...')                                    │
[rank118]: │ ❱ 717 │   trainer = Trainer(                                                 │
[rank118]: │   718 │   │   run_name=run_name,                                             │
[rank118]: │   719 │   │   seed=seed,                                                     │
[rank118]: │   720 │   │   model=model,                                                   │
[rank118]: │                                                                              │
[rank118]: │ /leonardo_scratch/large/userexternal/rorland1/python-envs/llm-foundry-0.8.0- │
[rank118]: │ venv/lib/python3.11/site-packages/composer/trainer/trainer.py:1715 in        │
[rank118]: │ __init__                                                                     │
[rank118]: │                                                                              │
[rank118]: │   1712 │   │   │   │   if wandb.run is None:                                 │
[rank118]: │   1713 │   │   │   │   │   load_object_store.init(self.state, self.logger)   │
[rank118]: │   1714 │   │   │   _, _, parsed_load_path = parse_uri(load_path)             │
[rank118]: │ ❱ 1715 │   │   │   self._rng_state = checkpoint.load_checkpoint(             │
[rank118]: │   1716 │   │   │   │   state=self.state,                                     │
[rank118]: │   1717 │   │   │   │   logger=self.logger,                                   │
[rank118]: │   1718 │   │   │   │   path=parsed_load_path,                                │
[rank118]: │                                                                              │
[rank118]: │ /leonardo_scratch/large/userexternal/rorland1/python-envs/llm-foundry-0.8.0- │
[rank118]: │ venv/lib/python3.11/site-packages/composer/utils/checkpoint.py:558 in        │
[rank118]: │ load_checkpoint                                                              │
[rank118]: │                                                                              │
[rank118]: │    555 │   dist.all_reduce(max_step_to_resume_from, reduce_operation='MAX')  │
[rank118]: │    556 │   dist.all_reduce(min_step_to_resume_from, reduce_operation='MIN')  │
[rank118]: │    557 │   if max_step_to_resume_from.data != min_step_to_resume_from.data:  │
[rank118]: │ ❱  558 │   │   raise RuntimeError(                                           │
[rank118]: │    559 │   │   │   textwrap.dedent(                                          │
[rank118]: │    560 │   │   │   │   f'Timestamp mismatch error: batch to resume from {ste │
[rank118]: │    561 │   │   │   │   'This usually occurs when at least one rank fails to  │
[rank118]: ╰──────────────────────────────────────────────────────────────────────────────╯
[rank118]: RuntimeError: Timestamp mismatch error: batch to resume from 10000 is not the 
[rank118]: same on all ranks. This usually occurs when at least one rank fails to save the 
[rank118]: last checkpoint while using sharded checkpointing + autoresume. Please manually 
[rank118]: resume by disabling autoresume and explicitly setting load_path to the most 
[rank118]: recent checkpoints that all ranks have saved. E.g. for the 10th batch: trainer =
[rank118]: Trainer(autoresume=False, load_path="/path/to/checkpoint/ba10-rank{rank}.pt", 
[rank118]: ...). Remember to keep the {rank} placeholder!

Are you experiencing a similar issue? Or do you have any hints?

@germanjke
Copy link
Author

I also have 0.8.0 and HYBRID_SHARD

@irenedea irenedea linked a pull request May 29, 2024 that will close this issue
@eracah
Copy link
Contributor

eracah commented Jun 6, 2024

@Riccorl, it seems like your problem is separate from @germanjke's. can you file a new issue with some more information like:

  • details of the model you are loading (is this a LLama Pro issue?)
  • details on the checkpoint you are loading
  • ideally these details could be easily repo'd

@nik-mosaic
Copy link
Contributor

@germanjke can you try using freezing layers as part of the Optimizer, rather the Composer layer freezing algorithm. Freezing via the optimizer is more well-tested. For example,

optimizer:
    lr: <your_learning_rate>
    name: decoupled_adamw
    disable_grad: ^((?!(Wqkv|out_proj)).)*$ # Regex which disables gradients except for attention and out_proj

@mvpatel2000
Copy link
Collaborator

I'm facing a similar issue with the latest release (0.8.0). When resuming from a monolithic checkpoint with HYBRID_SHARD I get the following error (KeyError: 'state'):

@Riccorl I have identified this as a PyTorch issue and opened a bug report on their end + a PR to fix it

@dakinggg dakinggg added the question Further information is requested label Jul 2, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

6 participants