Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Resuming training on unsharded checkpoint #641

Open
lecifire opened this issue Jul 4, 2024 · 5 comments
Open

Resuming training on unsharded checkpoint #641

lecifire opened this issue Jul 4, 2024 · 5 comments
Labels
type/bug An issue about a bug

Comments

@lecifire
Copy link

lecifire commented Jul 4, 2024

🐛 Describe the bug

I tried resuming training on a previous unsharded checkpoint from step 1k and the training resumed with no initial issue however when it tried to save the sharded checkpoint i encountered a error as shown below wondering what is causing this issue? For context, the env/node number used are all the same.

Traceback (most recent call last):
File "/mnt/azureml/cr/j/947c8b089dfe4d0484df42634f296716/exe/wd/scripts/train.py", line 345, in
main(cfg)
File "/mnt/azureml/cr/j/947c8b089dfe4d0484df42634f296716/exe/wd/scripts/train.py", line 316, in main
trainer.fit()
File "/workspace/OLMo/olmo/train.py", line 1153, in fit
checkpoint_path, _ = self.save_checkpoint(CheckpointType.sharded)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/workspace/OLMo/olmo/train.py", line 560, in save_checkpoint
result = self.save_sharded_checkpoint()
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/workspace/OLMo/olmo/train.py", line 468, in save_sharded_checkpoint
result = self._save_checkpoint(checkpointer, CheckpointType.sharded)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/workspace/OLMo/olmo/train.py", line 428, in _save_checkpoint
checkpointer.save_checkpoint(
File "/workspace/OLMo/olmo/checkpoint.py", line 1000, in save_checkpoint
"optim": FSDP.optim_state_dict(dist_model, optim),
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.11/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 1832, in optim_state_dict
return FullyShardedDataParallel._optim_state_dict_impl(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.11/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 1255, in _optim_state_dict_impl
return _optim_state_dict(
^^^^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.11/site-packages/torch/distributed/fsdp/_optim_utils.py", line 1972, in _optim_state_dict
fsdp_osd_state = convert_fn(
^^^^^^^^^^^
File "/opt/conda/lib/python3.11/site-packages/torch/distributed/fsdp/_optim_utils.py", line 1795, in _convert_state_with_orig_params
_gather_all_orig_param_state(
File "/opt/conda/lib/python3.11/site-packages/torch/distributed/fsdp/_optim_utils.py", line 1689, in _gather_all_orig_param_state
output_states = _allgather_orig_param_states(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.11/site-packages/torch/distributed/fsdp/_optim_utils.py", line 1519, in _allgather_orig_param_states
dtype, state_buffers = _convert_all_state_info(
^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.11/site-packages/torch/distributed/fsdp/_optim_utils.py", line 1415, in _convert_all_state_info
assert curr_scalar_tensor_value is None or torch.equal(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: Rank 4 has different values for step: 1500.0. Other ranks: 500.0

Versions

.

@lecifire lecifire added the type/bug An issue about a bug label Jul 4, 2024
@2015aroras
Copy link
Collaborator

Could you share your config and details about the checkpoint you started from (if it's an official one)? Also, it looks like your run is using the default 'legacy' checkpointer. You may have success with another checkpointer like the olmo_core one (pass --sharded_checkpointer=olmo_core to train.py, for example).

@lecifire
Copy link
Author

lecifire commented Jul 10, 2024

Hi I am using the official config for Olmo1B. Only thing I've amended was changing the tokenizer to the dolma one. Here are the details of the config. Also wanted to check when we resume, do we resume using the config file saved inside the saved checkpoints folder or we continue using the existing config file we originally started with?

seed: 6198
dry_run: false

wandb:
  name: ${run_name}
  project: olmotest
  entity: q3team

model:
  d_model: 2048
  n_heads: 16
  n_layers: 16
  mlp_ratio: 8
  weight_tying: true
  alibi: false
  rope: true
  flash_attention: true  # not available on AMD
  attention_dropout: 0.0
  attention_layer_norm: false
  multi_query_attention: false
  include_bias: false
  block_type: sequential
  layer_norm_type: default
  layer_norm_with_affine: false
  bias_for_layer_norm: false
  attention_layer_norm_with_affine: false
  activation_type: swiglu
  residual_dropout: 0.0
  embedding_dropout: 0.0
  max_sequence_length: 2048
  vocab_size: 50280 
  embedding_size: 50304
  eos_token_id: 50279
  pad_token_id: 1
  init_device: meta
  init_fn: mitchell

compile: null  # causes instability on AMD GPUs

optimizer:
  name: adamw
  learning_rate: 4.0e-4
  weight_decay: 0.1
  betas:
  - 0.9
  - 0.95
  metrics_log_interval: 10

scheduler:
  name: cosine_with_warmup
  t_warmup: 2000
  alpha_f: 0.1

tokenizer:
  identifier: tokenizers/allenai_gpt-neox-olmo-dolma-v1_5.json
  truncate_direction: right

save_folder: "./outputs"
save_overwrite: false
# Sharded checkpoints (best for restarts)
save_interval: 500
save_num_checkpoints_to_keep: 9
# Unsharded checkpoints (for final storage)
save_interval_unsharded: 1000
save_num_unsharded_checkpoints_to_keep: -1

load_path: null

max_duration: 739_328  # 3.1T tokens
global_train_batch_size: 2048
device_train_microbatch_size: 8

precision: amp_bf16

fsdp:
  wrapping_strategy: null
  precision: mixed

max_grad_norm: 1.0
max_grad_norm_ratio: null

speed_monitor:
  window_size: 20

eval_interval: ${save_interval}
eval_subset_num_batches: -1
device_eval_batch_size: ${device_train_microbatch_size}
evaluators:
  # lump all the small datasets together (we still get separate metrics).
  - label: v3-small-ppl-validation
    data:
      num_workers: 0
      drop_last: true
      datasets:
        v3-small-c4_en-validation:
          - https://olmo-data.org/eval-data/perplexity/v3_small_gptneox20b/c4_en/val/part-0-00000.npy
        v3-small-dolma_books-validation:
          - https://olmo-data.org/eval-data/perplexity/v3_small_gptneox20b/dolma_books/val/part-0-00000.npy
        v3-small-dolma_common-crawl-validation:
          - https://olmo-data.org/eval-data/perplexity/v3_small_gptneox20b/dolma_common-crawl/val/part-0-00000.npy
        v3-small-dolma_pes2o-validation:
          - https://olmo-data.org/eval-data/perplexity/v3_small_gptneox20b/dolma_pes2o/val/part-0-00000.npy
        v3-small-dolma_reddit-validation:
          - https://olmo-data.org/eval-data/perplexity/v3_small_gptneox20b/dolma_reddit/val/part-0-00000.npy
        v3-small-dolma_stack-validation:
          - https://olmo-data.org/eval-data/perplexity/v3_small_gptneox20b/dolma_stack/val/part-0-00000.npy
        v3-small-dolma_wiki-validation:
          - https://olmo-data.org/eval-data/perplexity/v3_small_gptneox20b/dolma_wiki/val/part-0-00000.npy
        v3-small-ice-validation:
          - https://olmo-data.org/eval-data/perplexity/v3_small_gptneox20b/ice/val/part-0-00000.npy
        v3-small-m2d2_s2orc-validation:
          - https://olmo-data.org/eval-data/perplexity/v3_small_gptneox20b/m2d2_s2orc/val/part-0-00000.npy
        v3-small-pile-validation:
          - https://olmo-data.org/eval-data/perplexity/v3_small_gptneox20b/pile/val/part-0-00000.npy
        v3-small-wikitext_103-validation:
          - https://olmo-data.org/eval-data/perplexity/v3_small_gptneox20b/wikitext_103/val/part-0-00000.npy

  - label: v2-small-ppl-validation
    data:
      num_workers: 0
      drop_last: true
      datasets:
        v2-small-4chan-validation:
        - https://olmo-data.org/eval-data/perplexity/v2_small_gptneox20b/4chan/val.npy
        v2-small-c4_100_domains-validation:
        - https://olmo-data.org/eval-data/perplexity/v2_small_gptneox20b/c4_100_domains/val.npy
        v2-small-c4_en-validation:
        - https://olmo-data.org/eval-data/perplexity/v2_small_gptneox20b/c4_en/val.npy
        v2-small-gab-validation:
        - https://olmo-data.org/eval-data/perplexity/v2_small_gptneox20b/gab/val.npy
        v2-small-ice-validation:
        - https://olmo-data.org/eval-data/perplexity/v2_small_gptneox20b/ice/val.npy
        v2-small-m2d2_s2orc-validation:
        - https://olmo-data.org/eval-data/perplexity/v2_small_gptneox20b/m2d2_s2orc/val.npy
        v2-small-m2d2_wiki-validation:
        - https://olmo-data.org/eval-data/perplexity/v2_small_gptneox20b/m2d2_wiki/val.npy
        v2-small-manosphere-validation:
        - https://olmo-data.org/eval-data/perplexity/v2_small_gptneox20b/manosphere/val.npy
        v2-small-mc4_en-validation:
        - https://olmo-data.org/eval-data/perplexity/v2_small_gptneox20b/mc4_en/val.npy
        v2-small-pile-validation:
        - https://olmo-data.org/eval-data/perplexity/v2_small_gptneox20b/pile/val.npy
        v2-small-ptb-validation:
        - https://olmo-data.org/eval-data/perplexity/v2_small_gptneox20b/ptb/val.npy
        v2-small-twitterAEE-validation:
        - https://olmo-data.org/eval-data/perplexity/v2_small_gptneox20b/twitterAEE/val.npy
        v2-small-wikitext_103-validation:
        - https://olmo-data.org/eval-data/perplexity/v2_small_gptneox20b/wikitext_103/val.npy

  - label: piqa
    type: downstream

  - label: hellaswag
    type: downstream

  - label: winogrande
    type: downstream

  - label: openbook_qa
    type: downstream

  # - label: boolq  # requires implemention of the pmi_dc matrix
    # type: downstream

  - label: sciq
    type: downstream

  - label: arc_easy
    type: downstream

  # - label: arc_challenge  # requires implemention of the pmi_dc matrix
  #   type: downstream

  - label: copa
    type: downstream

  - label: rte
    type: downstream

  - label: commitment_bank
    type: downstream

  - label: mrpc
    type: downstream

  - label: sst2
    type: downstream

data:
  pad_direction: right
  num_workers: 0
  drop_last: true
  pin_memory: true
  prefetch_factor: 16
  persistent_workers: true
  timeout: 0
  paths:
    - https://buildllmpremjpeast.blob.core.windows.net/dataset1/dolma_tokenized/c4-filtered/part-00-00004.npy

@2015aroras
Copy link
Collaborator

I normally use the config file that I started training with, but I imagine but can work fine.

Could you share a bit more details about the checkpoint you started with? Is it from an official run or a run you did? Was it unsharded manually (using, say, scripts/unshard.py) or produced as an unsharded checkpoint by training?

@lecifire
Copy link
Author

We started the checkpoint from a run we did and it was produced as an unsharded checkpoint by training.

@lecifire
Copy link
Author

I also tried resuming the sharded checkpoints produced by training and it had no issues either in resuming or saving subsequently

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
type/bug An issue about a bug
Projects
None yet
Development

No branches or pull requests

2 participants