Skip to content

Commit

Permalink
feat(AwsVllmComponent): Add AWS Cognito as authentication service for…
Browse files Browse the repository at this point in the history
… LLM applications
  • Loading branch information
bramelfrink committed Oct 4, 2024
1 parent 7fb92d3 commit 7173bb5
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 96 deletions.
148 changes: 63 additions & 85 deletions src/damavand/cloud/aws/resources/vllm_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ class AwsVllmComponentArgs:
whether to deploy a public API for the model.
api_env_name : str
the name of the API environment.
cognito_user_pool_id : Optional[str]
the Cognito user pool ID for authentication.
"""

region: str = "us-west-2"
Expand All @@ -42,6 +44,7 @@ class AwsVllmComponentArgs:
instance_type: str = "ml.g4dn.xlarge"
public_internet_access: bool = False
api_env_name: str = "prod"
cognito_user_pool_id: Optional[str] = None


class AwsVllmComponent(PulumiComponentResource):
Expand Down Expand Up @@ -94,18 +97,24 @@ def __init__(
)

self.args = args

print(">>>> self.args: ", self.args)
_ = self.model
_ = self.endpoint_config
_ = self.endpoint

if self.args.public_internet_access:
_ = self.api
_ = self.api_resource
_ = self.api_method
_ = self.api_integration
_ = self.api_integration_response
_ = self.api_method_response
_ = self.api_deploy
_ = self.api
_ = self.api_resource

if not self.args.public_internet_access:
_ = self.api_authorizer

_ = self.api_method
_ = self.api_integration
_ = self.api_integration_response
_ = self.api_method_response
_ = self.api_deploy


def get_service_assume_policy(self, service: str) -> dict[str, Any]:
"""Return the assume role policy for the requested service.
Expand Down Expand Up @@ -233,17 +242,8 @@ def api(self) -> aws.apigateway.RestApi:
"""
Return a public API for the SageMaker endpoint.
Raises
------
AttributeError
When public_internet_access is False.
"""

if not self.args.public_internet_access:
raise AttributeError(
"`api` is only available when public_internet_access is True"
)

return aws.apigateway.RestApi(
resource_name=f"{self._name}-api",
opts=ResourceOptions(parent=self),
Expand All @@ -258,17 +258,8 @@ def api_resource(self) -> aws.apigateway.Resource:
"""
Return a resource for the API Gateway.
Raises
------
AttributeError
When public_internet_access is False.
"""

if not self.args.public_internet_access:
raise AttributeError(
"`api_resource`is only available when public_internet_access is True"
)

return aws.apigateway.Resource(
resource_name=f"{self._name}-api-resource",
opts=ResourceOptions(parent=self),
Expand All @@ -279,39 +270,71 @@ def api_resource(self) -> aws.apigateway.Resource:

@property
@cache
def api_method(self) -> aws.apigateway.Method:
def api_authorizer(self) -> aws.apigateway.Authorizer:
"""
Return a method for the API Gateway.
Return an authorizer for the API Gateway.
Raises
------
AttributeError
When public_internet_access is False.
When public_internet_access is True.
AttributeError
When cognito_user_pool_id is not set.
"""

if not self.args.public_internet_access:
if self.args.public_internet_access:
raise AttributeError(
"`api_method`is only available when public_internet_access is True"
"`api_authorizer`is only available when public_internet_access is False"
)

return aws.apigateway.Method(
resource_name=f"{self._name}-api-method",
if not self.args.cognito_user_pool_id:
raise AttributeError(
"`api_authorizer` requires a cognito_user_pool_id to be set"
)


return aws.apigateway.Authorizer(
resource_name=f"{self._name}-api-authorizer",
opts=ResourceOptions(parent=self),
rest_api=self.api.id,
resource_id=self.api_resource.id,
http_method="POST",
authorization="NONE",
type="COGNITO_USER_POOLS",
provider_arns=[self.args.cognito_user_pool_id],
)

@property
@cache
def api_method(self) -> aws.apigateway.Method:
"""
Return a method for the API Gateway.
"""

if self.args.public_internet_access:
return aws.apigateway.Method(
resource_name=f"{self._name}-api-method",
opts=ResourceOptions(parent=self),
rest_api=self.api.id,
resource_id=self.api_resource.id,
http_method="POST",
authorization="NONE",
)
else:
return aws.apigateway.Method(
resource_name=f"{self._name}-api-method",
opts=ResourceOptions(parent=self),
rest_api=self.api.id,
resource_id=self.api_resource.id,
http_method="POST",
authorization="COGNITO_USER_POOLS",
authorizer_id=self.api_authorizer.id,
)

@property
def api_sagemaker_integration_uri(self) -> pulumi.Output[str]:
"""
Return the SageMaker model integration URI for the API Gateway
Raises
------
AttributeError
When public_internet_access is False.
"""

return self.endpoint.name.apply(
Expand All @@ -332,17 +355,8 @@ def api_access_sagemaker_role(self) -> aws.iam.Role:
"""
Return an execution role for APIGateway to access SageMaker endpoints.
Raises
------
AttributeError
When public_internet_access is False.
"""

if not self.args.public_internet_access:
raise AttributeError(
"`api_access_sagemaker_rol`is only available when public_internet_access is True"
)

return aws.iam.Role(
resource_name=f"{self._name}-api-sagemaker-access-role",
opts=ResourceOptions(parent=self),
Expand All @@ -358,17 +372,8 @@ def api_integration(self) -> aws.apigateway.Integration:
"""
Return a sagemaker integration for the API Gateway.
Raises
------
AttributeError
When public_internet_access is False.
"""

if not self.args.public_internet_access:
raise AttributeError(
"`api_integration`is only available when public_internet_access is True"
)

return aws.apigateway.Integration(
resource_name=f"{self._name}-api-integration",
opts=ResourceOptions(parent=self),
Expand All @@ -387,17 +392,8 @@ def api_integration_response(self) -> aws.apigateway.IntegrationResponse:
"""
Return a sagemaker integration response for the API Gateway.
Raises
------
AttributeError
When public_internet_access is False.
"""

if not self.args.public_internet_access:
raise AttributeError(
"`api_integration_response`is only available when public_internet_access is True"
)

return aws.apigateway.IntegrationResponse(
resource_name=f"{self._name}-api-integration-response",
opts=ResourceOptions(parent=self, depends_on=[self.api_integration]),
Expand All @@ -413,17 +409,8 @@ def api_method_response(self) -> aws.apigateway.MethodResponse:
"""
Return a sagemaker method response for the API Gateway.
Raises
------
AttributeError
When public_internet_access is False.
"""

if not self.args.public_internet_access:
raise AttributeError(
"`api_method_response`is only available when public_internet_access is True"
)

return aws.apigateway.MethodResponse(
resource_name=f"{self._name}-api-method-response",
opts=ResourceOptions(parent=self),
Expand All @@ -439,17 +426,8 @@ def api_deploy(self) -> aws.apigateway.Deployment:
"""
Return an API deployment for the API Gateway.
Raises
------
AttributeError
When public_internet_access is False.
"""

if not self.args.public_internet_access:
raise AttributeError(
"`api_deploy`is only available when public_internet_access is True"
)

return aws.apigateway.Deployment(
resource_name=f"{self._name}-api-deploy",
opts=ResourceOptions(
Expand Down
25 changes: 14 additions & 11 deletions tests/clouds/aws/resources/test_vllm_component.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import pytest
from typing import Optional, Tuple, List

import pulumi
Expand Down Expand Up @@ -30,18 +29,20 @@ def call(self, args: MockCallArgs) -> Tuple[dict, Optional[List[Tuple[str, str]]
def test_private_internet_access():
vllm = AwsVllmComponent(
name="test",
args=AwsVllmComponentArgs(),
args=AwsVllmComponentArgs(
cognito_user_pool_id="us-west-2_123456789",
),
)

with pytest.raises(AttributeError):
vllm.api
vllm.api_resource
vllm.api_method
vllm.api_access_sagemaker_role
vllm.api_integration
vllm.api_integration_response
vllm.api_method_response
vllm.api_deploy
assert isinstance(vllm.api, aws.apigateway.RestApi)
assert isinstance(vllm.api_resource, aws.apigateway.Resource)
assert isinstance(vllm.api_authorizer, aws.apigateway.Authorizer)
assert isinstance(vllm.api_method, aws.apigateway.Method)
assert isinstance(vllm.api_access_sagemaker_role, aws.iam.Role)
assert isinstance(vllm.api_integration, aws.apigateway.Integration)
assert isinstance(vllm.api_integration_response, aws.apigateway.IntegrationResponse)
assert isinstance(vllm.api_method_response, aws.apigateway.MethodResponse)
assert isinstance(vllm.api_deploy, aws.apigateway.Deployment)


def test_public_internet_access():
Expand All @@ -67,6 +68,7 @@ def test_model_image_version():
name="test",
args=AwsVllmComponentArgs(
model_image_version="0.29.0",
public_internet_access=True,
),
)

Expand All @@ -78,6 +80,7 @@ def test_model_image_config():
name="test",
args=AwsVllmComponentArgs(
model_name="microsoft/Phi-3-mini-4k-instruct",
public_internet_access=True,
),
)

Expand Down

0 comments on commit 7173bb5

Please sign in to comment.