Skip to content

Commit

Permalink
use jumpstart deployment config image as default optimization image
Browse files Browse the repository at this point in the history
  • Loading branch information
gwang111 committed Jan 16, 2025
1 parent a58654e commit 1104baf
Show file tree
Hide file tree
Showing 4 changed files with 332 additions and 7 deletions.
113 changes: 111 additions & 2 deletions src/sagemaker/serve/builder/jumpstart_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import re
from abc import ABC, abstractmethod
from datetime import datetime, timedelta
from typing import Type, Any, List, Dict, Optional
from typing import Type, Any, List, Dict, Optional, Tuple
import logging

from botocore.exceptions import ClientError
Expand Down Expand Up @@ -82,6 +82,7 @@
ModelServer.DJL_SERVING,
ModelServer.TGI,
}
_JS_MINIMUM_VERSION_IMAGE = "{}:0.31.0-lmi13.0.0-cu124"

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -829,7 +830,13 @@ def _optimize_for_jumpstart(
self.pysdk_model._enable_network_isolation = False

if quantization_config or sharding_config or is_compilation:
return create_optimization_job_args
# only apply default image for vLLM usecases.
# vLLM does not support compilation for now so skip on compilation
return (
create_optimization_job_args
if is_compilation
else self._set_optimization_image_default(create_optimization_job_args)
)
return None

def _is_gated_model(self, model=None) -> bool:
Expand Down Expand Up @@ -986,3 +993,105 @@ def _get_neuron_model_env_vars(
)
return job_model.env
return None

def _set_optimization_image_default(
self, create_optimization_job_args: Dict[str, Any]
) -> Dict[str, Any]:
"""Defaults the optimization image to the JumpStart deployment config default
Args:
create_optimization_job_args (Dict[str, Any]): create optimization job request
Returns:
Dict[str, Any]: create optimization job request with image uri default
"""
default_image = self._get_default_vllm_image(self.pysdk_model.init_kwargs["image_uri"])

# find the latest vLLM image version
for optimization_config in create_optimization_job_args.get("OptimizationConfigs"):
if optimization_config.get("ModelQuantizationConfig"):
model_quantization_config = optimization_config.get("ModelQuantizationConfig")
provided_image = model_quantization_config.get("Image")
if provided_image and self._get_latest_lmi_version_from_list(
default_image, provided_image
):
default_image = provided_image
if optimization_config.get("ModelShardingConfig"):
model_sharding_config = optimization_config.get("ModelShardingConfig")
provided_image = model_sharding_config.get("Image")
if provided_image and self._get_latest_lmi_version_from_list(
default_image, provided_image
):
default_image = provided_image

# default to latest vLLM version
for optimization_config in create_optimization_job_args.get("OptimizationConfigs"):
if optimization_config.get("ModelQuantizationConfig") is not None:
optimization_config.get("ModelQuantizationConfig")["Image"] = default_image
if optimization_config.get("ModelShardingConfig") is not None:
optimization_config.get("ModelShardingConfig")["Image"] = default_image

logger.info("Defaulting to %s image for optimization job", default_image)

return create_optimization_job_args

def _get_default_vllm_image(self, image: str) -> bool:
"""Ensures the minimum working image version for vLLM enabled optimization techniques
Args:
image (str): JumpStart provided default image
Returns:
str: minimum working image version
"""
dlc_name, _ = image.split(":")
major_version_number, _, _ = self._parse_lmi_version(image)

if int(major_version_number) < self._parse_lmi_version(_JS_MINIMUM_VERSION_IMAGE)[0]:
minimum_version_default = _JS_MINIMUM_VERSION_IMAGE.format(dlc_name)
return minimum_version_default
return image

def _get_latest_lmi_version_from_list(self, version: str, version_to_compare: str) -> bool:
"""LMI version comparator
Args:
version (str): current version
version_to_compare (str): version to compare to
Returns:
bool: if version_to_compare larger or equal to version
"""
parse_lmi_version = self._parse_lmi_version(version)
parse_lmi_version_to_compare = self._parse_lmi_version(version_to_compare)

# Check major version
if parse_lmi_version_to_compare[0] > parse_lmi_version[0]:
return True
# Check minor version
if parse_lmi_version_to_compare[0] == parse_lmi_version[0]:
if parse_lmi_version_to_compare[1] > parse_lmi_version[1]:
return True
if parse_lmi_version_to_compare[1] == parse_lmi_version[1]:
# Check patch version
if parse_lmi_version_to_compare[2] >= parse_lmi_version[2]:
return True
return False
return False
return False

def _parse_lmi_version(self, image: str) -> Tuple[int, int, int]:
"""Parse out LMI version
Args:
image (str): image to parse version out of
Returns:
Tuple[int, int, it]: LMI version split into major, minor, patch
"""
_, dlc_tag = image.split(":")
_, lmi_version, _ = dlc_tag.split("-")
major_version, minor_version, patch_version = lmi_version.split(".")
major_version_number = major_version[3:]

return (int(major_version_number), int(minor_version), int(patch_version))
18 changes: 18 additions & 0 deletions tests/integ/sagemaker/serve/test_serve_js_deep_unit_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ def test_js_model_with_optimize_speculative_decoding_config_gated_requests_are_e
iam_client = sagemaker_session.boto_session.client("iam")
role_arn = iam_client.get_role(RoleName=ROLE_NAME)["Role"]["Arn"]

sagemaker_session.sagemaker_client.create_optimization_job = MagicMock()

schema_builder = SchemaBuilder("test", "test")
model_builder = ModelBuilder(
model="meta-textgeneration-llama-3-1-8b-instruct",
Expand All @@ -50,6 +52,8 @@ def test_js_model_with_optimize_speculative_decoding_config_gated_requests_are_e
accept_eula=True,
)

assert not sagemaker_session.sagemaker_client.create_optimization_job.called

optimized_model.deploy()

mock_create_model.assert_called_once_with(
Expand Down Expand Up @@ -126,6 +130,13 @@ def test_js_model_with_optimize_sharding_and_resource_requirements_requests_are_
accept_eula=True,
)

assert (
sagemaker_session.sagemaker_client.create_optimization_job.call_args_list[0][1][
"OptimizationConfigs"
][0]["ModelShardingConfig"]["Image"]
is not None
)

optimized_model.deploy(
resources=ResourceRequirements(requests={"memory": 196608, "num_accelerators": 8})
)
Expand Down Expand Up @@ -206,6 +217,13 @@ def test_js_model_with_optimize_quantization_on_pre_optimized_model_requests_are
accept_eula=True,
)

assert (
sagemaker_session.sagemaker_client.create_optimization_job.call_args_list[0][1][
"OptimizationConfigs"
][0]["ModelQuantizationConfig"]["Image"]
is not None
)

optimized_model.deploy()

mock_create_model.assert_called_once_with(
Expand Down
Loading

0 comments on commit 1104baf

Please sign in to comment.