Skip to content

Commit

Permalink
adding deferred init as Ke advised
Browse files Browse the repository at this point in the history
  • Loading branch information
HamidShojanazeri committed Jan 2, 2024
1 parent f3daad1 commit cf74348
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions examples/llama/2d_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ def modify_view(

# Grab the model
llama = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-2-7b-chat-hf", low_cpu_mem_usage=True
"meta-llama/Llama-2-7b-chat-hf", low_cpu_mem_usage=True,
torch_dtype=torch.float16
)
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf")

Expand All @@ -46,8 +47,8 @@ def modify_view(
mesh_2d = init_device_mesh("cuda", (pp_group_size, tp_group_size), mesh_dim_names=("pp", "tp"))
pp_group = mesh_2d["pp"].get_group()

llama.to(device).eval()
inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(device)
llama.eval()
inputs = tokenizer(prompts, return_tensors="pt", padding=True)

# Cut model by equal number of layers per rank
layers_per_stage = llama.config.num_hidden_layers // pp_group_size
Expand Down Expand Up @@ -90,7 +91,7 @@ def modify_view(
parallelize_module(
stage.submod, tp_mesh, {**attn_plan, **mlp_plan}
)

inputs = inputs.to(device)
# Run
if stage_idx == 0:
args = inputs["input_ids"]
Expand Down

0 comments on commit cf74348

Please sign in to comment.