-
Notifications
You must be signed in to change notification settings - Fork 476
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Resuming training on unsharded checkpoint #641
Comments
Could you share your config and details about the checkpoint you started from (if it's an official one)? Also, it looks like your run is using the default 'legacy' checkpointer. You may have success with another checkpointer like the |
Hi I am using the official config for Olmo1B. Only thing I've amended was changing the tokenizer to the dolma one. Here are the details of the config. Also wanted to check when we resume, do we resume using the config file saved inside the saved checkpoints folder or we continue using the existing config file we originally started with?
|
I normally use the config file that I started training with, but I imagine but can work fine. Could you share a bit more details about the checkpoint you started with? Is it from an official run or a run you did? Was it unsharded manually (using, say, |
We started the checkpoint from a run we did and it was produced as an unsharded checkpoint by training. |
I also tried resuming the sharded checkpoints produced by training and it had no issues either in resuming or saving subsequently |
🐛 Describe the bug
I tried resuming training on a previous unsharded checkpoint from step 1k and the training resumed with no initial issue however when it tried to save the sharded checkpoint i encountered a error as shown below wondering what is causing this issue? For context, the env/node number used are all the same.
Traceback (most recent call last):
File "/mnt/azureml/cr/j/947c8b089dfe4d0484df42634f296716/exe/wd/scripts/train.py", line 345, in
main(cfg)
File "/mnt/azureml/cr/j/947c8b089dfe4d0484df42634f296716/exe/wd/scripts/train.py", line 316, in main
trainer.fit()
File "/workspace/OLMo/olmo/train.py", line 1153, in fit
checkpoint_path, _ = self.save_checkpoint(CheckpointType.sharded)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/workspace/OLMo/olmo/train.py", line 560, in save_checkpoint
result = self.save_sharded_checkpoint()
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/workspace/OLMo/olmo/train.py", line 468, in save_sharded_checkpoint
result = self._save_checkpoint(checkpointer, CheckpointType.sharded)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/workspace/OLMo/olmo/train.py", line 428, in _save_checkpoint
checkpointer.save_checkpoint(
File "/workspace/OLMo/olmo/checkpoint.py", line 1000, in save_checkpoint
"optim": FSDP.optim_state_dict(dist_model, optim),
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.11/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 1832, in optim_state_dict
return FullyShardedDataParallel._optim_state_dict_impl(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.11/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 1255, in _optim_state_dict_impl
return _optim_state_dict(
^^^^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.11/site-packages/torch/distributed/fsdp/_optim_utils.py", line 1972, in _optim_state_dict
fsdp_osd_state = convert_fn(
^^^^^^^^^^^
File "/opt/conda/lib/python3.11/site-packages/torch/distributed/fsdp/_optim_utils.py", line 1795, in _convert_state_with_orig_params
_gather_all_orig_param_state(
File "/opt/conda/lib/python3.11/site-packages/torch/distributed/fsdp/_optim_utils.py", line 1689, in _gather_all_orig_param_state
output_states = _allgather_orig_param_states(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.11/site-packages/torch/distributed/fsdp/_optim_utils.py", line 1519, in _allgather_orig_param_states
dtype, state_buffers = _convert_all_state_info(
^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.11/site-packages/torch/distributed/fsdp/_optim_utils.py", line 1415, in _convert_all_state_info
assert curr_scalar_tensor_value is None or torch.equal(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: Rank 4 has different values for step: 1500.0. Other ranks: 500.0
Versions
.
The text was updated successfully, but these errors were encountered: