Skip to content

Commit a5515c5

Browse files
feat(ingestion/SageMaker): Remove deprecated apis and add stateful ingestion capability (datahub-project#10573)
1 parent d559656 commit a5515c5

File tree

6 files changed

+112
-362
lines changed

6 files changed

+112
-362
lines changed

metadata-ingestion/src/datahub/ingestion/source/aws/sagemaker.py

+18-4
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from collections import defaultdict
2-
from typing import DefaultDict, Dict, Iterable
2+
from typing import DefaultDict, Dict, Iterable, List, Optional
33

44
from datahub.ingestion.api.common import PipelineContext
55
from datahub.ingestion.api.decorators import (
@@ -10,7 +10,7 @@
1010
platform_name,
1111
support_status,
1212
)
13-
from datahub.ingestion.api.source import Source
13+
from datahub.ingestion.api.source import MetadataWorkUnitProcessor
1414
from datahub.ingestion.api.workunit import MetadataWorkUnit
1515
from datahub.ingestion.source.aws.sagemaker_processors.common import (
1616
SagemakerSourceConfig,
@@ -26,13 +26,19 @@
2626
)
2727
from datahub.ingestion.source.aws.sagemaker_processors.lineage import LineageProcessor
2828
from datahub.ingestion.source.aws.sagemaker_processors.models import ModelProcessor
29+
from datahub.ingestion.source.state.stale_entity_removal_handler import (
30+
StaleEntityRemovalHandler,
31+
)
32+
from datahub.ingestion.source.state.stateful_ingestion_base import (
33+
StatefulIngestionSourceBase,
34+
)
2935

3036

3137
@platform_name("SageMaker")
3238
@config_class(SagemakerSourceConfig)
3339
@support_status(SupportStatus.CERTIFIED)
3440
@capability(SourceCapability.LINEAGE_COARSE, "Enabled by default")
35-
class SagemakerSource(Source):
41+
class SagemakerSource(StatefulIngestionSourceBase):
3642
"""
3743
This plugin extracts the following:
3844
@@ -45,7 +51,7 @@ class SagemakerSource(Source):
4551
report = SagemakerSourceReport()
4652

4753
def __init__(self, config: SagemakerSourceConfig, ctx: PipelineContext):
48-
super().__init__(ctx)
54+
super().__init__(config, ctx)
4955
self.source_config = config
5056
self.report = SagemakerSourceReport()
5157
self.sagemaker_client = config.sagemaker_client
@@ -56,6 +62,14 @@ def create(cls, config_dict, ctx):
5662
config = SagemakerSourceConfig.parse_obj(config_dict)
5763
return cls(config, ctx)
5864

65+
def get_workunit_processors(self) -> List[Optional[MetadataWorkUnitProcessor]]:
66+
return [
67+
*super().get_workunit_processors(),
68+
StaleEntityRemovalHandler.create(
69+
self, self.source_config, self.ctx
70+
).workunit_processor,
71+
]
72+
5973
def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]:
6074
# get common lineage graph
6175
lineage_processor = LineageProcessor(

metadata-ingestion/src/datahub/ingestion/source/aws/sagemaker_processors/common.py

+20-5
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,22 @@
1-
from dataclasses import dataclass
2-
from typing import Dict, Optional, Union
1+
from dataclasses import dataclass, field
2+
from typing import Dict, List, Optional, Union
33

44
from pydantic.fields import Field
55

6-
from datahub.ingestion.api.source import SourceReport
6+
from datahub.configuration.source_common import PlatformInstanceConfigMixin
77
from datahub.ingestion.source.aws.aws_common import AwsSourceConfig
8+
from datahub.ingestion.source.state.stale_entity_removal_handler import (
9+
StaleEntityRemovalSourceReport,
10+
StatefulIngestionConfigBase,
11+
StatefulStaleMetadataRemovalConfig,
12+
)
813

914

10-
class SagemakerSourceConfig(AwsSourceConfig):
15+
class SagemakerSourceConfig(
16+
AwsSourceConfig,
17+
PlatformInstanceConfigMixin,
18+
StatefulIngestionConfigBase,
19+
):
1120
extract_feature_groups: Optional[bool] = Field(
1221
default=True, description="Whether to extract feature groups."
1322
)
@@ -17,21 +26,24 @@ class SagemakerSourceConfig(AwsSourceConfig):
1726
extract_jobs: Optional[Union[Dict[str, str], bool]] = Field(
1827
default=True, description="Whether to extract AutoML jobs."
1928
)
29+
# Custom Stateful Ingestion settings
30+
stateful_ingestion: Optional[StatefulStaleMetadataRemovalConfig] = None
2031

2132
@property
2233
def sagemaker_client(self):
2334
return self.get_sagemaker_client()
2435

2536

2637
@dataclass
27-
class SagemakerSourceReport(SourceReport):
38+
class SagemakerSourceReport(StaleEntityRemovalSourceReport):
2839
feature_groups_scanned = 0
2940
features_scanned = 0
3041
endpoints_scanned = 0
3142
groups_scanned = 0
3243
models_scanned = 0
3344
jobs_scanned = 0
3445
datasets_scanned = 0
46+
filtered: List[str] = field(default_factory=list)
3547

3648
def report_feature_group_scanned(self) -> None:
3749
self.feature_groups_scanned += 1
@@ -53,3 +65,6 @@ def report_job_scanned(self) -> None:
5365

5466
def report_dataset_scanned(self) -> None:
5567
self.datasets_scanned += 1
68+
69+
def report_dropped(self, name: str) -> None:
70+
self.filtered.append(name)

metadata-ingestion/src/datahub/ingestion/source/aws/sagemaker_processors/job_classes.py

+6-25
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1-
from typing import Dict, Final
1+
from typing import Dict
2+
3+
from typing_extensions import Final
24

35
from datahub.metadata.schema_classes import JobStatusClass
46

@@ -63,8 +65,9 @@ class AutoMlJobInfo(SageMakerJobInfo):
6365
list_key: Final = "AutoMLJobSummaries"
6466
list_name_key: Final = "AutoMLJobName"
6567
list_arn_key: Final = "AutoMLJobArn"
66-
# https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker.html#SageMaker.Client.describe_auto_ml_job
67-
describe_command: Final = "describe_auto_ml_job"
68+
# DescribeAutoMLJobV2 are new versions of CreateAutoMLJob and DescribeAutoMLJob which offer backward compatibility.
69+
# https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_DescribeAutoMLJobV2.html
70+
describe_command: Final = "describe_auto_ml_job_v2"
6871
describe_name_key: Final = "AutoMLJobName"
6972
describe_arn_key: Final = "AutoMLJobArn"
7073
describe_status_key: Final = "AutoMLJobStatus"
@@ -101,28 +104,6 @@ class CompilationJobInfo(SageMakerJobInfo):
101104
processor = "process_compilation_job"
102105

103106

104-
class EdgePackagingJobInfo(SageMakerJobInfo):
105-
# https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker.html#SageMaker.Client.list_edge_packaging_jobs
106-
list_command: Final = "list_edge_packaging_jobs"
107-
list_key: Final = "EdgePackagingJobSummaries"
108-
list_name_key: Final = "EdgePackagingJobName"
109-
list_arn_key: Final = "EdgePackagingJobArn"
110-
# https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker.html#SageMaker.Client.describe_edge_packaging_job
111-
describe_command: Final = "describe_edge_packaging_job"
112-
describe_name_key: Final = "EdgePackagingJobName"
113-
describe_arn_key: Final = "EdgePackagingJobArn"
114-
describe_status_key: Final = "EdgePackagingJobStatus"
115-
status_map = {
116-
"INPROGRESS": JobStatusClass.IN_PROGRESS,
117-
"COMPLETED": JobStatusClass.COMPLETED,
118-
"FAILED": JobStatusClass.FAILED,
119-
"STARTING": JobStatusClass.STARTING,
120-
"STOPPING": JobStatusClass.STOPPING,
121-
"STOPPED": JobStatusClass.STOPPED,
122-
}
123-
processor = "process_edge_packaging_job"
124-
125-
126107
class HyperParameterTuningJobInfo(SageMakerJobInfo):
127108
# https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker.html#SageMaker.Client.list_hyper_parameter_tuning_jobs
128109
list_command: Final = "list_hyper_parameter_tuning_jobs"

metadata-ingestion/src/datahub/ingestion/source/aws/sagemaker_processors/jobs.py

+15-87
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
from datahub.ingestion.source.aws.sagemaker_processors.job_classes import (
2727
AutoMlJobInfo,
2828
CompilationJobInfo,
29-
EdgePackagingJobInfo,
3029
HyperParameterTuningJobInfo,
3130
LabelingJobInfo,
3231
ProcessingJobInfo,
@@ -53,7 +52,6 @@
5352
"JobInfo",
5453
AutoMlJobInfo,
5554
CompilationJobInfo,
56-
EdgePackagingJobInfo,
5755
HyperParameterTuningJobInfo,
5856
LabelingJobInfo,
5957
ProcessingJobInfo,
@@ -65,7 +63,6 @@
6563
class JobType(Enum):
6664
AUTO_ML = "auto_ml"
6765
COMPILATION = "compilation"
68-
EDGE_PACKAGING = "edge_packaging"
6966
HYPER_PARAMETER_TUNING = "hyper_parameter_tuning"
7067
LABELING = "labeling"
7168
PROCESSING = "processing"
@@ -78,7 +75,6 @@ class JobType(Enum):
7875
job_type_to_info: Mapping[JobType, Any] = {
7976
JobType.AUTO_ML: AutoMlJobInfo(),
8077
JobType.COMPILATION: CompilationJobInfo(),
81-
JobType.EDGE_PACKAGING: EdgePackagingJobInfo(),
8278
JobType.HYPER_PARAMETER_TUNING: HyperParameterTuningJobInfo(),
8379
JobType.LABELING: LabelingJobInfo(),
8480
JobType.PROCESSING: ProcessingJobInfo(),
@@ -416,23 +412,20 @@ def process_auto_ml_job(self, job: Dict[str, Any]) -> SageMakerJob:
416412
"""
417413
Process outputs from Boto3 describe_auto_ml_job()
418414
419-
See https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker.html#SageMaker.Client.describe_auto_ml_job
415+
See https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker/client/describe_auto_ml_job_v2.html
420416
"""
421417

422418
JOB_TYPE = JobType.AUTO_ML
423419

424420
input_datasets = {}
425-
426-
for input_config in job.get("InputDataConfig", []):
421+
for input_config in job.get("AutoMLJobInputDataConfig", []):
427422
input_data = input_config.get("DataSource", {}).get("S3DataSource")
428-
429423
if input_data is not None and "S3Uri" in input_data:
430424
input_datasets[make_s3_urn(input_data["S3Uri"], self.env)] = {
431425
"dataset_type": "s3",
432426
"uri": input_data["S3Uri"],
433427
"datatype": input_data.get("S3DataType"),
434428
}
435-
436429
output_datasets = {}
437430

438431
output_s3_path = job.get("OutputDataConfig", {}).get("S3OutputPath")
@@ -448,6 +441,18 @@ def process_auto_ml_job(self, job: Dict[str, Any]) -> SageMakerJob:
448441
JOB_TYPE,
449442
)
450443

444+
metrics: Dict[str, Any] = {}
445+
# Get job metrics from CandidateMetrics
446+
candidate_metrics = (
447+
job.get("BestCandidate", {})
448+
.get("CandidateProperties", {})
449+
.get("CandidateMetrics", [])
450+
)
451+
if candidate_metrics:
452+
metrics = {
453+
metric["MetricName"]: metric["Value"] for metric in candidate_metrics
454+
}
455+
451456
model_containers = job.get("BestCandidate", {}).get("InferenceContainers", [])
452457

453458
for model_container in model_containers:
@@ -456,7 +461,7 @@ def process_auto_ml_job(self, job: Dict[str, Any]) -> SageMakerJob:
456461
if model_data_url is not None:
457462
job_key = JobKey(job_snapshot.urn, JobDirection.TRAINING)
458463

459-
self.update_model_image_jobs(model_data_url, job_key)
464+
self.update_model_image_jobs(model_data_url, job_key, metrics=metrics)
460465

461466
return SageMakerJob(
462467
job_name=job_name,
@@ -515,83 +520,6 @@ def process_compilation_job(self, job: Dict[str, Any]) -> SageMakerJob:
515520
output_datasets=output_datasets,
516521
)
517522

518-
def process_edge_packaging_job(
519-
self,
520-
job: Dict[str, Any],
521-
) -> SageMakerJob:
522-
"""
523-
Process outputs from Boto3 describe_edge_packaging_job()
524-
525-
See https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker.html#SageMaker.Client.describe_edge_packaging_job
526-
"""
527-
528-
JOB_TYPE = JobType.EDGE_PACKAGING
529-
530-
name: str = job["EdgePackagingJobName"]
531-
arn: str = job["EdgePackagingJobArn"]
532-
533-
output_datasets = {}
534-
535-
model_artifact_s3_uri: Optional[str] = job.get("ModelArtifact")
536-
output_s3_uri: Optional[str] = job.get("OutputConfig", {}).get(
537-
"S3OutputLocation"
538-
)
539-
540-
if model_artifact_s3_uri is not None:
541-
output_datasets[make_s3_urn(model_artifact_s3_uri, self.env)] = {
542-
"dataset_type": "s3",
543-
"uri": model_artifact_s3_uri,
544-
}
545-
546-
if output_s3_uri is not None:
547-
output_datasets[make_s3_urn(output_s3_uri, self.env)] = {
548-
"dataset_type": "s3",
549-
"uri": output_s3_uri,
550-
}
551-
552-
# from docs: "The name of the SageMaker Neo compilation job that is used to locate model artifacts that are being packaged."
553-
compilation_job_name: Optional[str] = job.get("CompilationJobName")
554-
555-
output_jobs = set()
556-
if compilation_job_name is not None:
557-
# globally unique job name
558-
full_job_name = ("compilation", compilation_job_name)
559-
560-
if full_job_name in self.name_to_arn:
561-
output_jobs.add(
562-
make_sagemaker_job_urn(
563-
"compilation",
564-
compilation_job_name,
565-
self.name_to_arn[full_job_name],
566-
self.env,
567-
)
568-
)
569-
else:
570-
self.report.report_warning(
571-
name,
572-
f"Unable to find ARN for compilation job {compilation_job_name} produced by edge packaging job {arn}",
573-
)
574-
575-
job_snapshot, job_name, job_arn = self.create_common_job_snapshot(
576-
job,
577-
JOB_TYPE,
578-
f"https://{self.aws_region}.console.aws.amazon.com/sagemaker/home?region={self.aws_region}#/edge-packaging-jobs/{job['EdgePackagingJobName']}",
579-
)
580-
581-
if job.get("ModelName") is not None:
582-
job_key = JobKey(job_snapshot.urn, JobDirection.DOWNSTREAM)
583-
584-
self.update_model_name_jobs(job["ModelName"], job_key)
585-
586-
return SageMakerJob(
587-
job_name=job_name,
588-
job_arn=job_arn,
589-
job_type=JOB_TYPE,
590-
job_snapshot=job_snapshot,
591-
output_datasets=output_datasets,
592-
output_jobs=output_jobs,
593-
)
594-
595523
def process_hyper_parameter_tuning_job(
596524
self,
597525
job: Dict[str, Any],

0 commit comments

Comments
 (0)