From 7173bb5a729c1999932dcbe9b9d026eeb14dcfc9 Mon Sep 17 00:00:00 2001 From: Bram Elfrink Date: Fri, 4 Oct 2024 14:08:06 +0200 Subject: [PATCH] feat(AwsVllmComponent): Add AWS Cognito as authentication service for LLM applications --- .../cloud/aws/resources/vllm_component.py | 148 ++++++++---------- .../aws/resources/test_vllm_component.py | 25 +-- 2 files changed, 77 insertions(+), 96 deletions(-) diff --git a/src/damavand/cloud/aws/resources/vllm_component.py b/src/damavand/cloud/aws/resources/vllm_component.py index 42dc185..6b6bd5d 100644 --- a/src/damavand/cloud/aws/resources/vllm_component.py +++ b/src/damavand/cloud/aws/resources/vllm_component.py @@ -33,6 +33,8 @@ class AwsVllmComponentArgs: whether to deploy a public API for the model. api_env_name : str the name of the API environment. + cognito_user_pool_id : Optional[str] + the Cognito user pool ID for authentication. """ region: str = "us-west-2" @@ -42,6 +44,7 @@ class AwsVllmComponentArgs: instance_type: str = "ml.g4dn.xlarge" public_internet_access: bool = False api_env_name: str = "prod" + cognito_user_pool_id: Optional[str] = None class AwsVllmComponent(PulumiComponentResource): @@ -94,18 +97,24 @@ def __init__( ) self.args = args + + print(">>>> self.args: ", self.args) _ = self.model _ = self.endpoint_config _ = self.endpoint - if self.args.public_internet_access: - _ = self.api - _ = self.api_resource - _ = self.api_method - _ = self.api_integration - _ = self.api_integration_response - _ = self.api_method_response - _ = self.api_deploy + _ = self.api + _ = self.api_resource + + if not self.args.public_internet_access: + _ = self.api_authorizer + + _ = self.api_method + _ = self.api_integration + _ = self.api_integration_response + _ = self.api_method_response + _ = self.api_deploy + def get_service_assume_policy(self, service: str) -> dict[str, Any]: """Return the assume role policy for the requested service. @@ -233,17 +242,8 @@ def api(self) -> aws.apigateway.RestApi: """ Return a public API for the SageMaker endpoint. - Raises - ------ - AttributeError - When public_internet_access is False. """ - if not self.args.public_internet_access: - raise AttributeError( - "`api` is only available when public_internet_access is True" - ) - return aws.apigateway.RestApi( resource_name=f"{self._name}-api", opts=ResourceOptions(parent=self), @@ -258,17 +258,8 @@ def api_resource(self) -> aws.apigateway.Resource: """ Return a resource for the API Gateway. - Raises - ------ - AttributeError - When public_internet_access is False. """ - if not self.args.public_internet_access: - raise AttributeError( - "`api_resource`is only available when public_internet_access is True" - ) - return aws.apigateway.Resource( resource_name=f"{self._name}-api-resource", opts=ResourceOptions(parent=self), @@ -279,39 +270,71 @@ def api_resource(self) -> aws.apigateway.Resource: @property @cache - def api_method(self) -> aws.apigateway.Method: + def api_authorizer(self) -> aws.apigateway.Authorizer: """ - Return a method for the API Gateway. + Return an authorizer for the API Gateway. Raises ------ AttributeError - When public_internet_access is False. + When public_internet_access is True. + + AttributeError + When cognito_user_pool_id is not set. """ - if not self.args.public_internet_access: + if self.args.public_internet_access: raise AttributeError( - "`api_method`is only available when public_internet_access is True" + "`api_authorizer`is only available when public_internet_access is False" ) - return aws.apigateway.Method( - resource_name=f"{self._name}-api-method", + if not self.args.cognito_user_pool_id: + raise AttributeError( + "`api_authorizer` requires a cognito_user_pool_id to be set" + ) + + + return aws.apigateway.Authorizer( + resource_name=f"{self._name}-api-authorizer", opts=ResourceOptions(parent=self), rest_api=self.api.id, - resource_id=self.api_resource.id, - http_method="POST", - authorization="NONE", + type="COGNITO_USER_POOLS", + provider_arns=[self.args.cognito_user_pool_id], ) + @property + @cache + def api_method(self) -> aws.apigateway.Method: + """ + Return a method for the API Gateway. + + """ + + if self.args.public_internet_access: + return aws.apigateway.Method( + resource_name=f"{self._name}-api-method", + opts=ResourceOptions(parent=self), + rest_api=self.api.id, + resource_id=self.api_resource.id, + http_method="POST", + authorization="NONE", + ) + else: + return aws.apigateway.Method( + resource_name=f"{self._name}-api-method", + opts=ResourceOptions(parent=self), + rest_api=self.api.id, + resource_id=self.api_resource.id, + http_method="POST", + authorization="COGNITO_USER_POOLS", + authorizer_id=self.api_authorizer.id, + ) + @property def api_sagemaker_integration_uri(self) -> pulumi.Output[str]: """ Return the SageMaker model integration URI for the API Gateway - Raises - ------ - AttributeError - When public_internet_access is False. """ return self.endpoint.name.apply( @@ -332,17 +355,8 @@ def api_access_sagemaker_role(self) -> aws.iam.Role: """ Return an execution role for APIGateway to access SageMaker endpoints. - Raises - ------ - AttributeError - When public_internet_access is False. """ - if not self.args.public_internet_access: - raise AttributeError( - "`api_access_sagemaker_rol`is only available when public_internet_access is True" - ) - return aws.iam.Role( resource_name=f"{self._name}-api-sagemaker-access-role", opts=ResourceOptions(parent=self), @@ -358,17 +372,8 @@ def api_integration(self) -> aws.apigateway.Integration: """ Return a sagemaker integration for the API Gateway. - Raises - ------ - AttributeError - When public_internet_access is False. """ - if not self.args.public_internet_access: - raise AttributeError( - "`api_integration`is only available when public_internet_access is True" - ) - return aws.apigateway.Integration( resource_name=f"{self._name}-api-integration", opts=ResourceOptions(parent=self), @@ -387,17 +392,8 @@ def api_integration_response(self) -> aws.apigateway.IntegrationResponse: """ Return a sagemaker integration response for the API Gateway. - Raises - ------ - AttributeError - When public_internet_access is False. """ - if not self.args.public_internet_access: - raise AttributeError( - "`api_integration_response`is only available when public_internet_access is True" - ) - return aws.apigateway.IntegrationResponse( resource_name=f"{self._name}-api-integration-response", opts=ResourceOptions(parent=self, depends_on=[self.api_integration]), @@ -413,17 +409,8 @@ def api_method_response(self) -> aws.apigateway.MethodResponse: """ Return a sagemaker method response for the API Gateway. - Raises - ------ - AttributeError - When public_internet_access is False. """ - if not self.args.public_internet_access: - raise AttributeError( - "`api_method_response`is only available when public_internet_access is True" - ) - return aws.apigateway.MethodResponse( resource_name=f"{self._name}-api-method-response", opts=ResourceOptions(parent=self), @@ -439,17 +426,8 @@ def api_deploy(self) -> aws.apigateway.Deployment: """ Return an API deployment for the API Gateway. - Raises - ------ - AttributeError - When public_internet_access is False. """ - if not self.args.public_internet_access: - raise AttributeError( - "`api_deploy`is only available when public_internet_access is True" - ) - return aws.apigateway.Deployment( resource_name=f"{self._name}-api-deploy", opts=ResourceOptions( diff --git a/tests/clouds/aws/resources/test_vllm_component.py b/tests/clouds/aws/resources/test_vllm_component.py index e9fcdd0..17528bc 100644 --- a/tests/clouds/aws/resources/test_vllm_component.py +++ b/tests/clouds/aws/resources/test_vllm_component.py @@ -1,4 +1,3 @@ -import pytest from typing import Optional, Tuple, List import pulumi @@ -30,18 +29,20 @@ def call(self, args: MockCallArgs) -> Tuple[dict, Optional[List[Tuple[str, str]] def test_private_internet_access(): vllm = AwsVllmComponent( name="test", - args=AwsVllmComponentArgs(), + args=AwsVllmComponentArgs( + cognito_user_pool_id="us-west-2_123456789", + ), ) - with pytest.raises(AttributeError): - vllm.api - vllm.api_resource - vllm.api_method - vllm.api_access_sagemaker_role - vllm.api_integration - vllm.api_integration_response - vllm.api_method_response - vllm.api_deploy + assert isinstance(vllm.api, aws.apigateway.RestApi) + assert isinstance(vllm.api_resource, aws.apigateway.Resource) + assert isinstance(vllm.api_authorizer, aws.apigateway.Authorizer) + assert isinstance(vllm.api_method, aws.apigateway.Method) + assert isinstance(vllm.api_access_sagemaker_role, aws.iam.Role) + assert isinstance(vllm.api_integration, aws.apigateway.Integration) + assert isinstance(vllm.api_integration_response, aws.apigateway.IntegrationResponse) + assert isinstance(vllm.api_method_response, aws.apigateway.MethodResponse) + assert isinstance(vllm.api_deploy, aws.apigateway.Deployment) def test_public_internet_access(): @@ -67,6 +68,7 @@ def test_model_image_version(): name="test", args=AwsVllmComponentArgs( model_image_version="0.29.0", + public_internet_access=True, ), ) @@ -78,6 +80,7 @@ def test_model_image_config(): name="test", args=AwsVllmComponentArgs( model_name="microsoft/Phi-3-mini-4k-instruct", + public_internet_access=True, ), )