From bedb16a513462ef1da24d072eca2179302f2ed0a Mon Sep 17 00:00:00 2001 From: Wanchao Liang Date: Wed, 12 Jun 2024 16:14:35 -0700 Subject: [PATCH] enable TP fp8 allgather with PrepareFloat8ModuleInput This PR is a follow up PR to enable fp8 allgather in TP after these PR landed: * https://github.com/pytorch/pytorch/pull/128431 * https://github.com/pytorch-labs/float8_experimental/pull/275 One need to update their pytorch/float8_experimental to have those changes in to train with fp8 changes. Since fp8 is not enabled as part of our integration tests yet, there should be no issues on CI --- torchtitan/parallelisms/parallelize_llama.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/torchtitan/parallelisms/parallelize_llama.py b/torchtitan/parallelisms/parallelize_llama.py index 2d8e2150..754fc6a3 100644 --- a/torchtitan/parallelisms/parallelize_llama.py +++ b/torchtitan/parallelisms/parallelize_llama.py @@ -114,7 +114,7 @@ def selective_checkpointing_context_fn(): def get_tp_parallel_strategy( job_config: JobConfig, -) -> Tuple[RowwiseParallel, ColwiseParallel]: +) -> Tuple[RowwiseParallel, ColwiseParallel, PrepareModuleInput]: """Get the parallel strategy for the transformer model. This function handles the special case of using float8 with tensor parallelism. @@ -123,10 +123,11 @@ def get_tp_parallel_strategy( from float8_experimental.float8_tensor_parallel import ( Float8ColwiseParallel, Float8RowwiseParallel, + PrepareFloat8ModuleInput, ) - return Float8RowwiseParallel, Float8ColwiseParallel - return RowwiseParallel, ColwiseParallel + return Float8RowwiseParallel, Float8ColwiseParallel, PrepareFloat8ModuleInput + return RowwiseParallel, ColwiseParallel, PrepareModuleInput def pipeline_llama( @@ -299,9 +300,11 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig): ) tp_mesh = world_mesh["tp"] - row_parallel_strategy, col_parallel_strategy = get_tp_parallel_strategy( - job_config - ) + ( + row_parallel_strategy, + col_parallel_strategy, + prepare_module_input, + ) = get_tp_parallel_strategy(job_config) loss_parallel = parallel_dims.loss_parallel_enabled # 1. Parallelize the first embedding and the last linear proj layer @@ -327,7 +330,7 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig): # Apply tensor + sequence parallelism to every transformer block for layer_id, transformer_block in model.layers.items(): layer_plan = { - "attention": PrepareModuleInput( + "attention": prepare_module_input( input_layouts=(Shard(1), None), desired_input_layouts=(Replicate(), None), ), @@ -336,7 +339,7 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig): "attention.wv": col_parallel_strategy(), "attention.wo": row_parallel_strategy(output_layouts=Shard(1)), "attention_norm": SequenceParallel(), - "feed_forward": PrepareModuleInput( + "feed_forward": prepare_module_input( input_layouts=(Shard(1),), desired_input_layouts=(Replicate(),), ),