diff --git a/examples/llama/2d_llama.py b/examples/llama/2d_llama.py index f3764e68f..a3e66e054 100644 --- a/examples/llama/2d_llama.py +++ b/examples/llama/2d_llama.py @@ -27,7 +27,8 @@ def modify_view( # Grab the model llama = AutoModelForCausalLM.from_pretrained( - "meta-llama/Llama-2-7b-chat-hf", low_cpu_mem_usage=True + "meta-llama/Llama-2-7b-chat-hf", low_cpu_mem_usage=True, + torch_dtype=torch.float16 ) tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf") @@ -46,8 +47,8 @@ def modify_view( mesh_2d = init_device_mesh("cuda", (pp_group_size, tp_group_size), mesh_dim_names=("pp", "tp")) pp_group = mesh_2d["pp"].get_group() -llama.to(device).eval() -inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(device) +llama.eval() +inputs = tokenizer(prompts, return_tensors="pt", padding=True) # Cut model by equal number of layers per rank layers_per_stage = llama.config.num_hidden_layers // pp_group_size @@ -90,7 +91,7 @@ def modify_view( parallelize_module( stage.submod, tp_mesh, {**attn_plan, **mlp_plan} ) - +inputs = inputs.to(device) # Run if stage_idx == 0: args = inputs["input_ids"]