From 61e8b515e09604d969a47ea6acd8ef109b541980 Mon Sep 17 00:00:00 2001 From: Silvio Vasiljevic Date: Thu, 2 Mar 2023 16:40:27 +0100 Subject: [PATCH] Update SageMaker inference example (#208) * Add serverless mode to SageMaker inference sample * Add to documentation --- sagemaker-inference/ReadMe.md | 6 ++++-- sagemaker-inference/main.py | 30 ++++++++++++++++++++++++------ 2 files changed, 28 insertions(+), 8 deletions(-) diff --git a/sagemaker-inference/ReadMe.md b/sagemaker-inference/ReadMe.md index c3392cb..fcaffc3 100644 --- a/sagemaker-inference/ReadMe.md +++ b/sagemaker-inference/ReadMe.md @@ -47,7 +47,7 @@ And execute the example with: python main.py ``` -You should see an output like this: +You should see an output like this for each of the runs: ``` Creating bucket... Uploading model data to bucket... @@ -62,4 +62,6 @@ Invoking via boto... Predicted digits: [7, 3] Invoking endpoint directly... Predicted digits: [2, 6] -``` \ No newline at end of file +``` + +To try out the serverless run you can remove the comment in the `main.py` file and run the example again. \ No newline at end of file diff --git a/sagemaker-inference/main.py b/sagemaker-inference/main.py index cc7138e..81dc4ff 100644 --- a/sagemaker-inference/main.py +++ b/sagemaker-inference/main.py @@ -26,7 +26,7 @@ s3: S3Client = boto3.client("s3", endpoint_url=LOCALSTACK_ENDPOINT, region_name="us-east-1") -def deploy_model(run_id: str = "0"): +def deploy_model(run_id: str = "0", serverless=False): # Put the Model into the correct bucket print("Creating bucket...") s3.create_bucket(Bucket=f"{MODEL_BUCKET}-{run_id}") @@ -39,10 +39,15 @@ def deploy_model(run_id: str = "0"): PrimaryContainer={"Image": CONTAINER_IMAGE, "ModelDataUrl": f"s3://{MODEL_BUCKET}-{run_id}/{MODEL_NAME}.tar.gz"}) print("Adding endpoint configuration...") - sagemaker.create_endpoint_config(EndpointConfigName=f"{CONFIG_NAME}-{run_id}", ProductionVariants=[{ + production_variant = { "VariantName": f"var-{run_id}", "ModelName": f"{MODEL_NAME}-{run_id}", "InitialInstanceCount": 1, "InstanceType": "ml.m5.large" - }]) + } + if serverless: + production_variant |= {"ServerlessConfig": {"MaxConcurrency": 1, "MemorySizeInMB": 1024}} + + sagemaker.create_endpoint_config(EndpointConfigName=f"{CONFIG_NAME}-{run_id}", + ProductionVariants=[production_variant]) print("Creating endpoint...") sagemaker.create_endpoint(EndpointName=f"{ENDPOINT_NAME}-{run_id}", EndpointConfigName=f"{CONFIG_NAME}-{run_id}") @@ -83,7 +88,7 @@ def inference_model_container(run_id: str = "0"): tag_list = sagemaker.list_tags(ResourceArn=arn) port = "4510" for tag in tag_list["Tags"]: - if tag["Key"] == "_LS_ENDPOINT_PORT_": + if tag["Key"] == "_LS_REALTIMEENDPOINT_PORT_": port = tag["Value"] inputs = _get_input_dict() print("Invoking endpoint directly...") @@ -98,7 +103,7 @@ def inference_model_boto3(run_id: str = "0"): response = sagemaker_runtime.invoke_endpoint(EndpointName=f"{ENDPOINT_NAME}-{run_id}", Body=json.dumps(inputs), Accept="application/json", ContentType="application/json") - _show_predictions(json.loads(response["Body"].read())) + _show_predictions(json.loads(response["Body"].read().decode("utf-8"))) def _short_uid(): @@ -107,10 +112,23 @@ def _short_uid(): return str(uuid.uuid4())[:8] -if __name__ == '__main__': +def run_regular(): test_run = _short_uid() deploy_model(test_run) if not await_endpoint(test_run): exit(-1) inference_model_boto3(test_run) inference_model_container(test_run) + + +def run_serverless(): + test_run = _short_uid() + deploy_model(test_run, serverless=True) + # invoking a serverless endpoint is only supported by using boto3 + for _ in range(3): + inference_model_boto3(test_run) + + +if __name__ == '__main__': + run_regular() + # run_serverless()