Skip to content

Commit

Permalink
2D working without TP self attention
Browse files Browse the repository at this point in the history
  • Loading branch information
kwen2501 committed Dec 21, 2023
1 parent ea7d1d6 commit 654028f
Showing 1 changed file with 78 additions and 0 deletions.
78 changes: 78 additions & 0 deletions examples/llama/2d_llama.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
# $ torchrun --nproc-per-node 8 2d_llama.py
import os
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from pippy import Pipe, PipeSplitWrapper, annotate_split_points, PipelineStage
from torch.distributed._tensor import init_device_mesh

# Grab the model
llama = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-2-7b-chat-hf", low_cpu_mem_usage=True
)
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf")

prompts = (
"How do you", "I like to", "Can I help", "You need to",
"The weather is", "I found a", "What is your", "You are so",
) # bs = 8
tokenizer.pad_token = tokenizer.eos_token

rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])
device = torch.device(f"cuda:{rank % torch.cuda.device_count()}")

pp_group_size = 2
tp_group_size = 4
mesh_2d = init_device_mesh("cuda", (pp_group_size, tp_group_size), mesh_dim_names=("pp", "tp"))
pp_group = mesh_2d["pp"].get_group()

llama.to(device).eval()
inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(device)

# Cut model by equal number of layers per rank
layers_per_stage = llama.config.num_hidden_layers // pp_group_size
for i in range(1, pp_group_size):
annotate_split_points(llama,
{f"model.layers.{i * layers_per_stage}": PipeSplitWrapper.SplitPoint.BEGINNING})

# Create a pipeline representation from the model
llama_pipe = Pipe.from_tracing(llama, pp_group_size, example_args=(inputs["input_ids"],))

# Create pipeline stage for each rank
stage_idx = rank // tp_group_size
stage = PipelineStage(llama_pipe, stage_idx, device=device, group=pp_group)

# Tensor parallel
from torch.distributed.tensor.parallel import parallelize_module, ColwiseParallel, RowwiseParallel
starting_layer = stage_idx * layers_per_stage
plan = {}
for i in range(layers_per_stage):
# HACK: the right fix is to remove the ".mod" added by PipeSplitWrapper
extra = "_mod" if starting_layer > 0 and i == 0 else ""
layer_name = f"L__self___model_layers_{starting_layer + i}{extra}"
plan.update({
# Parallel self attention not working yet due to the dimension mismatch
# after TP in view operation
#f"{layer_name}_self_attn_q_proj": ColwiseParallel(),
#f"{layer_name}_self_attn_k_proj": ColwiseParallel(),
#f"{layer_name}_self_attn_v_proj": ColwiseParallel(),
#f"{layer_name}_self_attn_o_proj": RowwiseParallel(),
f"{layer_name}_mlp_gate_proj": ColwiseParallel(),
f"{layer_name}_mlp_up_proj": ColwiseParallel(),
f"{layer_name}_mlp_down_proj": RowwiseParallel(),
})
tp_mesh = mesh_2d["tp"]
parallelize_module(stage.submod, tp_mesh, plan)

# Run
if stage_idx == 0:
args = inputs["input_ids"]
else:
args = None
output = stage(args)

# Decode
if output is not None:
next_token_logits = output[0][:, -1, :]
next_token = torch.argmax(next_token_logits, dim=-1)
print(tokenizer.batch_decode(next_token))

0 comments on commit 654028f

Please sign in to comment.