Skip to content
This repository has been archived by the owner on Nov 1, 2024. It is now read-only.

feat: add unified reshard script #734

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

mattmazzola
Copy link
Contributor

@mattmazzola mattmazzola commented Jun 9, 2023

Issue

Training requires flattened models, any MP and FSDP
Inference requires unflattened models with FSDP 1

We wanted AML jobs which train model (which produces Flattened checkpoint output), reshards the checkpoints to Unflatten with FSDP 1 for inference, and evaluates the model outputs

  • Did not have a way to convert flattened checkpoints from one model parallel version to another
  • Did not have a way to convert an inference checkpoint to a training checkpoint
  • Converting a trained checkpoint for inference requires several steps

Solution

  • Add unified reshard script
    • Allows modifying both MP and FSDP in single script
    • Supports flattening/unflattening the model files
    • Automatically infers and adds shard_metadata
    • Improved logging output to verify expect transformation

Testing steps

Did not test although this file was direct from our fork without modification so it likely works as is

Video or Screenshots

Flow Chart of Steps in Script

MPxFSDP (2x8 -> 8x1)

2023-06-01 19:06:53 | INFO | __main__ | Found 16 sharded checkpoints (checkpoint_last-model_part-0-shard0.pt to checkpoint_last-model_part-1-shard7.pt)
2023-06-01 19:06:53 | INFO | __main__ | Loading all sharded checkpoints to CPU
2023-06-01 19:07:01 | INFO | __main__ | Detected Input Model State:
2023-06-01 19:07:01 | INFO | __main__ | - Model Parallel (MP) factor:			2
2023-06-01 19:07:01 | INFO | __main__ | - Fully Sharded Data Parallel (FSDP) factor:	8
2023-06-01 19:07:01 | INFO | __main__ | - Model Weights:				Flattened
2023-06-01 19:07:01 | INFO | __main__ | Desired Output Model State:
2023-06-01 19:07:01 | INFO | __main__ | - Model Parallel (MP) factor:			8
2023-06-01 19:07:01 | INFO | __main__ | - Fully Sharded Data Parallel (FSDP) factor:	1
2023-06-01 19:07:01 | INFO | __main__ | - Model Weights:				Unflattened
2023-06-01 19:07:01 | INFO | __main__ | Resharding model parallel from 2 to 8
2023-06-01 19:07:01 | INFO | __main__ | You attempted to change MP from 2 to 8 but the models was flattened. It must be unflattened with FSDP 1 to change model parallel factor. Unflattening and consolidating the weights.
2023-06-01 19:07:01 | INFO | __main__ | Resharding state dicts into 1 fsdp shard(s)
2023-06-01 19:07:01 | INFO | __main__ | Resharding model weights into 1 shard(s)
2023-06-01 19:07:02 | INFO | __main__ | Resharding state dicts into 1 fsdp shard(s)
2023-06-01 19:07:02 | INFO | __main__ | Resharding model weights into 1 shard(s)
2023-06-01 19:07:02 | INFO | __main__ | Allocating memory for unsharded checkpoint
2023-06-01 19:07:03 | WARNING | __main__ | Max value discrepancy for key 'decoder.embed_positions.weight': 4.7684e-06
2023-06-01 19:07:04 | INFO | __main__ | Resharding state dict for model parallel part 0
2023-06-01 19:07:04 | INFO | __main__ | Resharding state dict for model parallel part 1
2023-06-01 19:07:04 | INFO | __main__ | Resharding state dict for model parallel part 2
2023-06-01 19:07:04 | INFO | __main__ | Resharding state dict for model parallel part 3
2023-06-01 19:07:04 | INFO | __main__ | Resharding state dict for model parallel part 4
2023-06-01 19:07:04 | INFO | __main__ | Resharding state dict for model parallel part 5
2023-06-01 19:07:04 | INFO | __main__ | Resharding state dict for model parallel part 6
2023-06-01 19:07:04 | INFO | __main__ | Resharding state dict for model parallel part 7
2023-06-01 19:07:05 | INFO | __main__ | n_layers: 24
2023-06-01 19:07:06 | INFO | __main__ | Writing a resharded state dict to //amltbf97debed90d3226669ab66b89d50d4a/projects/mattm-nlg-distill/amlt-results/7314329603.77649-0eb9a9d5-5ed5-4908-9e39-7cafc2438bcd/opt-1.3b-e2e_nlg-human--full-20230601.1144/reshard/reshard-model_part-0.pt
2023-06-01 19:07:10 | INFO | __main__ | Writing a resharded state dict to //amltbf97debed90d3226669ab66b89d50d4a/projects/mattm-nlg-distill/amlt-results/7314329603.77649-0eb9a9d5-5ed5-4908-9e39-7cafc2438bcd/opt-1.3b-e2e_nlg-human--full-20230601.1144/reshard/reshard-model_part-1.pt
2023-06-01 19:07:14 | INFO | __main__ | Writing a resharded state dict to //amltbf97debed90d3226669ab66b89d50d4a/projects/mattm-nlg-distill/amlt-results/7314329603.77649-0eb9a9d5-5ed5-4908-9e39-7cafc2438bcd/opt-1.3b-e2e_nlg-human--full-20230601.1144/reshard/reshard-model_part-2.pt
2023-06-01 19:07:18 | INFO | __main__ | Writing a resharded state dict to //amltbf97debed90d3226669ab66b89d50d4a/projects/mattm-nlg-distill/amlt-results/7314329603.77649-0eb9a9d5-5ed5-4908-9e39-7cafc2438bcd/opt-1.3b-e2e_nlg-human--full-20230601.1144/reshard/reshard-model_part-3.pt
2023-06-01 19:07:23 | INFO | __main__ | Writing a resharded state dict to //amltbf97debed90d3226669ab66b89d50d4a/projects/mattm-nlg-distill/amlt-results/7314329603.77649-0eb9a9d5-5ed5-4908-9e39-7cafc2438bcd/opt-1.3b-e2e_nlg-human--full-20230601.1144/reshard/reshard-model_part-4.pt
2023-06-01 19:07:27 | INFO | __main__ | Writing a resharded state dict to //amltbf97debed90d3226669ab66b89d50d4a/projects/mattm-nlg-distill/amlt-results/7314329603.77649-0eb9a9d5-5ed5-4908-9e39-7cafc2438bcd/opt-1.3b-e2e_nlg-human--full-20230601.1144/reshard/reshard-model_part-5.pt
2023-06-01 19:07:31 | INFO | __main__ | Writing a resharded state dict to //amltbf97debed90d3226669ab66b89d50d4a/projects/mattm-nlg-distill/amlt-results/7314329603.77649-0eb9a9d5-5ed5-4908-9e39-7cafc2438bcd/opt-1.3b-e2e_nlg-human--full-20230601.1144/reshard/reshard-model_part-6.pt
2023-06-01 19:07:35 | INFO | __main__ | Writing a resharded state dict to //amltbf97debed90d3226669ab66b89d50d4a/projects/mattm-nlg-distill/amlt-results/7314329603.77649-0eb9a9d5-5ed5-4908-9e39-7cafc2438bcd/opt-1.3b-e2e_nlg-human--full-20230601.1144/reshard/reshard-model_part-7.pt

Flattened to Unflattened Output (MP and FSDP unchanged):

Found 2 sharded checkpoints (checkpoint_last-model_part-0-shard0.pt to checkpoint_last-model_part-1-shard0.pt)
Loading all sharded checkpoints to CPU
Detected Input Model State:
- Model Parallel (MP) factor:                   2
- Fully Sharded Data Parallel (FSDP) factor:    1
- Model Weights:                                Flattened
Desired Output Model State:
- Model Parallel (MP) factor:                   2
- Fully Sharded Data Parallel (FSDP) factor:    1
- Model Weights:                                Unflattened
Current model Flattened but desired model is Unflattened
Unflattened model
Resharding state dicts into 1 fsdp shard(s)
Resharding model weights into 1 shard(s)
Resharding state dicts into 1 fsdp shard(s)
Resharding model weights into 1 shard(s)
Writing a resharded state dict to _results/1.3b-2x1-unflattened/reshard-model_part-0-shard0.pt
Writing a resharded state dict to _results/1.3b-2x1-unflattened/reshard-model_part-1-shard0.pt

Related to #726

Most of work was done by @sahajgg

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants