Skip to content

Commit

Permalink
Modify view ops to make them compatible with TP
Browse files Browse the repository at this point in the history
  • Loading branch information
kwen2501 committed Dec 28, 2023
1 parent d2e8e62 commit e55eef4
Showing 1 changed file with 31 additions and 7 deletions.
38 changes: 31 additions & 7 deletions examples/llama/2d_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,23 @@
from pippy import Pipe, PipeSplitWrapper, annotate_split_points, PipelineStage
from torch.distributed._tensor import init_device_mesh


def modify_view(
gm: torch.fx.GraphModule,
tp: int
):
"""
Adjust dimension size of view ops to make them compatible with tensor parallelism.
"""
for node in gm.graph.nodes:
if node.op == "call_method" and (
node.target == "view" or node.target == "reshape"
):
assert len(node.args) >= 4
node.update_arg(3, node.args[3] // tp)
gm.recompile()


# Grab the model
llama = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-2-7b-chat-hf", low_cpu_mem_usage=True
Expand Down Expand Up @@ -42,27 +59,34 @@
stage_idx = rank // tp_group_size
stage = PipelineStage(llama_pipe, stage_idx, device=device, group=pp_group)

modify_view(stage.submod, tp_group_size)

# Tensor parallel
from torch.distributed.tensor.parallel import parallelize_module, ColwiseParallel, RowwiseParallel
starting_layer = stage_idx * layers_per_stage
plan = {}
attn_plan = {}
mlp_plan = {}
for i in range(layers_per_stage):
# HACK: the right fix is to remove the ".mod" added by PipeSplitWrapper
extra = "_mod" if starting_layer > 0 and i == 0 else ""
layer_name = f"L__self___model_layers_{starting_layer + i}{extra}"
plan.update({
attn_plan.update({
# Parallel self attention not working yet due to the dimension mismatch
# after TP in view operation
#f"{layer_name}_self_attn_q_proj": ColwiseParallel(),
#f"{layer_name}_self_attn_k_proj": ColwiseParallel(),
#f"{layer_name}_self_attn_v_proj": ColwiseParallel(),
#f"{layer_name}_self_attn_o_proj": RowwiseParallel(),
f"{layer_name}_self_attn_q_proj": ColwiseParallel(),
f"{layer_name}_self_attn_k_proj": ColwiseParallel(),
f"{layer_name}_self_attn_v_proj": ColwiseParallel(),
f"{layer_name}_self_attn_o_proj": RowwiseParallel(),
})
mlp_plan.update({
f"{layer_name}_mlp_gate_proj": ColwiseParallel(),
f"{layer_name}_mlp_up_proj": ColwiseParallel(),
f"{layer_name}_mlp_down_proj": RowwiseParallel(),
})
tp_mesh = mesh_2d["tp"]
parallelize_module(stage.submod, tp_mesh, plan)
parallelize_module(
stage.submod, tp_mesh, {**attn_plan, **mlp_plan}
)

# Run
if stage_idx == 0:
Expand Down

0 comments on commit e55eef4

Please sign in to comment.