Skip to content

Commit

Permalink
enable TP fp8 allgather with PrepareFloat8ModuleInput
Browse files Browse the repository at this point in the history
This PR is a follow up PR to enable fp8 allgather in TP after these PR
landed:
* pytorch/pytorch#128431
* pytorch-labs/float8_experimental#275

One need to update their pytorch/float8_experimental to have those
changes in to train with fp8 changes.

Since fp8 is not enabled as part of our integration tests yet, there
should be no issues on CI
  • Loading branch information
wanchaol committed Jun 12, 2024
1 parent 763b810 commit bedb16a
Showing 1 changed file with 11 additions and 8 deletions.
19 changes: 11 additions & 8 deletions torchtitan/parallelisms/parallelize_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def selective_checkpointing_context_fn():

def get_tp_parallel_strategy(
job_config: JobConfig,
) -> Tuple[RowwiseParallel, ColwiseParallel]:
) -> Tuple[RowwiseParallel, ColwiseParallel, PrepareModuleInput]:
"""Get the parallel strategy for the transformer model.
This function handles the special case of using float8 with tensor parallelism.
Expand All @@ -123,10 +123,11 @@ def get_tp_parallel_strategy(
from float8_experimental.float8_tensor_parallel import (
Float8ColwiseParallel,
Float8RowwiseParallel,
PrepareFloat8ModuleInput,
)

return Float8RowwiseParallel, Float8ColwiseParallel
return RowwiseParallel, ColwiseParallel
return Float8RowwiseParallel, Float8ColwiseParallel, PrepareFloat8ModuleInput
return RowwiseParallel, ColwiseParallel, PrepareModuleInput


def pipeline_llama(
Expand Down Expand Up @@ -299,9 +300,11 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig):
)

tp_mesh = world_mesh["tp"]
row_parallel_strategy, col_parallel_strategy = get_tp_parallel_strategy(
job_config
)
(
row_parallel_strategy,
col_parallel_strategy,
prepare_module_input,
) = get_tp_parallel_strategy(job_config)
loss_parallel = parallel_dims.loss_parallel_enabled

# 1. Parallelize the first embedding and the last linear proj layer
Expand All @@ -327,7 +330,7 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig):
# Apply tensor + sequence parallelism to every transformer block
for layer_id, transformer_block in model.layers.items():
layer_plan = {
"attention": PrepareModuleInput(
"attention": prepare_module_input(
input_layouts=(Shard(1), None),
desired_input_layouts=(Replicate(), None),
),
Expand All @@ -336,7 +339,7 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig):
"attention.wv": col_parallel_strategy(),
"attention.wo": row_parallel_strategy(output_layouts=Shard(1)),
"attention_norm": SequenceParallel(),
"feed_forward": PrepareModuleInput(
"feed_forward": prepare_module_input(
input_layouts=(Shard(1),),
desired_input_layouts=(Replicate(),),
),
Expand Down

0 comments on commit bedb16a

Please sign in to comment.