diff --git a/localstack/services/sns/sns_listener.py b/localstack/services/sns/sns_listener.py index 8fa3107e7f68e..90cde0192ee45 100644 --- a/localstack/services/sns/sns_listener.py +++ b/localstack/services/sns/sns_listener.py @@ -20,27 +20,23 @@ from localstack.utils.aws.dead_letter_queue import sns_error_to_dead_letter_queue from localstack.utils.common import parse_request_data, timestamp_millis, short_uid, to_str, to_bytes, start_thread from localstack.utils.persistence import PersistingProxyListener +from localstack.services.generic_proxy import RegionBackend +from moto.sns.models import SNSBackend as MotoSNSBackend +from moto.sns.exceptions import DuplicateSnsEndpointError # set up logger LOG = logging.getLogger(__name__) -# mappings for SNS topic subscriptions -SNS_SUBSCRIPTIONS = {} -# mappings for subscription status -SUBSCRIPTION_STATUS = {} - -# mappings for SNS tags -SNS_TAGS = {} - -# cache of platform endpoint messages (used primarily for testing) -PLATFORM_ENDPOINT_MESSAGES = {} - -# maps phone numbers to list of sent messages -SMS_MESSAGES = [] - -# actions to be skipped from persistence -SKIP_PERSISTENCE_ACTIONS = ['Subscribe', 'ConfirmSubscription', 'Unsubscribe'] +class SNSBackend(RegionBackend): + def __init__(self): + self.sns_subscriptions = {} # mappings for SNS topic subscriptions + self.subscription_status = {} # mappings for subscription status + self.sns_tags = {} # mappings for SNS tags + self.platform_endpoint_messages = {} # cache of platform endpoint messages (used primarily for testing) + self.sms_messages = [] # maps phone numbers to list of sent messages + # actions to be skipped from persistence + self.skip_persistence_actions = ['Subscribe', 'ConfirmSubscription', 'Unsubscribe'] class ProxyListenerSNS(PersistingProxyListener): @@ -72,7 +68,6 @@ def forward_request(self, method, path, data, headers): if topic_arn: topic_arn = topic_arn[0] topic_arn = aws_stack.fix_account_id_in_arns(topic_arn) - if req_action == 'SetSubscriptionAttributes': sub = get_subscription_by_arn(req_data['SubscriptionArn'][0]) if not sub: @@ -121,11 +116,11 @@ def forward_request(self, method, path, data, headers): elif req_action == 'Publish': if req_data.get('Subject') == ['']: return make_error(code=400, code_string='InvalidParameter', message='Subject') - + sns_backend = SNSBackend.get() # No need to create a topic to send SMS or single push notifications with SNS # but we can't mock a sending so we only return that it went well if 'PhoneNumber' not in req_data and 'TargetArn' not in req_data: - if topic_arn not in SNS_SUBSCRIPTIONS: + if topic_arn not in sns_backend.sns_subscriptions: return make_error(code=404, code_string='NotFound', message='Topic does not exist') message_id = publish_message(topic_arn, req_data, headers) @@ -147,16 +142,18 @@ def forward_request(self, method, path, data, headers): return make_response(req_action, content=content) elif req_action == 'CreateTopic': + sns_backend = SNSBackend.get() topic_arn = aws_stack.sns_topic_arn(req_data['Name'][0]) - tag_resource_success = self._extract_tags(topic_arn, req_data, True) - SNS_SUBSCRIPTIONS[topic_arn] = SNS_SUBSCRIPTIONS.get(topic_arn) or [] + tag_resource_success = self._extract_tags(topic_arn, req_data, True, sns_backend) + sns_backend.sns_subscriptions[topic_arn] = sns_backend.sns_subscriptions.get(topic_arn) or [] # in case if there is an error it returns an error , other wise it will continue as expected. if not tag_resource_success: return make_error(code=400, code_string='InvalidParameter', message='Topic already exists with different tags') elif req_action == 'TagResource': - self._extract_tags(topic_arn, req_data, False) + sns_backend = SNSBackend.get() + self._extract_tags(topic_arn, req_data, False, sns_backend) return make_response(req_action) elif req_action == 'UntagResource': @@ -174,10 +171,10 @@ def forward_request(self, method, path, data, headers): return True @staticmethod - def _extract_tags(topic_arn, req_data, is_create_topic_request): + def _extract_tags(topic_arn, req_data, is_create_topic_request, sns_backend): tags = [] req_tags = {k: v for k, v in req_data.items() if k.startswith('Tags.member.')} - existing_tags = SNS_TAGS.get(topic_arn, None) + existing_tags = sns_backend.sns_tags.get(topic_arn, None) # TODO: use aws_responses.extract_tags(...) here! for i in range(int(len(req_tags.keys()) / 2)): key = req_tags['Tags.member.' + str(i + 1) + '.Key'][0] @@ -248,21 +245,40 @@ def return_response(self, method, path, data, headers, response): ) def should_persist(self, method, path, data, headers, response): + sns_backend = SNSBackend.get() req_params = parse_request_data(method, path, data) action = req_params.get('Action', '') - if action in SKIP_PERSISTENCE_ACTIONS: + if action in sns_backend.skip_persistence_actions: return False return super(ProxyListenerSNS, self).should_persist(method, path, data, headers, response) +def patch_moto(): + def patch_create_platform_endpoint(self, *args): + try: + return create_platform_endpoint_orig(self, *args) + except DuplicateSnsEndpointError: + custom_user_data, token = args[2], args[3] + for endpoint in self.platform_endpoints.values(): + if endpoint.token == token: + if custom_user_data and custom_user_data != endpoint.custom_user_data: + raise DuplicateSnsEndpointError('Endpoint already exist for token: %s with different attributes' + % token) + return endpoint + create_platform_endpoint_orig = MotoSNSBackend.create_platform_endpoint + MotoSNSBackend.create_platform_endpoint = patch_create_platform_endpoint + + +patch_moto() # instantiate listener UPDATE_SNS = ProxyListenerSNS() def unsubscribe_sqs_queue(queue_url): """ Called upon deletion of an SQS queue, to remove the queue from subscriptions """ - for topic_arn, subscriptions in SNS_SUBSCRIPTIONS.items(): - subscriptions = SNS_SUBSCRIPTIONS.get(topic_arn, []) + sns_backend = SNSBackend.get() + for topic_arn, subscriptions in sns_backend.sns_subscriptions.items(): + subscriptions = sns_backend.sns_subscriptions.get(topic_arn, []) for subscriber in list(subscriptions): sub_url = subscriber.get('sqs_queue_url') or subscriber['Endpoint'] if queue_url == sub_url: @@ -270,8 +286,8 @@ def unsubscribe_sqs_queue(queue_url): def message_to_subscribers(message_id, message, topic_arn, req_data, headers, subscription_arn=None, skip_checks=False): - - subscriptions = SNS_SUBSCRIPTIONS.get(topic_arn, []) + sns_backend = SNSBackend.get() + subscriptions = sns_backend.sns_subscriptions.get(topic_arn, []) for subscriber in list(subscriptions): if subscription_arn not in [None, subscriber['SubscriptionArn']]: continue @@ -288,7 +304,7 @@ def message_to_subscribers(message_id, message, topic_arn, req_data, headers, su 'endpoint': subscriber['Endpoint'], 'message_content': req_data['Message'][0] } - SMS_MESSAGES.append(event) + sns_backend.sms_messages.append(event) LOG.info('Delivering SMS message to %s: %s', subscriber['Endpoint'], req_data['Message'][0]) elif subscriber['Protocol'] == 'sqs': @@ -397,12 +413,14 @@ def message_to_subscribers(message_id, message, topic_arn, req_data, headers, su def publish_message(topic_arn, req_data, headers, subscription_arn=None, skip_checks=False): + sns_backend = SNSBackend.get() message = req_data['Message'][0] message_id = str(uuid.uuid4()) if topic_arn and ':endpoint/' in topic_arn: # cache messages published to platform endpoints - cache = PLATFORM_ENDPOINT_MESSAGES[topic_arn] = PLATFORM_ENDPOINT_MESSAGES.get(topic_arn) or [] + cache = sns_backend.platform_endpoint_messages[topic_arn] = sns_backend. \ + platform_endpoint_messages.get(topic_arn) or [] cache.append(req_data) LOG.debug('Publishing message to TopicArn: %s | Message: %s' % (topic_arn, message)) @@ -413,19 +431,21 @@ def publish_message(topic_arn, req_data, headers, subscription_arn=None, skip_ch def do_delete_topic(topic_arn): - SNS_SUBSCRIPTIONS.pop(topic_arn, None) - SNS_TAGS.pop(topic_arn, None) + sns_backend = SNSBackend.get() + sns_backend.sns_subscriptions.pop(topic_arn, None) + sns_backend.sns_tags.pop(topic_arn, None) def do_confirm_subscription(topic_arn, token): - for k, v in SUBSCRIPTION_STATUS.items(): + sns_backend = SNSBackend.get() + for k, v in sns_backend.subscription_status.items(): if v['Token'] == token and v['TopicArn'] == topic_arn: v['Status'] = 'Subscribed' def do_subscribe(topic_arn, endpoint, protocol, subscription_arn, attributes, filter_policy=None): - topic_subs = SNS_SUBSCRIPTIONS[topic_arn] = SNS_SUBSCRIPTIONS.get(topic_arn) or [] - + sns_backend = SNSBackend.get() + topic_subs = sns_backend.sns_subscriptions[topic_arn] = sns_backend.sns_subscriptions.get(topic_arn) or [] # An endpoint may only be subscribed to a topic once. Subsequent # subscribe calls do nothing (subscribe is idempotent). for existing_topic_subscription in topic_subs: @@ -443,10 +463,10 @@ def do_subscribe(topic_arn, endpoint, protocol, subscription_arn, attributes, fi subscription.update(attributes) topic_subs.append(subscription) - if subscription_arn not in SUBSCRIPTION_STATUS: - SUBSCRIPTION_STATUS[subscription_arn] = {} + if subscription_arn not in sns_backend.subscription_status: + sns_backend.subscription_status[subscription_arn] = {} - SUBSCRIPTION_STATUS[subscription_arn].update( + sns_backend.subscription_status[subscription_arn].update( { 'TopicArn': topic_arn, 'Token': short_uid(), @@ -468,18 +488,20 @@ def do_subscribe(topic_arn, endpoint, protocol, subscription_arn, attributes, fi def do_unsubscribe(subscription_arn): - for topic_arn, existing_subs in SNS_SUBSCRIPTIONS.items(): - SNS_SUBSCRIPTIONS[topic_arn] = [ + sns_backend = SNSBackend.get() + for topic_arn, existing_subs in sns_backend.sns_subscriptions.items(): + sns_backend.sns_subscriptions[topic_arn] = [ sub for sub in existing_subs if sub['SubscriptionArn'] != subscription_arn ] def _get_tags(topic_arn): - if topic_arn not in SNS_TAGS: - SNS_TAGS[topic_arn] = [] + sns_backend = SNSBackend.get() + if topic_arn not in sns_backend.sns_tags: + sns_backend.sns_tags[topic_arn] = [] - return SNS_TAGS[topic_arn] + return sns_backend.sns_tags[topic_arn] def do_list_tags_for_resource(topic_arn): @@ -487,7 +509,8 @@ def do_list_tags_for_resource(topic_arn): def do_tag_resource(topic_arn, tags): - existing_tags = SNS_TAGS.get(topic_arn, []) + sns_backend = SNSBackend.get() + existing_tags = sns_backend.sns_tags.get(topic_arn, []) tags = [ tag for idx, tag in enumerate(tags) if tag not in tags[:idx] @@ -506,11 +529,12 @@ def existing_tag_index(item): else: existing_tags[existing_index] = item - SNS_TAGS[topic_arn] = existing_tags + sns_backend.sns_tags[topic_arn] = existing_tags def do_untag_resource(topic_arn, tag_keys): - SNS_TAGS[topic_arn] = [t for t in _get_tags(topic_arn) if t['Key'] not in tag_keys] + sns_backend = SNSBackend.get() + sns_backend.sns_tags[topic_arn] = [t for t in _get_tags(topic_arn) if t['Key'] not in tag_keys] # --------------- @@ -519,8 +543,9 @@ def do_untag_resource(topic_arn, tag_keys): def get_subscription_by_arn(sub_arn): + sns_backend = SNSBackend.get() # TODO maintain separate map instead of traversing all items - for key, subscriptions in SNS_SUBSCRIPTIONS.items(): + for key, subscriptions in sns_backend.sns_subscriptions.items(): for sub in subscriptions: if sub['SubscriptionArn'] == sub_arn: return sub diff --git a/tests/integration/test_sns.py b/tests/integration/test_sns.py index b0712a0ab155d..88b0f6b3bb072 100644 --- a/tests/integration/test_sns.py +++ b/tests/integration/test_sns.py @@ -17,7 +17,7 @@ from localstack.utils.testutil import check_expected_lambda_log_events_length from localstack.services.infra import start_proxy from localstack.services.generic_proxy import ProxyListener -from localstack.services.sns import sns_listener +from localstack.services.sns.sns_listener import SNSBackend from .lambdas import lambda_integration from .test_lambda import TEST_LAMBDA_PYTHON, LAMBDA_RUNTIME_PYTHON36, TEST_LAMBDA_LIBS from localstack.services.install import SQS_BACKEND_IMPL @@ -279,13 +279,14 @@ def check_message(): def test_subscribe_platform_endpoint(self): sns = self.sns_client + sns_backend = SNSBackend.get() app_arn = sns.create_platform_application(Name='app1', Platform='p1', Attributes={})['PlatformApplicationArn'] platform_arn = sns.create_platform_endpoint(PlatformApplicationArn=app_arn, Token='token_1')['EndpointArn'] subscription = self._publish_sns_message_with_attrs(platform_arn, 'application') # assert that message has been received def check_message(): - self.assertGreater(len(sns_listener.PLATFORM_ENDPOINT_MESSAGES[platform_arn]), 0) + self.assertGreater(len(sns_backend.platform_endpoint_messages[platform_arn]), 0) retry(check_message, retries=PUBLICATION_RETRIES, sleep=PUBLICATION_TIMEOUT) # clean up @@ -394,10 +395,11 @@ def test_topic_subscription(self): Protocol='email', Endpoint='localstack@yopmail.com' ) + sns_backend = SNSBackend.get() def check_subscription(): subscription_arn = subscription['SubscriptionArn'] - subscription_obj = sns_listener.SUBSCRIPTION_STATUS[subscription_arn] + subscription_obj = sns_backend.subscription_status[subscription_arn] self.assertEqual(subscription_obj['Status'], 'Not Subscribed') _token = subscription_obj['Token'] @@ -694,6 +696,27 @@ def test_create_duplicate_topic_check_idempotentness(self): # clean up self.sns_client.delete_topic(TopicArn=responses[0]['TopicArn']) + def test_create_platform_endpoint_check_idempotentness(self): + response = self.sns_client.create_platform_application( + Name='test-%s' % short_uid(), Platform='GCM', Attributes={'PlatformCredential': '123'} + ) + kwargs_list = [{'Token': 'test1', 'CustomUserData': 'test-data'}, + {'Token': 'test1', 'CustomUserData': 'test-data'}, + {'Token': 'test1'}, {'Token': 'test1'} + ] + platform_arn = response['PlatformApplicationArn'] + responses = [] + for kwargs in kwargs_list: + responses.append(self.sns_client.create_platform_endpoint(PlatformApplicationArn=platform_arn, + **kwargs)) + # Assert endpointarn is returned in every call create platform call + for i in range(len(responses)): + self.assertIn('EndpointArn', responses[i]) + endpoint_arn = responses[0]['EndpointArn'] + # clean up + self.sns_client.delete_endpoint(EndpointArn=endpoint_arn) + self.sns_client.delete_platform_application(PlatformApplicationArn=platform_arn) + def test_publish_by_path_parameters(self): topic_name = 'topic-{}'.format(short_uid()) queue_name = 'queue-{}'.format(short_uid()) @@ -765,6 +788,10 @@ def _create_queue(self): return queue_name, queue_arn, queue_url def test_publish_sms_endpoint(self): + def check_messages(): + sns_backend = SNSBackend.get() + self.assertEqual(len(list_of_contacts), len(sns_backend.sms_messages)) + list_of_contacts = [ '+10123456789', '+10000000000', @@ -780,9 +807,6 @@ def test_publish_sms_endpoint(self): ) # Publish a message. self.sns_client.publish(Message=message, TopicArn=self.topic_arn) - - def check_messages(): - self.assertEqual(len(list_of_contacts), len(sns_listener.SMS_MESSAGES)) retry(check_messages, retries=3, sleep=0.5) def test_publish_sqs_from_sns(self): diff --git a/tests/unit/test_sns.py b/tests/unit/test_sns.py index c79e733cf3d57..f36e8718e6b8a 100644 --- a/tests/unit/test_sns.py +++ b/tests/unit/test_sns.py @@ -5,6 +5,7 @@ import dateutil.parser import re from localstack.services.sns import sns_listener +from localstack.services.sns.sns_listener import SNSBackend class SNSTests(unittest.TestCase): @@ -14,8 +15,6 @@ def setUp(self): 'RawMessageDelivery': 'false', 'TopicArn': 'arn', } - # Reset subscriptions - sns_listener.SNS_SUBSCRIPTIONS = {} def test_unsubscribe_without_arn_should_error(self): sns = sns_listener.ProxyListenerSNS() @@ -222,7 +221,7 @@ def test_create_sns_message_timestamp_millis(self): def test_only_one_subscription_per_topic_per_endpoint(self): sub_arn = 'arn:aws:sns:us-east-1:000000000000:test-topic:45e61c7f-dca5-4fcd-be2b-4e1b0d6eef72' topic_arn = 'arn:aws:sns:us-east-1:000000000000:test-topic' - + sns_backend = SNSBackend().get() for i in [1, 2]: sns_listener.do_subscribe( topic_arn, @@ -231,7 +230,7 @@ def test_only_one_subscription_per_topic_per_endpoint(self): sub_arn, {} ) - self.assertEqual(len(sns_listener.SNS_SUBSCRIPTIONS[topic_arn]), 1) + self.assertEqual(len(sns_backend.sns_subscriptions[topic_arn]), 1) def test_filter_policy(self): test_data = [