Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: use jumpstart deployment config image as default optimization image #4992

Merged
merged 6 commits into from
Jan 29, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 111 additions & 2 deletions src/sagemaker/serve/builder/jumpstart_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import re
from abc import ABC, abstractmethod
from datetime import datetime, timedelta
from typing import Type, Any, List, Dict, Optional
from typing import Type, Any, List, Dict, Optional, Tuple
import logging

from botocore.exceptions import ClientError
Expand Down Expand Up @@ -82,6 +82,7 @@
ModelServer.DJL_SERVING,
ModelServer.TGI,
}
_JS_MINIMUM_VERSION_IMAGE = "{}:0.31.0-lmi13.0.0-cu124"

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -829,7 +830,13 @@
self.pysdk_model._enable_network_isolation = False

if quantization_config or sharding_config or is_compilation:
return create_optimization_job_args
# only apply default image for vLLM usecases.
# vLLM does not support compilation for now so skip on compilation
return (
create_optimization_job_args
if is_compilation
else self._set_optimization_image_default(create_optimization_job_args)
)
return None

def _is_gated_model(self, model=None) -> bool:
Expand Down Expand Up @@ -986,3 +993,105 @@
)
return job_model.env
return None

def _set_optimization_image_default(
self, create_optimization_job_args: Dict[str, Any]
) -> Dict[str, Any]:
"""Defaults the optimization image to the JumpStart deployment config default

Args:
create_optimization_job_args (Dict[str, Any]): create optimization job request

Returns:
Dict[str, Any]: create optimization job request with image uri default
"""
default_image = self._get_default_vllm_image(self.pysdk_model.init_kwargs["image_uri"])

# find the latest vLLM image version
for optimization_config in create_optimization_job_args.get("OptimizationConfigs"):
if optimization_config.get("ModelQuantizationConfig"):
model_quantization_config = optimization_config.get("ModelQuantizationConfig")
provided_image = model_quantization_config.get("Image")
if provided_image and self._get_latest_lmi_version_from_list(
default_image, provided_image
):
default_image = provided_image
if optimization_config.get("ModelShardingConfig"):
gwang111 marked this conversation as resolved.
Show resolved Hide resolved
model_sharding_config = optimization_config.get("ModelShardingConfig")
provided_image = model_sharding_config.get("Image")
if provided_image and self._get_latest_lmi_version_from_list(
default_image, provided_image
):
default_image = provided_image

Check warning on line 1025 in src/sagemaker/serve/builder/jumpstart_builder.py

View check run for this annotation

Codecov / codecov/patch

src/sagemaker/serve/builder/jumpstart_builder.py#L1025

Added line #L1025 was not covered by tests

# default to latest vLLM version
for optimization_config in create_optimization_job_args.get("OptimizationConfigs"):
if optimization_config.get("ModelQuantizationConfig") is not None:
optimization_config.get("ModelQuantizationConfig")["Image"] = default_image
if optimization_config.get("ModelShardingConfig") is not None:
optimization_config.get("ModelShardingConfig")["Image"] = default_image

logger.info("Defaulting to %s image for optimization job", default_image)

return create_optimization_job_args

def _get_default_vllm_image(self, image: str) -> bool:
"""Ensures the minimum working image version for vLLM enabled optimization techniques

Args:
image (str): JumpStart provided default image

Returns:
str: minimum working image version
"""
dlc_name, _ = image.split(":")
major_version_number, _, _ = self._parse_lmi_version(image)

if major_version_number < self._parse_lmi_version(_JS_MINIMUM_VERSION_IMAGE)[0]:
minimum_version_default = _JS_MINIMUM_VERSION_IMAGE.format(dlc_name)
return minimum_version_default
return image

def _get_latest_lmi_version_from_list(self, version: str, version_to_compare: str) -> bool:
"""LMI version comparator

Args:
version (str): current version
version_to_compare (str): version to compare to

Returns:
bool: if version_to_compare larger or equal to version
"""
parse_lmi_version = self._parse_lmi_version(version)
parse_lmi_version_to_compare = self._parse_lmi_version(version_to_compare)

# Check major version
if parse_lmi_version_to_compare[0] > parse_lmi_version[0]:
return True
# Check minor version
if parse_lmi_version_to_compare[0] == parse_lmi_version[0]:
if parse_lmi_version_to_compare[1] > parse_lmi_version[1]:
return True
if parse_lmi_version_to_compare[1] == parse_lmi_version[1]:
# Check patch version
if parse_lmi_version_to_compare[2] >= parse_lmi_version[2]:
return True
return False
return False

Check warning on line 1080 in src/sagemaker/serve/builder/jumpstart_builder.py

View check run for this annotation

Codecov / codecov/patch

src/sagemaker/serve/builder/jumpstart_builder.py#L1079-L1080

Added lines #L1079 - L1080 were not covered by tests
return False

def _parse_lmi_version(self, image: str) -> Tuple[int, int, int]:
"""Parse out LMI version

Args:
image (str): image to parse version out of

Returns:
Tuple[int, int, int]: LMI version split into major, minor, patch
"""
_, dlc_tag = image.split(":")
_, lmi_version, _ = dlc_tag.split("-")
major_version, minor_version, patch_version = lmi_version.split(".")
major_version_number = major_version[3:]

return (int(major_version_number), int(minor_version), int(patch_version))
18 changes: 18 additions & 0 deletions tests/integ/sagemaker/serve/test_serve_js_deep_unit_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ def test_js_model_with_optimize_speculative_decoding_config_gated_requests_are_e
iam_client = sagemaker_session.boto_session.client("iam")
role_arn = iam_client.get_role(RoleName=ROLE_NAME)["Role"]["Arn"]

sagemaker_session.sagemaker_client.create_optimization_job = MagicMock()

schema_builder = SchemaBuilder("test", "test")
model_builder = ModelBuilder(
model="meta-textgeneration-llama-3-1-8b-instruct",
Expand All @@ -50,6 +52,8 @@ def test_js_model_with_optimize_speculative_decoding_config_gated_requests_are_e
accept_eula=True,
)

assert not sagemaker_session.sagemaker_client.create_optimization_job.called

optimized_model.deploy()

mock_create_model.assert_called_once_with(
Expand Down Expand Up @@ -126,6 +130,13 @@ def test_js_model_with_optimize_sharding_and_resource_requirements_requests_are_
accept_eula=True,
)

assert (
sagemaker_session.sagemaker_client.create_optimization_job.call_args_list[0][1][
"OptimizationConfigs"
][0]["ModelShardingConfig"]["Image"]
is not None
)

optimized_model.deploy(
resources=ResourceRequirements(requests={"memory": 196608, "num_accelerators": 8})
)
Expand Down Expand Up @@ -206,6 +217,13 @@ def test_js_model_with_optimize_quantization_on_pre_optimized_model_requests_are
accept_eula=True,
)

assert (
sagemaker_session.sagemaker_client.create_optimization_job.call_args_list[0][1][
"OptimizationConfigs"
][0]["ModelQuantizationConfig"]["Image"]
is not None
)

optimized_model.deploy()

mock_create_model.assert_called_once_with(
Expand Down
Loading