Skip to content

Commit

Permalink
Apply suggestions from code review
Browse files Browse the repository at this point in the history
- Capture `container_uri` from environment variable before running test
and remove the default value to prevent issues when testing
- Remove `max_train_epochs=-1` as not required since `max_steps` is
already specified
- Rename `test_transformers` to `test_huggingface_inference_toolkit`
  • Loading branch information
alvarobartt committed Sep 2, 2024
1 parent 7c4bf87 commit 3af2bcf
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 27 deletions.
13 changes: 6 additions & 7 deletions tests/pytorch/inference/test_huggingface_inference_toolkit.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,24 +48,23 @@
),
],
)
def test_transformers(
def test_huggingface_inference_toolkit(
caplog: pytest.LogCaptureFixture,
hf_model_id: str,
hf_task: str,
prediction_payload: dict,
) -> None:
caplog.set_level(logging.INFO)

container_uri = os.getenv("INFERENCE_DLC", None)
if container_uri is None or container_uri == "":
assert False, "INFERENCE_DLC environment variable is not set"

client = docker.from_env()

logging.info(f"Starting container for {hf_model_id}...")
container = client.containers.run(
os.getenv(
"INFERENCE_DLC",
"us-docker.pkg.dev/deeplearning-platform-release/gcr.io/huggingface-pytorch-inference-cpu.2-2.transformers.4-44.ubuntu2204.py311"
if not CUDA_AVAILABLE
else "us-docker.pkg.dev/deeplearning-platform-release/gcr.io/huggingface-pytorch-inference-cu121.2-2.transformers.4-44.ubuntu2204.py311",
),
container_uri,
ports={"8080": 8080},
environment={
"HF_MODEL_ID": hf_model_id,
Expand Down
20 changes: 10 additions & 10 deletions tests/pytorch/training/test_trl.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,15 @@ def test_trl(caplog: pytest.LogCaptureFixture, tmp_path: PosixPath) -> None:
"""Adapted from https://github.com/huggingface/trl/blob/main/examples/scripts/sft.py"""
caplog.set_level(logging.INFO)

container_uri = os.getenv("TRAINING_DLC", None)
if container_uri is None or container_uri == "":
assert False, "TRAINING_DLC environment variable is not set"

client = docker.from_env()

logging.info("Running the container for TRL...")
container = client.containers.run(
os.getenv(
"TRAINING_DLC",
"us-docker.pkg.dev/deeplearning-platform-release/gcr.io/huggingface-pytorch-training-cu121.2-3.transformers.4-42.ubuntu2204.py310",
),
container_uri,
command=[
"trl",
"sft",
Expand All @@ -38,7 +39,6 @@ def test_trl(caplog: pytest.LogCaptureFixture, tmp_path: PosixPath) -> None:
"--gradient_accumulation_steps=1",
"--output_dir=/opt/huggingface/trained_model",
"--logging_steps=1",
"--num_train_epochs=-1",
"--max_steps=10",
"--gradient_checkpointing",
],
Expand Down Expand Up @@ -81,14 +81,15 @@ def test_trl_peft(caplog: pytest.LogCaptureFixture, tmp_path: PosixPath) -> None
"""Adapted from https://github.com/huggingface/trl/blob/main/examples/scripts/sft.py"""
caplog.set_level(logging.INFO)

container_uri = os.getenv("TRAINING_DLC", None)
if container_uri is None or container_uri == "":
assert False, "TRAINING_DLC environment variable is not set"

client = docker.from_env()

logging.info("Running the container for TRL...")
container = client.containers.run(
os.getenv(
"TRAINING_DLC",
"us-docker.pkg.dev/deeplearning-platform-release/gcr.io/huggingface-pytorch-training-cu121.2-3.transformers.4-42.ubuntu2204.py310",
),
container_uri,
command=[
"trl",
"sft",
Expand All @@ -100,7 +101,6 @@ def test_trl_peft(caplog: pytest.LogCaptureFixture, tmp_path: PosixPath) -> None
"--gradient_accumulation_steps=1",
"--output_dir=/opt/huggingface/trained_model",
"--logging_steps=1",
"--num_train_epochs=-1",
"--max_steps=10",
"--gradient_checkpointing",
"--use_peft",
Expand Down
11 changes: 5 additions & 6 deletions tests/tei/test_tei.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,18 +33,17 @@ def test_text_embeddings_inference(
) -> None:
caplog.set_level(logging.INFO)

container_uri = os.getenv("TEI_DLC", None)
if container_uri is None or container_uri == "":
assert False, "TEI_DLC environment variable is not set"

client = docker.from_env()

logging.info(
f"Starting container for {text_embeddings_router_kwargs.get('MODEL_ID', None)}..."
)
container = client.containers.run(
os.getenv(
"TEI_DLC",
"us-docker.pkg.dev/deeplearning-platform-release/gcr.io/huggingface-text-embeddings-inference-cpu.1-2"
if not CUDA_AVAILABLE
else "us-docker.pkg.dev/deeplearning-platform-release/gcr.io/huggingface-text-embeddings-inference-cu122.1-4.ubuntu2204",
),
container_uri,
# TODO: udpate once the TEI DLCs is updated, as the current is still on revision:
# https://github.com/huggingface/Google-Cloud-Containers/blob/517b8728725f6249774dcd46ee8d7ede8d95bb70/containers/tei/cpu/1.2.2/Dockerfile
# and it exposes the 80 port and uses the /data directory instead of /tmp
Expand Down
9 changes: 5 additions & 4 deletions tests/tgi/test_tgi.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,10 @@ def test_text_generation_inference(
) -> None:
caplog.set_level(logging.INFO)

container_uri = os.getenv("TGI_DLC", None)
if container_uri is None or container_uri == "":
assert False, "TGI_DLC environment variable is not set"

client = docker.from_env()

# If the GPU compute capability is lower than 8.0 (Ampere), then set `USE_FLASH_ATTENTION=false`
Expand All @@ -56,10 +60,7 @@ def test_text_generation_inference(
f"Starting container for {text_generation_launcher_kwargs.get('MODEL_ID', None)}..."
)
container = client.containers.run(
os.getenv(
"TGI_DLC",
"us-docker.pkg.dev/deeplearning-platform-release/gcr.io/huggingface-text-generation-inference-cu121.2-2.ubuntu2204.py310",
),
container_uri,
ports={8080: 8080},
environment=text_generation_launcher_kwargs,
healthcheck={
Expand Down

0 comments on commit 3af2bcf

Please sign in to comment.