From e6fb45b79d30d37148d68785b45b436dd2e53fb3 Mon Sep 17 00:00:00 2001 From: Mingfei Shao Date: Thu, 22 Aug 2024 14:26:55 -0500 Subject: [PATCH] add param modification logic and tests --- gen3cirrus/aws/utils.py | 67 +++++++++++++++++++++++++++++++++++++++-- test/test_utils.py | 37 ++++++++++++++++++++++- 2 files changed, 101 insertions(+), 3 deletions(-) diff --git a/gen3cirrus/aws/utils.py b/gen3cirrus/aws/utils.py index cdaf65b8..fe277ddd 100644 --- a/gen3cirrus/aws/utils.py +++ b/gen3cirrus/aws/utils.py @@ -1,9 +1,72 @@ +from urllib.parse import urlencode from botocore.exceptions import ClientError from cdislogging import get_logger logger = get_logger(__name__, log_level="info") +custom_params = ["user_id", "username", "client_id", "x-amz-request-payer"] + + +def is_custom_params(param_key): + """ + Little helper function for checking if a param key should be skipping from validation + + Args: + param_key (string): a key of a param + """ + if param_key in custom_params: + return True + else: + return False + + +def client_param_handler(*, params, context, **_kw): + """ + Little helper function for removing customized params before validating + + Args: + params (dict): a dict of parameters + context (context): for temporarily storing those removed parameters + """ + # Store custom parameters in context for later event handlers + context["custom_params"] = {k: v for k, v in params.items() if is_custom_params(k)} + # Remove custom parameters from client parameters, + # because validation would fail on them + return {k: v for k, v in params.items() if not is_custom_params(k)} + + +def request_param_injector(*, request, **_kw): + """ + Little helper function for adding customized params back into url before signing + + Args: + request (request): request for presigned url + """ + if request.context["custom_params"]: + request.url += "&" if "?" in request.url else "?" + request.url += urlencode(request.context["custom_params"]) + + +def customize_s3_client_param_events(s3_client): + """ + Function for modifying the params that need to be included when signing + This is needed because we need to include some customized params in the signed url, but boto3 won't allow them to exist out of the box + See https://stackoverflow.com/a/59057975 + + Args: + s3_client (S3.Client): boto3 S3 client + """ + s3_client.meta.events.register( + "provide-client-params.s3.GetObject", client_param_handler + ) + s3_client.meta.events.register("before-sign.s3.GetObject", request_param_injector) + s3_client.meta.events.register( + "provide-client-params.s3.PutObject", client_param_handler + ) + s3_client.meta.events.register("before-sign.s3.PutObject", request_param_injector) + return s3_client + def generate_presigned_url( client, method, bucket_name, object_name, expires, additional_info=None @@ -28,7 +91,7 @@ def generate_presigned_url( for key in additional_info: params[key] = additional_info[key] - s3_client = client + s3_client = customize_s3_client_param_events(client) if method == "get": client_method = "get_object" @@ -112,7 +175,7 @@ def generate_presigned_url_requester_pays( for key in additional_info: params[key] = additional_info[key] - s3_client = client + s3_client = customize_s3_client_param_events(client) try: response = s3_client.generate_presigned_url( diff --git a/test/test_utils.py b/test/test_utils.py index 28e557ec..4ea33271 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -2,6 +2,7 @@ import pytest from urllib.parse import quote +from botocore.exceptions import ParamValidationError from gen3cirrus.google_cloud.utils import ( _get_string_to_sign, @@ -143,6 +144,40 @@ def test_aws_get_presigned_url(): assert url is not None +def test_aws_get_presigned_url_with_valid_additional_info(): + """ + Test that we can get a presigned url from a bucket with some valid additional info + """ + + s3 = boto3.client("s3", aws_access_key_id="", aws_secret_access_key="") + + bucket = "test" + obj = "test-obj.txt" + expires = 3600 + additional_info = {"user_id": "test_user_id", "username": "test_username"} + + url = generate_presigned_url(s3, "get", bucket, obj, expires, additional_info) + + assert url is not None + + +def test_aws_get_presigned_url_with_invalid_additional_info(): + """ + Test that we cannot get a presigned url from a bucket with invalid additional info + """ + + s3 = boto3.client("s3", aws_access_key_id="", aws_secret_access_key="") + + bucket = "test" + obj = "test-obj.txt" + expires = 3600 + additional_info = {"some_random_key": "some_random_value"} + + with pytest.raises(ParamValidationError): + url = generate_presigned_url(s3, "get", bucket, obj, expires, additional_info) + assert url is None + + def test_aws_get_presigned_url_requester_pays(): """ Test that we can get a presigned url from a requester pays bucket @@ -160,7 +195,7 @@ def test_aws_get_presigned_url_requester_pays(): def test_aws_get_presigned_url_with_invalid_method(): """ - Test that we can not get a presigned url if the method is not valid + Test that we cannot get a presigned url if the method is not valid """ s3 = boto3.client("s3", aws_access_key_id="", aws_secret_access_key="")