Skip to content

Commit

Permalink
Add kwargs to mute many outputs; change world_size to 2
Browse files Browse the repository at this point in the history
  • Loading branch information
kwen2501 committed Jul 24, 2024
1 parent 1525384 commit 5958928
Showing 1 changed file with 9 additions and 2 deletions.
11 changes: 9 additions & 2 deletions examples/llama/meta_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
"meta-llama/Llama-2-7b-chat-hf"
)

llama.eval()
print(llama)

# Cast the model to FakeTensor with real device (from meta device) because
Expand All @@ -60,7 +61,7 @@

# Beginning of distributed
# [Note 2]: change world size here
world_size = 4
world_size = 2
print(f"{world_size=}")

# Cut model by equal number of layers per rank
Expand All @@ -72,7 +73,12 @@
}

# Convert model into a pipeline
pipe = pipeline(llama, mb_args=(fake_ids,), split_spec=split_spec)
pipe = pipeline(
llama,
mb_args=(fake_ids,),
mb_kwargs={"output_attentions": False, "output_hidden_states": False, "use_cache": False,},
split_spec=split_spec,
)

# Materialize each stage
# [Note 3]: remove this for loop if you are running this script in a
Expand All @@ -81,4 +87,5 @@
stage_module = pipe.get_stage_module(rank)
print(f"Loading weights into stage {rank}")
load_weights(stage_module)
stage_module.print_readable()

0 comments on commit 5958928

Please sign in to comment.