Skip to content

Commit

Permalink
feat(tests): added test showing gemma7b sharding and prefill works
Browse files Browse the repository at this point in the history
  • Loading branch information
tengomucho committed Apr 9, 2024
1 parent 550e1fb commit 2215595
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 4 deletions.
21 changes: 21 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import pytest

# See https://stackoverflow.com/a/61193490/217945 for run_slow
def pytest_addoption(parser):
parser.addoption(
"--runslow", action="store_true", default=False, help="run slow tests"
)


def pytest_configure(config):
config.addinivalue_line("markers", "slow: mark test as slow to run")


def pytest_collection_modifyitems(config, items):
if config.getoption("--runslow"):
# --runslow given in cli: do not skip slow tests
return
skip_slow = pytest.mark.skip(reason="need --runslow option to run")
for item in items:
if "slow" in item.keywords:
item.add_marker(skip_slow)
18 changes: 14 additions & 4 deletions tests/test_distributed_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from optimum.tpu.distributed_model import DistributedModel
from transformers import AutoTokenizer
import torch
import pytest


def sample_greedy(logits):
Expand All @@ -10,10 +11,9 @@ def sample_greedy(logits):
return next_token_id


def test_distributed_model_prefill():
# This model will not actually shard gpt2, but it ensures model can be loaded in a parallel way and
def _test_distributed_model_prefill(model_id):
# This test ensures model can be loaded in a parallel way and
# that the "proxy" distributed model can be used to prefill the model.
model_id = "openai-community/gpt2"
# Disable tokenizers parallelism to avoid deadlocks
os.environ["TOKENIZERS_PARALLELISM"] = "false"
tokenizer = AutoTokenizer.from_pretrained(model_id)
Expand All @@ -29,16 +29,26 @@ def test_distributed_model_prefill():
tokens = torch.cat([tokens, next_tokens], dim=-1)

# Data can be decoded even before leaving
decoded_texts = tokenizer.batch_decode(tokens)
decoded_texts = tokenizer.batch_decode(tokens, skip_special_tokens=True)
print()
print("------------------------------------------")
print("Decoded texts:")
print(decoded_texts[0])
print("------------------------------------------")
# Even if models are different, for this simple test results are the same.
expected_text = "Running something in parallel means that"
assert expected_text == decoded_texts[0]


def test_distributed_model_prefill_gpt2():
_test_distributed_model_prefill("openai-community/gpt2")


@pytest.mark.slow
def test_distributed_model_prefill_gemma7b():
_test_distributed_model_prefill("google/gemma-7b")


def test_distributed_model_config():
model_id = "openai-community/gpt2"
model = DistributedModel(model_id, sample_greedy)
Expand Down

0 comments on commit 2215595

Please sign in to comment.