From f3daad1ad70fa7730390add0bf92f8a45446764a Mon Sep 17 00:00:00 2001 From: Ke Wen Date: Thu, 28 Dec 2023 09:36:48 -0800 Subject: [PATCH] Comments --- examples/llama/2d_llama.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/examples/llama/2d_llama.py b/examples/llama/2d_llama.py index e09342134..f3764e68f 100644 --- a/examples/llama/2d_llama.py +++ b/examples/llama/2d_llama.py @@ -6,12 +6,15 @@ from torch.distributed._tensor import init_device_mesh +# Utility def modify_view( gm: torch.fx.GraphModule, tp: int ): """ - Adjust dimension size of view ops to make them compatible with tensor parallelism. + Adjust dimension size of view ops to make them compatible with tensor + parallelism. For example, when TP is 4, we need to adjust `num_heads` from + 32 to 8. This is needed for attention layers. """ for node in gm.graph.nodes: if node.op == "call_method" and (