Skip to content

Commit

Permalink
Comments
Browse files Browse the repository at this point in the history
  • Loading branch information
kwen2501 committed Dec 28, 2023
1 parent e55eef4 commit f3daad1
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion examples/llama/2d_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,15 @@
from torch.distributed._tensor import init_device_mesh


# Utility
def modify_view(
gm: torch.fx.GraphModule,
tp: int
):
"""
Adjust dimension size of view ops to make them compatible with tensor parallelism.
Adjust dimension size of view ops to make them compatible with tensor
parallelism. For example, when TP is 4, we need to adjust `num_heads` from
32 to 8. This is needed for attention layers.
"""
for node in gm.graph.nodes:
if node.op == "call_method" and (
Expand Down

0 comments on commit f3daad1

Please sign in to comment.