From 86630585e60f80d387acb269196b4fb4a5b60586 Mon Sep 17 00:00:00 2001 From: Wanchao Date: Wed, 12 Jun 2024 17:23:13 -0700 Subject: [PATCH] enable TP fp8 allgather with PrepareFloat8ModuleInput (#393) This PR is a follow up PR to enable fp8 allgather in TP after these PR landed: * https://github.com/pytorch/pytorch/pull/128431 * https://github.com/pytorch-labs/float8_experimental/pull/275 One need to update their pytorch/float8_experimental to have those changes in to train with fp8 changes. Since fp8 is not enabled as part of our integration tests yet, there should be no issues on CI or trains that does not use fp8 --- torchtitan/parallelisms/parallelize_llama.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/torchtitan/parallelisms/parallelize_llama.py b/torchtitan/parallelisms/parallelize_llama.py index 2d8e2150..754fc6a3 100644 --- a/torchtitan/parallelisms/parallelize_llama.py +++ b/torchtitan/parallelisms/parallelize_llama.py @@ -114,7 +114,7 @@ def selective_checkpointing_context_fn(): def get_tp_parallel_strategy( job_config: JobConfig, -) -> Tuple[RowwiseParallel, ColwiseParallel]: +) -> Tuple[RowwiseParallel, ColwiseParallel, PrepareModuleInput]: """Get the parallel strategy for the transformer model. This function handles the special case of using float8 with tensor parallelism. @@ -123,10 +123,11 @@ def get_tp_parallel_strategy( from float8_experimental.float8_tensor_parallel import ( Float8ColwiseParallel, Float8RowwiseParallel, + PrepareFloat8ModuleInput, ) - return Float8RowwiseParallel, Float8ColwiseParallel - return RowwiseParallel, ColwiseParallel + return Float8RowwiseParallel, Float8ColwiseParallel, PrepareFloat8ModuleInput + return RowwiseParallel, ColwiseParallel, PrepareModuleInput def pipeline_llama( @@ -299,9 +300,11 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig): ) tp_mesh = world_mesh["tp"] - row_parallel_strategy, col_parallel_strategy = get_tp_parallel_strategy( - job_config - ) + ( + row_parallel_strategy, + col_parallel_strategy, + prepare_module_input, + ) = get_tp_parallel_strategy(job_config) loss_parallel = parallel_dims.loss_parallel_enabled # 1. Parallelize the first embedding and the last linear proj layer @@ -327,7 +330,7 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig): # Apply tensor + sequence parallelism to every transformer block for layer_id, transformer_block in model.layers.items(): layer_plan = { - "attention": PrepareModuleInput( + "attention": prepare_module_input( input_layouts=(Shard(1), None), desired_input_layouts=(Replicate(), None), ), @@ -336,7 +339,7 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig): "attention.wv": col_parallel_strategy(), "attention.wo": row_parallel_strategy(output_layouts=Shard(1)), "attention_norm": SequenceParallel(), - "feed_forward": PrepareModuleInput( + "feed_forward": prepare_module_input( input_layouts=(Shard(1),), desired_input_layouts=(Replicate(),), ),