Skip to content

Commit

Permalink
Update SageMaker inference example (#208)
Browse files Browse the repository at this point in the history
* Add serverless mode to SageMaker inference sample

* Add to documentation
  • Loading branch information
silv-io authored Mar 2, 2023
1 parent da056db commit 61e8b51
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 8 deletions.
6 changes: 4 additions & 2 deletions sagemaker-inference/ReadMe.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ And execute the example with:
python main.py
```

You should see an output like this:
You should see an output like this for each of the runs:
```
Creating bucket...
Uploading model data to bucket...
Expand All @@ -62,4 +62,6 @@ Invoking via boto...
Predicted digits: [7, 3]
Invoking endpoint directly...
Predicted digits: [2, 6]
```
```

To try out the serverless run you can remove the comment in the `main.py` file and run the example again.
30 changes: 24 additions & 6 deletions sagemaker-inference/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
s3: S3Client = boto3.client("s3", endpoint_url=LOCALSTACK_ENDPOINT, region_name="us-east-1")


def deploy_model(run_id: str = "0"):
def deploy_model(run_id: str = "0", serverless=False):
# Put the Model into the correct bucket
print("Creating bucket...")
s3.create_bucket(Bucket=f"{MODEL_BUCKET}-{run_id}")
Expand All @@ -39,10 +39,15 @@ def deploy_model(run_id: str = "0"):
PrimaryContainer={"Image": CONTAINER_IMAGE,
"ModelDataUrl": f"s3://{MODEL_BUCKET}-{run_id}/{MODEL_NAME}.tar.gz"})
print("Adding endpoint configuration...")
sagemaker.create_endpoint_config(EndpointConfigName=f"{CONFIG_NAME}-{run_id}", ProductionVariants=[{
production_variant = {
"VariantName": f"var-{run_id}", "ModelName": f"{MODEL_NAME}-{run_id}", "InitialInstanceCount": 1,
"InstanceType": "ml.m5.large"
}])
}
if serverless:
production_variant |= {"ServerlessConfig": {"MaxConcurrency": 1, "MemorySizeInMB": 1024}}

sagemaker.create_endpoint_config(EndpointConfigName=f"{CONFIG_NAME}-{run_id}",
ProductionVariants=[production_variant])
print("Creating endpoint...")
sagemaker.create_endpoint(EndpointName=f"{ENDPOINT_NAME}-{run_id}", EndpointConfigName=f"{CONFIG_NAME}-{run_id}")

Expand Down Expand Up @@ -83,7 +88,7 @@ def inference_model_container(run_id: str = "0"):
tag_list = sagemaker.list_tags(ResourceArn=arn)
port = "4510"
for tag in tag_list["Tags"]:
if tag["Key"] == "_LS_ENDPOINT_PORT_":
if tag["Key"] == "_LS_REALTIMEENDPOINT_PORT_":
port = tag["Value"]
inputs = _get_input_dict()
print("Invoking endpoint directly...")
Expand All @@ -98,7 +103,7 @@ def inference_model_boto3(run_id: str = "0"):
response = sagemaker_runtime.invoke_endpoint(EndpointName=f"{ENDPOINT_NAME}-{run_id}", Body=json.dumps(inputs),
Accept="application/json",
ContentType="application/json")
_show_predictions(json.loads(response["Body"].read()))
_show_predictions(json.loads(response["Body"].read().decode("utf-8")))


def _short_uid():
Expand All @@ -107,10 +112,23 @@ def _short_uid():
return str(uuid.uuid4())[:8]


if __name__ == '__main__':
def run_regular():
test_run = _short_uid()
deploy_model(test_run)
if not await_endpoint(test_run):
exit(-1)
inference_model_boto3(test_run)
inference_model_container(test_run)


def run_serverless():
test_run = _short_uid()
deploy_model(test_run, serverless=True)
# invoking a serverless endpoint is only supported by using boto3
for _ in range(3):
inference_model_boto3(test_run)


if __name__ == '__main__':
run_regular()
# run_serverless()

0 comments on commit 61e8b51

Please sign in to comment.