diff --git a/examples/llama/2d_llama.py b/examples/llama/2d_llama.py index e09342134..f3764e68f 100644 --- a/examples/llama/2d_llama.py +++ b/examples/llama/2d_llama.py @@ -6,12 +6,15 @@ from torch.distributed._tensor import init_device_mesh +# Utility def modify_view( gm: torch.fx.GraphModule, tp: int ): """ - Adjust dimension size of view ops to make them compatible with tensor parallelism. + Adjust dimension size of view ops to make them compatible with tensor + parallelism. For example, when TP is 4, we need to adjust `num_heads` from + 32 to 8. This is needed for attention layers. """ for node in gm.graph.nodes: if node.op == "call_method" and (