Skip to content

Commit

Permalink
fix: skip checks on non serverless api resources (#6471)
Browse files Browse the repository at this point in the history
* skip checks on non serverless api resources

* fix comment

* handle ApiGatewayV2

* add test

* update comment
  • Loading branch information
sidhujus authored Jan 3, 2024
1 parent b5503ae commit 48d8b2b
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 0 deletions.
6 changes: 6 additions & 0 deletions samcli/commands/deploy/auth_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from samcli.commands.local.lib.swagger.reader import SwaggerReader
from samcli.lib.providers.provider import Stack
from samcli.lib.providers.sam_function_provider import SamFunctionProvider
from samcli.lib.utils.resources import AWS_APIGATEWAY_RESTAPI, AWS_APIGATEWAY_V2_API

LOG = logging.getLogger(__name__)

Expand Down Expand Up @@ -101,6 +102,11 @@ def _auth_id(resources_dict, event_properties, identifier):
"""
resource_name = event_properties.get(identifier, "")
api_resource = resources_dict.get(resource_name, {})

# Auth does not apply to ApiGateway::RestApi or ApiGatwayV2::Api resources so return true and continue
if api_resource and (api_resource.get("Type") in [AWS_APIGATEWAY_RESTAPI, AWS_APIGATEWAY_V2_API]):
return True

return any(
[
api_resource.get("Properties", {}).get("Auth", False),
Expand Down
29 changes: 29 additions & 0 deletions tests/unit/commands/deploy/test_auth_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,35 @@ def test_auth_per_resource_defined_on_api_resource(self):
_auth_per_resource = auth_per_resource([Stack("", "", "", {}, self.template_dict)])
self.assertEqual(_auth_per_resource, [("HelloWorldFunction", True)])

def test_auth_per_resource_on_non_serverless_restapi(self):
self.template_dict["Resources"]["HelloWorldApi"] = OrderedDict(
[
("Type", "AWS::ApiGateway::RestApi"),
("Properties", OrderedDict([("StageName", "Prod")])),
]
)
# setup the lambda function with a restapiId which has Auth defined.
self.template_dict["Resources"]["HelloWorldFunction"]["Properties"]["Events"]["HelloWorld"]["Properties"][
"RestApiId"
] = {"Ref": "HelloWorldApi"}
self.template_dict["Resources"]["HelloWorldFunction"]["Properties"]["Events"]["HelloWorld"]["Type"] = "Api"
_auth_per_resource = auth_per_resource([Stack("", "", "", {}, self.template_dict)])
self.assertEqual(_auth_per_resource, [("HelloWorldFunction", True)])

def test_auth_per_resource_on_non_serverless_httpapi(self):
self.template_dict["Resources"]["HelloWorldApi"] = OrderedDict(
[
("Type", "AWS::ApiGatewayV2::Api"),
]
)
# setup the lambda function with a restapiId which has Auth defined.
self.template_dict["Resources"]["HelloWorldFunction"]["Properties"]["Events"]["HelloWorld"]["Properties"][
"ApiId"
] = {"Ref": "HelloWorldApi"}
self.template_dict["Resources"]["HelloWorldFunction"]["Properties"]["Events"]["HelloWorld"]["Type"] = "HttpApi"
_auth_per_resource = auth_per_resource([Stack("", "", "", {}, self.template_dict)])
self.assertEqual(_auth_per_resource, [("HelloWorldFunction", True)])

def test_auth_supplied_via_definition_body_uri(self):
self.template_dict["Resources"]["HelloWorldApi"] = OrderedDict(
[
Expand Down

0 comments on commit 48d8b2b

Please sign in to comment.