Skip to content

Make batch deployment subclasses GA #40619

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions sdk/ml/azure-ai-ml/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

### Other Changes
- Hub and Project are officially GA'd and no longer experimental.
- PipelineComponentBatchDeployment, ModelBatchDeployment, ModelBatchDeploymentSettings are GA

## 1.26.4 (2025-04-23)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,13 @@

from marshmallow import fields, post_load

from azure.ai.ml._schema._deployment.deployment import DeploymentSchema
from azure.ai.ml._schema.core.fields import ComputeField, NestedField, StringTransformedEnum
from azure.ai.ml._schema.job_resource_configuration import JobResourceConfigurationSchema
from azure.ai.ml._schema._deployment.deployment import DeploymentSchema
from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY
from azure.ai.ml.constants._deployment import BatchDeploymentType
from azure.ai.ml._schema import ExperimentalField
from .model_batch_deployment_settings import ModelBatchDeploymentSettingsSchema

from .model_batch_deployment_settings import ModelBatchDeploymentSettingsSchema

module_logger = logging.getLogger(__name__)

Expand All @@ -37,7 +36,7 @@ class ModelBatchDeploymentSchema(DeploymentSchema):
allowed_values=[BatchDeploymentType.PIPELINE, BatchDeploymentType.MODEL], required=False
)

settings = ExperimentalField(NestedField(ModelBatchDeploymentSettingsSchema))
settings = NestedField(ModelBatchDeploymentSettingsSchema)

@post_load
def make(self, data: Any, **kwargs: Any) -> Any:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,17 @@

from marshmallow import INCLUDE, fields, post_load

from azure.ai.ml._schema import (
ArmVersionedStr,
ArmStr,
UnionField,
RegistryStr,
NestedField,
from azure.ai.ml._schema import ArmStr, ArmVersionedStr, NestedField, RegistryStr, UnionField
from azure.ai.ml._schema.core.fields import (
PathAwareSchema,
PipelineNodeNameStr,
StringTransformedEnum,
TypeSensitiveUnionField,
)
from azure.ai.ml._schema.core.fields import PipelineNodeNameStr, TypeSensitiveUnionField, PathAwareSchema
from azure.ai.ml._schema.pipeline.pipeline_component import PipelineComponentFileRefField
from azure.ai.ml.constants._common import AzureMLResourceType
from azure.ai.ml.constants._component import NodeType
from azure.ai.ml.constants._deployment import BatchDeploymentType

module_logger = logging.getLogger(__name__)

Expand All @@ -35,7 +35,9 @@ class PipelineComponentBatchDeploymentSchema(PathAwareSchema):
)
settings = fields.Dict()
name = fields.Str()
type = fields.Str()
type = StringTransformedEnum(
allowed_values=[BatchDeploymentType.PIPELINE, BatchDeploymentType.MODEL], required=False
)
job_definition = UnionField(
[
ArmStr(azureml_type=AzureMLResourceType.JOB),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# pylint: disable=protected-access

import logging
import warnings
from os import PathLike
from pathlib import Path
from typing import Any, Dict, Optional, Union
Expand All @@ -27,13 +28,29 @@

from .code_configuration import CodeConfiguration
from .deployment import Deployment
from .model_batch_deployment_settings import ModelBatchDeploymentSettings as BatchDeploymentSettings

module_logger = logging.getLogger(__name__)

SETTINGS_ATTRIBUTES = [
"output_action",
"output_file_name",
"error_threshold",
"retry_settings",
"logging_level",
"mini_batch_size",
"max_concurrency_per_instance",
"environment_variables",
]


class BatchDeployment(Deployment):
"""Batch endpoint deployment entity.

**Warning** This class should not be used directly.
Please use one of the child implementations, :class:`~azure.ai.ml.entities.ModelBatchDeployment` or
:class:`azure.ai.ml.entities.PipelineComponentBatchDeployment`.

:param name: the name of the batch deployment
:type name: str
:param description: Description of the resource.
Expand Down Expand Up @@ -112,34 +129,61 @@ def __init__(
instance_count: Optional[int] = None, # promoted property from resources.instance_count
**kwargs: Any,
) -> None:
_type = kwargs.pop("_type", None)

# Suppresses deprecation warning when object is created from REST responses
# This is needed to avoid false deprecation warning on model batch deployment
if _type is None and not kwargs.pop("_from_rest", False):
warnings.warn(
"This class is intended as a base class and it's direct usage is deprecated. "
"Use one of the concrete implementations instead:\n"
"* ModelBatchDeployment - For model-based batch deployments\n"
"* PipelineComponentBatchDeployment - For pipeline component-based batch deployments"
)
self._provisioning_state: Optional[str] = kwargs.pop("provisioning_state", None)

settings = kwargs.pop("settings", None)
super(BatchDeployment, self).__init__(
name=name,
type=_type,
endpoint_name=endpoint_name,
properties=properties,
tags=tags,
description=description,
model=model,
code_configuration=code_configuration,
environment=environment,
environment_variables=environment_variables,
environment_variables=environment_variables, # needed, otherwise Deployment.__init__() will set it to {}
code_path=code_path,
scoring_script=scoring_script,
**kwargs,
)

self.compute = compute
self.resources = resources
self.output_action = output_action
self.output_file_name = output_file_name
self.error_threshold = error_threshold
self.retry_settings = retry_settings
self.logging_level = logging_level
self.mini_batch_size = mini_batch_size
self.max_concurrency_per_instance = max_concurrency_per_instance

if self.resources and instance_count:

self._settings = (
settings
if settings
else BatchDeploymentSettings(
mini_batch_size=mini_batch_size,
instance_count=instance_count,
max_concurrency_per_instance=max_concurrency_per_instance,
output_action=output_action,
output_file_name=output_file_name,
retry_settings=retry_settings,
environment_variables=environment_variables,
error_threshold=error_threshold,
logging_level=logging_level,
)
)

self._setup_instance_count()

def _setup_instance_count(
self,
) -> None: # No need to check instance_count here as it's already set in self._settings during initialization
if self.resources and self._settings.instance_count:
msg = "Can't set instance_count when resources is provided."
raise ValidationException(
message=msg,
Expand All @@ -149,8 +193,26 @@ def __init__(
error_type=ValidationErrorType.INVALID_VALUE,
)

if not self.resources and instance_count:
self.resources = ResourceConfiguration(instance_count=instance_count)
if not self.resources and self._settings.instance_count:
self.resources = ResourceConfiguration(instance_count=self._settings.instance_count)

def __getattr__(self, name: str) -> Optional[Any]:
# Support backwards compatibility with old BatchDeployment properties.
if name in SETTINGS_ATTRIBUTES:
try:
return getattr(self._settings, name)
except AttributeError:
pass
return super().__getattribute__(name)

def __setattr__(self, name, value):
# Support backwards compatibility with old BatchDeployment properties.
if name in SETTINGS_ATTRIBUTES:
try:
setattr(self._settings, name, value)
except AttributeError:
pass
super().__setattr__(name, value)

@property
def instance_count(self) -> Optional[int]:
Expand Down Expand Up @@ -195,7 +257,7 @@ def _yaml_output_action_to_rest_output_action(cls, yaml_output_action: Any) -> s
return output_switcher.get(yaml_output_action, yaml_output_action)

# pylint: disable=arguments-differ
def _to_rest_object(self, location: str) -> BatchDeploymentData: # type: ignore
def _to_rest_object(self, location: str) -> BatchDeploymentData: # type: ignore[override]
self._validate()
code_config = (
RestCodeConfiguration(
Expand All @@ -209,42 +271,28 @@ def _to_rest_object(self, location: str) -> BatchDeploymentData: # type: ignore
environment = self.environment

batch_deployment: RestBatchDeployment = None
if isinstance(self.output_action, str):
batch_deployment = RestBatchDeployment(
compute=self.compute,
description=self.description,
resources=self.resources._to_rest_object() if self.resources else None,
code_configuration=code_config,
environment_id=environment,
model=model,
output_file_name=self.output_file_name,
output_action=BatchDeployment._yaml_output_action_to_rest_output_action(self.output_action),
error_threshold=self.error_threshold,
retry_settings=self.retry_settings._to_rest_object() if self.retry_settings else None,
logging_level=self.logging_level,
mini_batch_size=self.mini_batch_size,
max_concurrency_per_instance=self.max_concurrency_per_instance,
environment_variables=self.environment_variables,
properties=self.properties,
)
else:
batch_deployment = RestBatchDeployment(
compute=self.compute,
description=self.description,
resources=self.resources._to_rest_object() if self.resources else None,
code_configuration=code_config,
environment_id=environment,
model=model,
output_file_name=self.output_file_name,
output_action=None,
error_threshold=self.error_threshold,
retry_settings=self.retry_settings._to_rest_object() if self.retry_settings else None,
logging_level=self.logging_level,
mini_batch_size=self.mini_batch_size,
max_concurrency_per_instance=self.max_concurrency_per_instance,
environment_variables=self.environment_variables,
properties=self.properties,
)
# Create base RestBatchDeployment object with common properties
batch_deployment = RestBatchDeployment(
compute=self.compute,
description=self.description,
resources=self.resources._to_rest_object() if self.resources else None,
code_configuration=code_config,
environment_id=environment,
model=model,
output_file_name=self.output_file_name,
output_action=(
BatchDeployment._yaml_output_action_to_rest_output_action(self.output_action)
if isinstance(self.output_action, str)
else None
),
error_threshold=self.error_threshold,
retry_settings=self.retry_settings._to_rest_object() if self.retry_settings else None,
logging_level=self.logging_level,
mini_batch_size=self.mini_batch_size,
max_concurrency_per_instance=self.max_concurrency_per_instance,
environment_variables=self.environment_variables,
properties=self.properties,
)

return BatchDeploymentData(location=location, properties=batch_deployment, tags=self.tags)

Expand Down Expand Up @@ -306,6 +354,7 @@ def _from_rest_object( # pylint: disable=arguments-renamed
properties=properties,
creation_context=SystemData._from_rest_object(deployment.system_data),
provisioning_state=deployment.properties.provisioning_state,
_from_rest=True,
)

return deployment
Expand Down
Loading