Skip to content
This repository was archived by the owner on Aug 14, 2024. It is now read-only.

Commit

Permalink
refactor utils for resource tagging
Browse files Browse the repository at this point in the history
  • Loading branch information
whummer committed Apr 6, 2021
1 parent 1fbf658 commit 5200867
Show file tree
Hide file tree
Showing 7 changed files with 37 additions and 30 deletions.
4 changes: 2 additions & 2 deletions localstack/services/cloudformation/cloudformation_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import xmltodict
from flask import Flask, request
from requests.models import Response
from localstack.utils.aws import aws_stack
from localstack.utils.aws import aws_stack, aws_responses
from localstack.utils.common import (
parse_request_data, short_uid, long_uid, clone, clone_safe, select_attributes,
timestamp_millis, recurse_object)
Expand Down Expand Up @@ -177,7 +177,7 @@ def template_resources(self):

@property
def tags(self):
return aws_stack.extract_tags(self.metadata)
return aws_responses.extract_tags(self.metadata)

@property
def imports(self):
Expand Down
7 changes: 3 additions & 4 deletions localstack/services/cloudwatch/cloudwatch_listener.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def forward_request(self, method, path, data, headers):
action = req_data.get('Action')
if action == 'TagResource':
arn = req_data.get('ResourceARN')
tags = aws_stack.extract_tags(req_data)
tags = aws_responses.extract_tags(req_data)
TAGS.tag_resource(arn, tags)
return aws_responses.requests_response_xml(action, {}, xmlns=XMLNS_CLOUDWATCH)
if action == 'UntagResource':
Expand Down Expand Up @@ -53,9 +53,8 @@ def return_response(self, method, path, data, headers, response):
cloudwatch_backends[aws_stack.get_region()].alarms[name].treat_missing_data = treat_missing_data
# record tags
arn = aws_stack.cloudwatch_alarm_arn(name)
tags = aws_stack.extract_tags(req_data)
if tags:
TAGS.tag_resource(arn, tags)
tags = aws_responses.extract_tags(req_data)
TAGS.tag_resource(arn, tags)

# Fix Incorrect date format to the correct format
# the dictionary contains the tag as the key and the value is a
Expand Down
1 change: 1 addition & 0 deletions localstack/services/sns/sns_listener.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,7 @@ def forward_request(self, method, path, data, headers):
def _extract_tags(topic_arn, req_data, is_create_topic_request):
tags = []
req_tags = {k: v for k, v in req_data.items() if k.startswith('Tags.member.')}
# TODO: use aws_responses.extract_tags(..) here!
for i in range(int(len(req_tags.keys()) / 2)):
key = req_tags['Tags.member.' + str(i + 1) + '.Key'][0]
value = req_tags['Tags.member.' + str(i + 1) + '.Value'][0]
Expand Down
24 changes: 24 additions & 0 deletions localstack/utils/aws/aws_responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,30 @@ def make_error(*args, **kwargs):
return flask_error_response_xml(*args, **kwargs)


def extract_tags(req_data):
keys = []
values = []
for param_name in ['Tag', 'member']:
keys = extract_url_encoded_param_list(req_data, 'Tags.{}.%s.Key'.format(param_name))
values = extract_url_encoded_param_list(req_data, 'Tags.{}.%s.Value'.format(param_name))
if keys and values:
break
entries = zip(keys, values)
tags = [{'Key': entry[0], 'Value': entry[1]} for entry in entries]
return tags


def extract_url_encoded_param_list(req_data, pattern):
result = []
for i in range(1, 200):
key = pattern % i
value = req_data.get(key)
if value is None:
break
result.append(value)
return result


def calculate_crc32(content):
return crc32(to_bytes(content)) & 0xffffffff

Expand Down
14 changes: 0 additions & 14 deletions localstack/utils/aws/aws_stack.py
Original file line number Diff line number Diff line change
Expand Up @@ -972,17 +972,3 @@ def check_stack():
def await_stack_completion(stack_name, retries=20, sleep=2, statuses=None):
statuses = statuses or ['CREATE_COMPLETE', 'UPDATE_COMPLETE']
return await_stack_status(stack_name, statuses, retries=retries, sleep=sleep)


# TODO: move to aws_responses.py?
def extract_tags(req_data):
tags = []
for i in range(1, 200):
k1 = 'Tags.member.%s.Key' % i
k2 = 'Tags.member.%s.Value' % i
key = req_data.get(k1)
value = req_data.get(k2, '')
if key is None:
break
tags.append({'Key': key, 'Value': value})
return tags
7 changes: 5 additions & 2 deletions localstack/utils/tagging.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,17 @@ class TaggingService(object):
def __init__(self):
self.tags = {}

def list_tags_for_resource(self, arn):
def list_tags_for_resource(self, arn, root_name=None):
root_name = root_name or 'Tags'
result = []
if arn in self.tags:
for k, v in self.tags[arn].items():
result.append({'Key': k, 'Value': v})
return {'Tags': result}
return {root_name: result}

def tag_resource(self, arn, tags):
if not tags:
return
if arn not in self.tags:
self.tags[arn] = {}
for t in tags:
Expand Down
10 changes: 2 additions & 8 deletions tests/integration/test_sqs.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,12 @@
import unittest
import requests
import datetime

from botocore.exceptions import ClientError
from botocore.awsrequest import AWSRequest
from botocore.credentials import Credentials
from botocore.auth import SigV4Auth, SIGV4_TIMESTAMP
from localstack.constants import (
TEST_AWS_ACCOUNT_ID,
TEST_AWS_ACCESS_KEY_ID,
TEST_AWS_SECRET_ACCESS_KEY
)
from localstack.constants import TEST_AWS_ACCOUNT_ID, TEST_AWS_ACCESS_KEY_ID, TEST_AWS_SECRET_ACCESS_KEY
from six.moves.urllib.parse import urlencode

from localstack import config
from localstack.utils import testutil
from localstack.utils.aws import aws_stack
Expand Down Expand Up @@ -867,7 +861,7 @@ def test_create_queue_with_slashes(self):
self.client.delete_queue(QueueUrl=queue_url.get('QueueUrl'))

result = self.client.list_queues()
self.assertNotIn(queue_url.get('QueueUrl'), result.get('QueueUrls'))
self.assertNotIn(queue_url.get('QueueUrl'), result.get('QueueUrls', []))

def list_queues_with_auth_in_presigned_url(self, method):
base_url = '{}://{}:{}'.format(get_service_protocol(), config.LOCALSTACK_HOSTNAME, config.PORT_SQS)
Expand Down

0 comments on commit 5200867

Please sign in to comment.