Skip to content

Commit

Permalink
Fix offload generate tests (#3334)
Browse files Browse the repository at this point in the history
* Fix tests

* format
  • Loading branch information
SunMarc authored Jan 13, 2025
1 parent 95f34d6 commit 8c423cf
Showing 1 changed file with 10 additions and 25 deletions.
35 changes: 10 additions & 25 deletions tests/test_big_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,11 +238,8 @@ def test_cpu_offload_gpt2(self):

gpt2 = AutoModelForCausalLM.from_pretrained("gpt2")
cpu_offload(gpt2, execution_device=0)
outputs = gpt2.generate(inputs["input_ids"])
assert (
tokenizer.decode(outputs[0].tolist())
== "Hello world! My name is Kiyoshi, and I'm a student at the University of Tokyo"
)
outputs = gpt2.generate(inputs["input_ids"], max_new_tokens=10)
assert tokenizer.decode(outputs[0].tolist()) == "Hello world! My name is Kiyoshi, and I'm a student at"

def test_disk_offload(self):
model = ModelForTest()
Expand Down Expand Up @@ -301,11 +298,8 @@ def test_disk_offload_gpt2(self):
gpt2 = AutoModelForCausalLM.from_pretrained("gpt2")
with TemporaryDirectory() as tmp_dir:
disk_offload(gpt2, tmp_dir, execution_device=0)
outputs = gpt2.generate(inputs["input_ids"])
assert (
tokenizer.decode(outputs[0].tolist())
== "Hello world! My name is Kiyoshi, and I'm a student at the University of Tokyo"
)
outputs = gpt2.generate(inputs["input_ids"], max_new_tokens=10)
assert tokenizer.decode(outputs[0].tolist()) == "Hello world! My name is Kiyoshi, and I'm a student at"

@require_non_cpu
def test_dispatch_model_and_remove_hook(self):
Expand Down Expand Up @@ -686,22 +680,16 @@ def test_dispatch_model_gpt2_on_two_devices(self):
device_map[f"transformer.h.{i}"] = 0 if i <= 5 else 1

gpt2 = dispatch_model(gpt2, device_map)
outputs = gpt2.generate(inputs["input_ids"])
assert (
tokenizer.decode(outputs[0].tolist())
== "Hello world! My name is Kiyoshi, and I'm a student at the University of Tokyo"
)
outputs = gpt2.generate(inputs["input_ids"], max_new_tokens=10)
assert tokenizer.decode(outputs[0].tolist()) == "Hello world! My name is Kiyoshi, and I'm a student at"

# Dispatch with a bit of CPU offload
gpt2 = AutoModelForCausalLM.from_pretrained("gpt2")
for i in range(4):
device_map[f"transformer.h.{i}"] = "cpu"
gpt2 = dispatch_model(gpt2, device_map)
outputs = gpt2.generate(inputs["input_ids"])
assert (
tokenizer.decode(outputs[0].tolist())
== "Hello world! My name is Kiyoshi, and I'm a student at the University of Tokyo"
)
outputs = gpt2.generate(inputs["input_ids"], max_new_tokens=10)
assert tokenizer.decode(outputs[0].tolist()) == "Hello world! My name is Kiyoshi, and I'm a student at"
# Dispatch with a bit of CPU and disk offload
gpt2 = AutoModelForCausalLM.from_pretrained("gpt2")
for i in range(2):
Expand All @@ -713,11 +701,8 @@ def test_dispatch_model_gpt2_on_two_devices(self):
}
offload_state_dict(tmp_dir, state_dict)
gpt2 = dispatch_model(gpt2, device_map, offload_dir=tmp_dir)
outputs = gpt2.generate(inputs["input_ids"])
assert (
tokenizer.decode(outputs[0].tolist())
== "Hello world! My name is Kiyoshi, and I'm a student at the University of Tokyo"
)
outputs = gpt2.generate(inputs["input_ids"], max_new_tokens=10)
assert tokenizer.decode(outputs[0].tolist()) == "Hello world! My name is Kiyoshi, and I'm a student at"

@require_non_cpu
def test_dispatch_model_with_unused_submodules(self):
Expand Down

0 comments on commit 8c423cf

Please sign in to comment.