Skip to content

Make batch deployment subclasses GA #40619

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,17 @@

from marshmallow import INCLUDE, fields, post_load

from azure.ai.ml._schema import (
ArmVersionedStr,
ArmStr,
UnionField,
RegistryStr,
NestedField,
from azure.ai.ml._schema import ArmStr, ArmVersionedStr, NestedField, RegistryStr, UnionField
from azure.ai.ml._schema.core.fields import (
PathAwareSchema,
PipelineNodeNameStr,
StringTransformedEnum,
TypeSensitiveUnionField,
)
from azure.ai.ml._schema.core.fields import PipelineNodeNameStr, TypeSensitiveUnionField, PathAwareSchema
from azure.ai.ml._schema.pipeline.pipeline_component import PipelineComponentFileRefField
from azure.ai.ml.constants._common import AzureMLResourceType
from azure.ai.ml.constants._component import NodeType
from azure.ai.ml.constants._deployment import BatchDeploymentType

module_logger = logging.getLogger(__name__)

Expand All @@ -35,7 +35,9 @@ class PipelineComponentBatchDeploymentSchema(PathAwareSchema):
)
settings = fields.Dict()
name = fields.Str()
type = fields.Str()
type = StringTransformedEnum(
allowed_values=[BatchDeploymentType.PIPELINE, BatchDeploymentType.MODEL], required=False
)
job_definition = UnionField(
[
ArmStr(azureml_type=AzureMLResourceType.JOB),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# pylint: disable=protected-access

import logging
import warnings
from os import PathLike
from pathlib import Path
from typing import Any, Dict, Optional, Union
Expand All @@ -27,13 +28,29 @@

from .code_configuration import CodeConfiguration
from .deployment import Deployment
from .model_batch_deployment_settings import ModelBatchDeploymentSettings as BatchDeploymentSettings

module_logger = logging.getLogger(__name__)

SETTINGS_ATTRIBUTES = [
"output_action",
"output_file_name",
"error_threshold",
"retry_settings",
"logging_level",
"mini_batch_size",
"max_concurrency_per_instance",
"environment_variables",
]


class BatchDeployment(Deployment):
"""Batch endpoint deployment entity.

**Warning** This class should not be used directly.
Please use one of the child implementations, :class:`~azure.ai.ml.entities.ModelBatchDeployment` or
:class:`azure.ai.ml.entities.PipelineComponentBatchDeployment`.

:param name: the name of the batch deployment
:type name: str
:param description: Description of the resource.
Expand Down Expand Up @@ -112,34 +129,58 @@ def __init__(
instance_count: Optional[int] = None, # promoted property from resources.instance_count
**kwargs: Any,
) -> None:
_type = kwargs.pop("_type", None)
if _type is None:
warnings.warn(
"This class is intended as a base class and it's direct usage is deprecated. "
"Use one of the concrete implementations instead:\n"
"* ModelBatchDeployment - For model-based batch deployments\n"
"* PipelineComponentBatchDeployment - For pipeline component-based batch deployments"
)
self._provisioning_state: Optional[str] = kwargs.pop("provisioning_state", None)

settings = kwargs.pop("settings", None)
super(BatchDeployment, self).__init__(
name=name,
type=_type,
endpoint_name=endpoint_name,
properties=properties,
tags=tags,
description=description,
model=model,
code_configuration=code_configuration,
environment=environment,
environment_variables=environment_variables,
environment_variables=environment_variables, # needed, otherwise Deployment.__init__() will set it to {}
code_path=code_path,
scoring_script=scoring_script,
**kwargs,
)

self.compute = compute
self.resources = resources
self.output_action = output_action
self.output_file_name = output_file_name
self.error_threshold = error_threshold
self.retry_settings = retry_settings
self.logging_level = logging_level
self.mini_batch_size = mini_batch_size
self.max_concurrency_per_instance = max_concurrency_per_instance

if self.resources and instance_count:

self._settings = (
settings
if settings
else BatchDeploymentSettings(
mini_batch_size=mini_batch_size,
instance_count=instance_count,
max_concurrency_per_instance=max_concurrency_per_instance,
output_action=output_action,
output_file_name=output_file_name,
retry_settings=retry_settings,
environment_variables=environment_variables,
error_threshold=error_threshold,
logging_level=logging_level,
)
)

self._setup_instance_count()

def _setup_instance_count(
self,
) -> None: # No need to check instance_count here as it's already set in self._settings during initialization
if self.resources and self._settings.instance_count:
msg = "Can't set instance_count when resources is provided."
raise ValidationException(
message=msg,
Expand All @@ -149,8 +190,26 @@ def __init__(
error_type=ValidationErrorType.INVALID_VALUE,
)

if not self.resources and instance_count:
self.resources = ResourceConfiguration(instance_count=instance_count)
if not self.resources and self._settings.instance_count:
self.resources = ResourceConfiguration(instance_count=self._settings.instance_count)

def __getattr__(self, name: str) -> Optional[Any]:
# Support backwards compatibility with old BatchDeployment properties.
if name in SETTINGS_ATTRIBUTES:
try:
return getattr(self._settings, name)
except AttributeError:
pass
return super().__getattribute__(name)

def __setattr__(self, name, value):
# Support backwards compatibility with old BatchDeployment properties.
if name in SETTINGS_ATTRIBUTES:
try:
setattr(self._settings, name, value)
except AttributeError:
pass
super().__setattr__(name, value)

@property
def instance_count(self) -> Optional[int]:
Expand Down Expand Up @@ -195,7 +254,7 @@ def _yaml_output_action_to_rest_output_action(cls, yaml_output_action: Any) -> s
return output_switcher.get(yaml_output_action, yaml_output_action)

# pylint: disable=arguments-differ
def _to_rest_object(self, location: str) -> BatchDeploymentData: # type: ignore
def _to_rest_object(self, location: str) -> BatchDeploymentData: # type: ignore[override]
self._validate()
code_config = (
RestCodeConfiguration(
Expand All @@ -209,42 +268,28 @@ def _to_rest_object(self, location: str) -> BatchDeploymentData: # type: ignore
environment = self.environment

batch_deployment: RestBatchDeployment = None
if isinstance(self.output_action, str):
batch_deployment = RestBatchDeployment(
compute=self.compute,
description=self.description,
resources=self.resources._to_rest_object() if self.resources else None,
code_configuration=code_config,
environment_id=environment,
model=model,
output_file_name=self.output_file_name,
output_action=BatchDeployment._yaml_output_action_to_rest_output_action(self.output_action),
error_threshold=self.error_threshold,
retry_settings=self.retry_settings._to_rest_object() if self.retry_settings else None,
logging_level=self.logging_level,
mini_batch_size=self.mini_batch_size,
max_concurrency_per_instance=self.max_concurrency_per_instance,
environment_variables=self.environment_variables,
properties=self.properties,
)
else:
batch_deployment = RestBatchDeployment(
compute=self.compute,
description=self.description,
resources=self.resources._to_rest_object() if self.resources else None,
code_configuration=code_config,
environment_id=environment,
model=model,
output_file_name=self.output_file_name,
output_action=None,
error_threshold=self.error_threshold,
retry_settings=self.retry_settings._to_rest_object() if self.retry_settings else None,
logging_level=self.logging_level,
mini_batch_size=self.mini_batch_size,
max_concurrency_per_instance=self.max_concurrency_per_instance,
environment_variables=self.environment_variables,
properties=self.properties,
)
# Create base RestBatchDeployment object with common properties
batch_deployment = RestBatchDeployment(
compute=self.compute,
description=self.description,
resources=self.resources._to_rest_object() if self.resources else None,
code_configuration=code_config,
environment_id=environment,
model=model,
output_file_name=self.output_file_name,
output_action=(
BatchDeployment._yaml_output_action_to_rest_output_action(self.output_action)
if isinstance(self.output_action, str)
else None
),
error_threshold=self.error_threshold,
retry_settings=self.retry_settings._to_rest_object() if self.retry_settings else None,
logging_level=self.logging_level,
mini_batch_size=self.mini_batch_size,
max_concurrency_per_instance=self.max_concurrency_per_instance,
environment_variables=self.environment_variables,
properties=self.properties,
)

return BatchDeploymentData(location=location, properties=batch_deployment, tags=self.tags)

Expand Down
Loading