Skip to content

Commit

Permalink
enable TP fp8 allgather with PrepareFloat8ModuleInput (pytorch#393)
Browse files Browse the repository at this point in the history
This PR is a follow up PR to enable fp8 allgather in TP after these PR
landed:
* pytorch/pytorch#128431
* pytorch-labs/float8_experimental#275

One need to update their pytorch/float8_experimental to have those
changes in to train with fp8 changes.

Since fp8 is not enabled as part of our integration tests yet, there
should be no issues on CI or trains that does not use fp8
  • Loading branch information
wanchaol committed Jun 13, 2024
1 parent c782cc7 commit 8663058
Showing 1 changed file with 11 additions and 8 deletions.
19 changes: 11 additions & 8 deletions torchtitan/parallelisms/parallelize_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def selective_checkpointing_context_fn():

def get_tp_parallel_strategy(
job_config: JobConfig,
) -> Tuple[RowwiseParallel, ColwiseParallel]:
) -> Tuple[RowwiseParallel, ColwiseParallel, PrepareModuleInput]:
"""Get the parallel strategy for the transformer model.
This function handles the special case of using float8 with tensor parallelism.
Expand All @@ -123,10 +123,11 @@ def get_tp_parallel_strategy(
from float8_experimental.float8_tensor_parallel import (
Float8ColwiseParallel,
Float8RowwiseParallel,
PrepareFloat8ModuleInput,
)

return Float8RowwiseParallel, Float8ColwiseParallel
return RowwiseParallel, ColwiseParallel
return Float8RowwiseParallel, Float8ColwiseParallel, PrepareFloat8ModuleInput
return RowwiseParallel, ColwiseParallel, PrepareModuleInput


def pipeline_llama(
Expand Down Expand Up @@ -299,9 +300,11 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig):
)

tp_mesh = world_mesh["tp"]
row_parallel_strategy, col_parallel_strategy = get_tp_parallel_strategy(
job_config
)
(
row_parallel_strategy,
col_parallel_strategy,
prepare_module_input,
) = get_tp_parallel_strategy(job_config)
loss_parallel = parallel_dims.loss_parallel_enabled

# 1. Parallelize the first embedding and the last linear proj layer
Expand All @@ -327,7 +330,7 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig):
# Apply tensor + sequence parallelism to every transformer block
for layer_id, transformer_block in model.layers.items():
layer_plan = {
"attention": PrepareModuleInput(
"attention": prepare_module_input(
input_layouts=(Shard(1), None),
desired_input_layouts=(Replicate(), None),
),
Expand All @@ -336,7 +339,7 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig):
"attention.wv": col_parallel_strategy(),
"attention.wo": row_parallel_strategy(output_layouts=Shard(1)),
"attention_norm": SequenceParallel(),
"feed_forward": PrepareModuleInput(
"feed_forward": prepare_module_input(
input_layouts=(Shard(1),),
desired_input_layouts=(Replicate(),),
),
Expand Down

0 comments on commit 8663058

Please sign in to comment.