Skip to content
This repository was archived by the owner on Aug 14, 2024. It is now read-only.

Commit

Permalink
fix idempotence for SNS CreatePlatformEndpoint API calls (localstack#…
Browse files Browse the repository at this point in the history
  • Loading branch information
KrishnanRanjithkumar authored May 15, 2021
1 parent 0691439 commit c236bf2
Show file tree
Hide file tree
Showing 3 changed files with 106 additions and 58 deletions.
121 changes: 73 additions & 48 deletions localstack/services/sns/sns_listener.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,27 +20,23 @@
from localstack.utils.aws.dead_letter_queue import sns_error_to_dead_letter_queue
from localstack.utils.common import parse_request_data, timestamp_millis, short_uid, to_str, to_bytes, start_thread
from localstack.utils.persistence import PersistingProxyListener
from localstack.services.generic_proxy import RegionBackend
from moto.sns.models import SNSBackend as MotoSNSBackend
from moto.sns.exceptions import DuplicateSnsEndpointError

# set up logger
LOG = logging.getLogger(__name__)

# mappings for SNS topic subscriptions
SNS_SUBSCRIPTIONS = {}

# mappings for subscription status
SUBSCRIPTION_STATUS = {}

# mappings for SNS tags
SNS_TAGS = {}

# cache of platform endpoint messages (used primarily for testing)
PLATFORM_ENDPOINT_MESSAGES = {}

# maps phone numbers to list of sent messages
SMS_MESSAGES = []

# actions to be skipped from persistence
SKIP_PERSISTENCE_ACTIONS = ['Subscribe', 'ConfirmSubscription', 'Unsubscribe']
class SNSBackend(RegionBackend):
def __init__(self):
self.sns_subscriptions = {} # mappings for SNS topic subscriptions
self.subscription_status = {} # mappings for subscription status
self.sns_tags = {} # mappings for SNS tags
self.platform_endpoint_messages = {} # cache of platform endpoint messages (used primarily for testing)
self.sms_messages = [] # maps phone numbers to list of sent messages
# actions to be skipped from persistence
self.skip_persistence_actions = ['Subscribe', 'ConfirmSubscription', 'Unsubscribe']


class ProxyListenerSNS(PersistingProxyListener):
Expand Down Expand Up @@ -72,7 +68,6 @@ def forward_request(self, method, path, data, headers):
if topic_arn:
topic_arn = topic_arn[0]
topic_arn = aws_stack.fix_account_id_in_arns(topic_arn)

if req_action == 'SetSubscriptionAttributes':
sub = get_subscription_by_arn(req_data['SubscriptionArn'][0])
if not sub:
Expand Down Expand Up @@ -121,11 +116,11 @@ def forward_request(self, method, path, data, headers):
elif req_action == 'Publish':
if req_data.get('Subject') == ['']:
return make_error(code=400, code_string='InvalidParameter', message='Subject')

sns_backend = SNSBackend.get()
# No need to create a topic to send SMS or single push notifications with SNS
# but we can't mock a sending so we only return that it went well
if 'PhoneNumber' not in req_data and 'TargetArn' not in req_data:
if topic_arn not in SNS_SUBSCRIPTIONS:
if topic_arn not in sns_backend.sns_subscriptions:
return make_error(code=404, code_string='NotFound', message='Topic does not exist')

message_id = publish_message(topic_arn, req_data, headers)
Expand All @@ -147,16 +142,18 @@ def forward_request(self, method, path, data, headers):
return make_response(req_action, content=content)

elif req_action == 'CreateTopic':
sns_backend = SNSBackend.get()
topic_arn = aws_stack.sns_topic_arn(req_data['Name'][0])
tag_resource_success = self._extract_tags(topic_arn, req_data, True)
SNS_SUBSCRIPTIONS[topic_arn] = SNS_SUBSCRIPTIONS.get(topic_arn) or []
tag_resource_success = self._extract_tags(topic_arn, req_data, True, sns_backend)
sns_backend.sns_subscriptions[topic_arn] = sns_backend.sns_subscriptions.get(topic_arn) or []
# in case if there is an error it returns an error , other wise it will continue as expected.
if not tag_resource_success:
return make_error(code=400, code_string='InvalidParameter',
message='Topic already exists with different tags')

elif req_action == 'TagResource':
self._extract_tags(topic_arn, req_data, False)
sns_backend = SNSBackend.get()
self._extract_tags(topic_arn, req_data, False, sns_backend)
return make_response(req_action)

elif req_action == 'UntagResource':
Expand All @@ -174,10 +171,10 @@ def forward_request(self, method, path, data, headers):
return True

@staticmethod
def _extract_tags(topic_arn, req_data, is_create_topic_request):
def _extract_tags(topic_arn, req_data, is_create_topic_request, sns_backend):
tags = []
req_tags = {k: v for k, v in req_data.items() if k.startswith('Tags.member.')}
existing_tags = SNS_TAGS.get(topic_arn, None)
existing_tags = sns_backend.sns_tags.get(topic_arn, None)
# TODO: use aws_responses.extract_tags(...) here!
for i in range(int(len(req_tags.keys()) / 2)):
key = req_tags['Tags.member.' + str(i + 1) + '.Key'][0]
Expand Down Expand Up @@ -248,30 +245,49 @@ def return_response(self, method, path, data, headers, response):
)

def should_persist(self, method, path, data, headers, response):
sns_backend = SNSBackend.get()
req_params = parse_request_data(method, path, data)
action = req_params.get('Action', '')
if action in SKIP_PERSISTENCE_ACTIONS:
if action in sns_backend.skip_persistence_actions:
return False
return super(ProxyListenerSNS, self).should_persist(method, path, data, headers, response)


def patch_moto():
def patch_create_platform_endpoint(self, *args):
try:
return create_platform_endpoint_orig(self, *args)
except DuplicateSnsEndpointError:
custom_user_data, token = args[2], args[3]
for endpoint in self.platform_endpoints.values():
if endpoint.token == token:
if custom_user_data and custom_user_data != endpoint.custom_user_data:
raise DuplicateSnsEndpointError('Endpoint already exist for token: %s with different attributes'
% token)
return endpoint
create_platform_endpoint_orig = MotoSNSBackend.create_platform_endpoint
MotoSNSBackend.create_platform_endpoint = patch_create_platform_endpoint


patch_moto()
# instantiate listener
UPDATE_SNS = ProxyListenerSNS()


def unsubscribe_sqs_queue(queue_url):
""" Called upon deletion of an SQS queue, to remove the queue from subscriptions """
for topic_arn, subscriptions in SNS_SUBSCRIPTIONS.items():
subscriptions = SNS_SUBSCRIPTIONS.get(topic_arn, [])
sns_backend = SNSBackend.get()
for topic_arn, subscriptions in sns_backend.sns_subscriptions.items():
subscriptions = sns_backend.sns_subscriptions.get(topic_arn, [])
for subscriber in list(subscriptions):
sub_url = subscriber.get('sqs_queue_url') or subscriber['Endpoint']
if queue_url == sub_url:
subscriptions.remove(subscriber)


def message_to_subscribers(message_id, message, topic_arn, req_data, headers, subscription_arn=None, skip_checks=False):

subscriptions = SNS_SUBSCRIPTIONS.get(topic_arn, [])
sns_backend = SNSBackend.get()
subscriptions = sns_backend.sns_subscriptions.get(topic_arn, [])
for subscriber in list(subscriptions):
if subscription_arn not in [None, subscriber['SubscriptionArn']]:
continue
Expand All @@ -288,7 +304,7 @@ def message_to_subscribers(message_id, message, topic_arn, req_data, headers, su
'endpoint': subscriber['Endpoint'],
'message_content': req_data['Message'][0]
}
SMS_MESSAGES.append(event)
sns_backend.sms_messages.append(event)
LOG.info('Delivering SMS message to %s: %s', subscriber['Endpoint'], req_data['Message'][0])

elif subscriber['Protocol'] == 'sqs':
Expand Down Expand Up @@ -397,12 +413,14 @@ def message_to_subscribers(message_id, message, topic_arn, req_data, headers, su


def publish_message(topic_arn, req_data, headers, subscription_arn=None, skip_checks=False):
sns_backend = SNSBackend.get()
message = req_data['Message'][0]
message_id = str(uuid.uuid4())

if topic_arn and ':endpoint/' in topic_arn:
# cache messages published to platform endpoints
cache = PLATFORM_ENDPOINT_MESSAGES[topic_arn] = PLATFORM_ENDPOINT_MESSAGES.get(topic_arn) or []
cache = sns_backend.platform_endpoint_messages[topic_arn] = sns_backend. \
platform_endpoint_messages.get(topic_arn) or []
cache.append(req_data)

LOG.debug('Publishing message to TopicArn: %s | Message: %s' % (topic_arn, message))
Expand All @@ -413,19 +431,21 @@ def publish_message(topic_arn, req_data, headers, subscription_arn=None, skip_ch


def do_delete_topic(topic_arn):
SNS_SUBSCRIPTIONS.pop(topic_arn, None)
SNS_TAGS.pop(topic_arn, None)
sns_backend = SNSBackend.get()
sns_backend.sns_subscriptions.pop(topic_arn, None)
sns_backend.sns_tags.pop(topic_arn, None)


def do_confirm_subscription(topic_arn, token):
for k, v in SUBSCRIPTION_STATUS.items():
sns_backend = SNSBackend.get()
for k, v in sns_backend.subscription_status.items():
if v['Token'] == token and v['TopicArn'] == topic_arn:
v['Status'] = 'Subscribed'


def do_subscribe(topic_arn, endpoint, protocol, subscription_arn, attributes, filter_policy=None):
topic_subs = SNS_SUBSCRIPTIONS[topic_arn] = SNS_SUBSCRIPTIONS.get(topic_arn) or []

sns_backend = SNSBackend.get()
topic_subs = sns_backend.sns_subscriptions[topic_arn] = sns_backend.sns_subscriptions.get(topic_arn) or []
# An endpoint may only be subscribed to a topic once. Subsequent
# subscribe calls do nothing (subscribe is idempotent).
for existing_topic_subscription in topic_subs:
Expand All @@ -443,10 +463,10 @@ def do_subscribe(topic_arn, endpoint, protocol, subscription_arn, attributes, fi
subscription.update(attributes)
topic_subs.append(subscription)

if subscription_arn not in SUBSCRIPTION_STATUS:
SUBSCRIPTION_STATUS[subscription_arn] = {}
if subscription_arn not in sns_backend.subscription_status:
sns_backend.subscription_status[subscription_arn] = {}

SUBSCRIPTION_STATUS[subscription_arn].update(
sns_backend.subscription_status[subscription_arn].update(
{
'TopicArn': topic_arn,
'Token': short_uid(),
Expand All @@ -468,26 +488,29 @@ def do_subscribe(topic_arn, endpoint, protocol, subscription_arn, attributes, fi


def do_unsubscribe(subscription_arn):
for topic_arn, existing_subs in SNS_SUBSCRIPTIONS.items():
SNS_SUBSCRIPTIONS[topic_arn] = [
sns_backend = SNSBackend.get()
for topic_arn, existing_subs in sns_backend.sns_subscriptions.items():
sns_backend.sns_subscriptions[topic_arn] = [
sub for sub in existing_subs
if sub['SubscriptionArn'] != subscription_arn
]


def _get_tags(topic_arn):
if topic_arn not in SNS_TAGS:
SNS_TAGS[topic_arn] = []
sns_backend = SNSBackend.get()
if topic_arn not in sns_backend.sns_tags:
sns_backend.sns_tags[topic_arn] = []

return SNS_TAGS[topic_arn]
return sns_backend.sns_tags[topic_arn]


def do_list_tags_for_resource(topic_arn):
return _get_tags(topic_arn)


def do_tag_resource(topic_arn, tags):
existing_tags = SNS_TAGS.get(topic_arn, [])
sns_backend = SNSBackend.get()
existing_tags = sns_backend.sns_tags.get(topic_arn, [])
tags = [
tag for idx, tag in enumerate(tags)
if tag not in tags[:idx]
Expand All @@ -506,11 +529,12 @@ def existing_tag_index(item):
else:
existing_tags[existing_index] = item

SNS_TAGS[topic_arn] = existing_tags
sns_backend.sns_tags[topic_arn] = existing_tags


def do_untag_resource(topic_arn, tag_keys):
SNS_TAGS[topic_arn] = [t for t in _get_tags(topic_arn) if t['Key'] not in tag_keys]
sns_backend = SNSBackend.get()
sns_backend.sns_tags[topic_arn] = [t for t in _get_tags(topic_arn) if t['Key'] not in tag_keys]


# ---------------
Expand All @@ -519,8 +543,9 @@ def do_untag_resource(topic_arn, tag_keys):


def get_subscription_by_arn(sub_arn):
sns_backend = SNSBackend.get()
# TODO maintain separate map instead of traversing all items
for key, subscriptions in SNS_SUBSCRIPTIONS.items():
for key, subscriptions in sns_backend.sns_subscriptions.items():
for sub in subscriptions:
if sub['SubscriptionArn'] == sub_arn:
return sub
Expand Down
36 changes: 30 additions & 6 deletions tests/integration/test_sns.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from localstack.utils.testutil import check_expected_lambda_log_events_length
from localstack.services.infra import start_proxy
from localstack.services.generic_proxy import ProxyListener
from localstack.services.sns import sns_listener
from localstack.services.sns.sns_listener import SNSBackend
from .lambdas import lambda_integration
from .test_lambda import TEST_LAMBDA_PYTHON, LAMBDA_RUNTIME_PYTHON36, TEST_LAMBDA_LIBS
from localstack.services.install import SQS_BACKEND_IMPL
Expand Down Expand Up @@ -279,13 +279,14 @@ def check_message():

def test_subscribe_platform_endpoint(self):
sns = self.sns_client
sns_backend = SNSBackend.get()
app_arn = sns.create_platform_application(Name='app1', Platform='p1', Attributes={})['PlatformApplicationArn']
platform_arn = sns.create_platform_endpoint(PlatformApplicationArn=app_arn, Token='token_1')['EndpointArn']
subscription = self._publish_sns_message_with_attrs(platform_arn, 'application')

# assert that message has been received
def check_message():
self.assertGreater(len(sns_listener.PLATFORM_ENDPOINT_MESSAGES[platform_arn]), 0)
self.assertGreater(len(sns_backend.platform_endpoint_messages[platform_arn]), 0)
retry(check_message, retries=PUBLICATION_RETRIES, sleep=PUBLICATION_TIMEOUT)

# clean up
Expand Down Expand Up @@ -394,10 +395,11 @@ def test_topic_subscription(self):
Protocol='email',
Endpoint='[email protected]'
)
sns_backend = SNSBackend.get()

def check_subscription():
subscription_arn = subscription['SubscriptionArn']
subscription_obj = sns_listener.SUBSCRIPTION_STATUS[subscription_arn]
subscription_obj = sns_backend.subscription_status[subscription_arn]
self.assertEqual(subscription_obj['Status'], 'Not Subscribed')

_token = subscription_obj['Token']
Expand Down Expand Up @@ -694,6 +696,27 @@ def test_create_duplicate_topic_check_idempotentness(self):
# clean up
self.sns_client.delete_topic(TopicArn=responses[0]['TopicArn'])

def test_create_platform_endpoint_check_idempotentness(self):
response = self.sns_client.create_platform_application(
Name='test-%s' % short_uid(), Platform='GCM', Attributes={'PlatformCredential': '123'}
)
kwargs_list = [{'Token': 'test1', 'CustomUserData': 'test-data'},
{'Token': 'test1', 'CustomUserData': 'test-data'},
{'Token': 'test1'}, {'Token': 'test1'}
]
platform_arn = response['PlatformApplicationArn']
responses = []
for kwargs in kwargs_list:
responses.append(self.sns_client.create_platform_endpoint(PlatformApplicationArn=platform_arn,
**kwargs))
# Assert endpointarn is returned in every call create platform call
for i in range(len(responses)):
self.assertIn('EndpointArn', responses[i])
endpoint_arn = responses[0]['EndpointArn']
# clean up
self.sns_client.delete_endpoint(EndpointArn=endpoint_arn)
self.sns_client.delete_platform_application(PlatformApplicationArn=platform_arn)

def test_publish_by_path_parameters(self):
topic_name = 'topic-{}'.format(short_uid())
queue_name = 'queue-{}'.format(short_uid())
Expand Down Expand Up @@ -765,6 +788,10 @@ def _create_queue(self):
return queue_name, queue_arn, queue_url

def test_publish_sms_endpoint(self):
def check_messages():
sns_backend = SNSBackend.get()
self.assertEqual(len(list_of_contacts), len(sns_backend.sms_messages))

list_of_contacts = [
'+10123456789',
'+10000000000',
Expand All @@ -780,9 +807,6 @@ def test_publish_sms_endpoint(self):
)
# Publish a message.
self.sns_client.publish(Message=message, TopicArn=self.topic_arn)

def check_messages():
self.assertEqual(len(list_of_contacts), len(sns_listener.SMS_MESSAGES))
retry(check_messages, retries=3, sleep=0.5)

def test_publish_sqs_from_sns(self):
Expand Down
Loading

0 comments on commit c236bf2

Please sign in to comment.