Skip to content

Commit

Permalink
Add validations to block compilation jobs using TRTLLM an Llama-3.1.
Browse files Browse the repository at this point in the history
  • Loading branch information
Joseph Zhang committed Sep 19, 2024
1 parent 54e995f commit 828ad60
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 0 deletions.
22 changes: 22 additions & 0 deletions src/sagemaker/serve/builder/model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -1276,6 +1276,28 @@ def _model_builder_optimize_wrapper(
)

if input_args:
optimization_instance_type = input_args["DeploymentInstanceType"]

# Compilation using TRTLLM and Llama-3.1 is currently not supported.
# TRTLLM is used by Neo if the following are provided:
# 1) a GPU instance type
# 2) compilation config
gpu_instance_families = ["g4", "g5", "p4d"]
is_gpu_instance = optimization_instance_type and any(
gpu_instance_family in optimization_instance_type
for gpu_instance_family in gpu_instance_families
)

# HF Model ID format = "meta-llama/Meta-Llama-3.1-8B"
# JS Model ID format = "meta-textgeneration-llama-3-1-8b"
llama_3_1_keywords = ["llama-3.1", "llama-3-1"]
is_llama_3_1 = self.model and any(
keyword in self.model.lower() for keyword in llama_3_1_keywords
)

if is_gpu_instance and self.model and is_llama_3_1 and self.is_compiled:
raise ValueError("Compilation is not supported for Llama-3.1 with a GPU instance.")

self.sagemaker_session.sagemaker_client.create_optimization_job(**input_args)
job_status = self.sagemaker_session.wait_for_optimization_job(job_name)
return _generate_optimized_model(self.pysdk_model, job_status)
Expand Down
51 changes: 51 additions & 0 deletions tests/unit/sagemaker/serve/builder/test_model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -2840,3 +2840,54 @@ def test_optimize_for_hf_without_custom_s3_path(
"OutputConfig": {"S3OutputLocation": "s3://bucket/code/"},
},
)

@patch.object(ModelBuilder, "_prepare_for_mode")
@patch.object(ModelBuilder, "_get_serve_setting", autospec=True)
def test_optimize_with_gpu_instance_and_llama_3_1_and_compilation(
self,
mock_get_serve_setting,
mock_prepare_for_mode,
):
mock_prepare_for_mode.side_effect = lambda *args, **kwargs: (
{
"S3DataSource": {
"CompressionType": "None",
"S3DataType": "S3Prefix",
"S3Uri": "s3://bucket/code/code/",
}
},
{"DTYPE": "bfloat16"},
)

mock_pysdk_model = Mock()
mock_pysdk_model.model_data = None
mock_pysdk_model.env = {"HF_MODEL_ID": "meta-llama/Meta-Llama-3-1-8B-Instruct"}

sample_input = {"inputs": "dummy prompt", "parameters": {}}

sample_output = [{"generated_text": "dummy response"}]

dummy_schema_builder = SchemaBuilder(sample_input, sample_output)

model_builder = ModelBuilder(
model="meta-llama/Meta-Llama-3-1-8B-Instruct",
schema_builder=dummy_schema_builder,
env_vars={"HF_TOKEN": "token"},
model_metadata={
"CUSTOM_MODEL_PATH": "s3://bucket/path/",
},
role_arn="role-arn",
instance_type="ml.g5.2xlarge",
)

model_builder.pysdk_model = mock_pysdk_model

self.assertRaisesRegex(
ValueError,
"Compilation is not supported for Llama-3.1 with a GPU instance.",
lambda: model_builder.optimize(
job_name="job_name-123",
compilation_config={"OverrideEnvironment": {"OPTION_TENSOR_PARALLEL_DEGREE": "2"}},
output_path="s3://bucket/code/",
),
)

0 comments on commit 828ad60

Please sign in to comment.