Skip to content
This repository was archived by the owner on Aug 14, 2024. It is now read-only.

Commit 5eab8f6

Browse files
authored
Support custom SQS queue attributes (localstack#1520)
1 parent 6f5f509 commit 5eab8f6

File tree

3 files changed

+197
-97
lines changed

3 files changed

+197
-97
lines changed

localstack/services/sqs/sqs_listener.py

+154-91
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515

1616
XMLNS_SQS = 'http://queue.amazonaws.com/doc/2012-11-05/'
1717

18-
1918
SUCCESSFUL_SEND_MESSAGE_XML_TEMPLATE = """
2019
<?xml version="1.0"?>
2120
<SendMessageResponse xmlns="%s">
@@ -30,32 +29,29 @@
3029
</SendMessageResponse>
3130
""".strip() % XMLNS_SQS
3231

32+
# list of valid attribute names, and names not supported by the backend (elasticmq)
33+
VALID_ATTRIBUTE_NAMES = ['DelaySeconds', 'MaximumMessageSize', 'MessageRetentionPeriod',
34+
'Policy', 'ReceiveMessageWaitTimeSeconds', 'RedrivePolicy', 'VisibilityTimeout']
35+
UNSUPPORTED_ATTRIBUTE_NAMES = ['MaximumMessageSize', 'MessageRetentionPeriod', 'Policy', 'RedrivePolicy']
36+
37+
# maps queue URLs to attributes set via the API
38+
QUEUE_ATTRIBUTES = {}
39+
3340

3441
class ProxyListenerSQS(ProxyListener):
3542

3643
def forward_request(self, method, path, data, headers):
3744
req_data = self.parse_request_data(method, path, data)
3845

3946
if req_data:
40-
if req_data.get('Action', [None])[0] == 'SendMessage':
41-
queue_url = req_data.get('QueueUrl', [path.partition('?')[0]])[0]
42-
queue_name = queue_url[queue_url.rindex('/') + 1:]
43-
message_body = req_data.get('MessageBody', [None])[0]
44-
message_attributes = self.format_message_attributes(req_data)
45-
region_name = extract_region_from_auth_header(headers)
46-
47-
process_result = lambda_api.process_sqs_message(message_body,
48-
message_attributes, queue_name, region_name=region_name)
49-
if process_result:
50-
# If a Lambda was listening, do not add the message to the queue
51-
new_response = Response()
52-
new_response._content = SUCCESSFUL_SEND_MESSAGE_XML_TEMPLATE.format(
53-
message_attr_hash=md5(data),
54-
message_body_hash=md5(message_body),
55-
message_id=str(uuid.uuid4())
56-
)
57-
new_response.status_code = 200
47+
action = req_data.get('Action', [None])[0]
48+
if action == 'SendMessage':
49+
new_response = self._send_message(path, data, req_data, headers)
50+
if new_response:
5851
return new_response
52+
elif action == 'SetQueueAttributes':
53+
self._set_queue_attributes(req_data)
54+
5955
if 'QueueName' in req_data:
6056
encoded_data = urlencode(req_data, doseq=True) if method == 'POST' else ''
6157
modified_url = None
@@ -76,6 +72,73 @@ def parse_request_data(self, method, path, data):
7672
return urlparse.parse_qs(parsed_path.query)
7773
return {}
7874

75+
def return_response(self, method, path, data, headers, response, request_handler):
76+
if method == 'OPTIONS' and path == '/':
77+
# Allow CORS preflight requests to succeed.
78+
return 200
79+
80+
if method == 'POST' and path == '/':
81+
region_name = extract_region_from_auth_header(headers)
82+
req_data = urlparse.parse_qs(to_str(data))
83+
action = req_data.get('Action', [None])[0]
84+
content_str = content_str_original = to_str(response.content)
85+
86+
self._fire_event(req_data, response)
87+
88+
# patch the response and add missing attributes
89+
if action == 'GetQueueAttributes':
90+
content_str = self._add_queue_attributes(req_data, content_str)
91+
92+
# patch the response and return the correct endpoint URLs / ARNs
93+
if action in ('CreateQueue', 'GetQueueUrl', 'ListQueues', 'GetQueueAttributes'):
94+
if config.USE_SSL and '<QueueUrl>http://' in content_str:
95+
# return https://... if we're supposed to use SSL
96+
content_str = re.sub(r'<QueueUrl>\s*http://', r'<QueueUrl>https://', content_str)
97+
# expose external hostname:port
98+
external_port = SQS_PORT_EXTERNAL or get_external_port(headers, request_handler)
99+
content_str = re.sub(r'<QueueUrl>\s*([a-z]+)://[^<]*:([0-9]+)/([^<]*)\s*</QueueUrl>',
100+
r'<QueueUrl>\1://%s:%s/\3</QueueUrl>' % (HOSTNAME_EXTERNAL, external_port), content_str)
101+
# fix queue ARN
102+
content_str = re.sub(r'<([a-zA-Z0-9]+)>\s*arn:aws:sqs:elasticmq:([^<]+)</([a-zA-Z0-9]+)>',
103+
r'<\1>arn:aws:sqs:%s:\2</\3>' % (region_name), content_str)
104+
105+
if content_str_original != content_str:
106+
# if changes have been made, return patched response
107+
new_response = Response()
108+
new_response.status_code = response.status_code
109+
new_response.headers = response.headers
110+
new_response._content = content_str
111+
new_response.headers['content-length'] = len(new_response._content)
112+
return new_response
113+
114+
# Since the following 2 API calls are not implemented in ElasticMQ, we're mocking them
115+
# and letting them to return an empty response
116+
if action == 'TagQueue':
117+
new_response = Response()
118+
new_response.status_code = 200
119+
new_response._content = ("""
120+
<?xml version="1.0"?>
121+
<TagQueueResponse>
122+
<ResponseMetadata>
123+
<RequestId>{}</RequestId>
124+
</ResponseMetadata>
125+
</TagQueueResponse>
126+
""").strip().format(uuid.uuid4())
127+
return new_response
128+
elif action == 'ListQueueTags':
129+
new_response = Response()
130+
new_response.status_code = 200
131+
new_response._content = ("""
132+
<?xml version="1.0"?>
133+
<ListQueueTagsResponse xmlns="{}">
134+
<ListQueueTagsResult/>
135+
<ResponseMetadata>
136+
<RequestId>{}</RequestId>
137+
</ResponseMetadata>
138+
</ListQueueTagsResponse>
139+
""").strip().format(XMLNS_SQS, uuid.uuid4())
140+
return new_response
141+
79142
# Format of the message Name attribute is MessageAttribute.<int id>.<field>
80143
# Format of the Value attributes is MessageAttribute.<int id>.Value.DataType
81144
# and MessageAttribute.<int id>.Value.<Type>Value
@@ -115,10 +178,10 @@ def parse_request_data(self, method, path, data):
115178
# dataType: 'String'
116179
# }
117180
# }
118-
119181
def format_message_attributes(self, data):
182+
prefix = 'MessageAttribute'
120183
names = []
121-
for (k, name) in [(k, data[k]) for k in data if k.startswith('MessageAttribute') and k.endswith('.Name')]:
184+
for (k, name) in [(k, data[k]) for k in data if k.startswith(prefix) and k.endswith('.Name')]:
122185
attr_name = name[0]
123186
k_id = k.split('.')[1]
124187
names.append((attr_name, k_id))
@@ -128,7 +191,7 @@ def format_message_attributes(self, data):
128191
msg_attrs[key_name] = {}
129192
# Find vals for each key_id
130193
attrs = [(k, data[k]) for k in data
131-
if k.startswith('MessageAttribute.{}.'.format(key_id)) and not k.endswith('.Name')]
194+
if k.startswith('{}.{}.'.format(prefix, key_id)) and not k.endswith('.Name')]
132195
for (attr_k, attr_v) in attrs:
133196
attr_name = attr_k.split('.')[3]
134197
msg_attrs[key_name][attr_name[0].lower() + attr_name[1:]] = attr_v[0]
@@ -141,78 +204,78 @@ def format_message_attributes(self, data):
141204

142205
return msg_attrs
143206

144-
def return_response(self, method, path, data, headers, response, request_handler):
145-
if method == 'OPTIONS' and path == '/':
146-
# Allow CORS preflight requests to succeed.
147-
return 200
207+
# Format attributes as dict. Example input:
208+
# {
209+
# 'Attribute.1.Name': ['Policy'],
210+
# 'Attribute.1.Value': ['...']
211+
# }
212+
def _format_attributes(self, req_data):
213+
result = {}
214+
for i in range(1, 500):
215+
key1 = 'Attribute.%s.Name' % i
216+
key2 = 'Attribute.%s.Value' % i
217+
if key1 not in req_data:
218+
break
219+
key_name = req_data[key1][0]
220+
key_value = req_data[key2][0]
221+
result[key_name] = key_value
222+
return result
148223

149-
if method == 'POST' and path == '/':
150-
region_name = extract_region_from_auth_header(headers)
151-
req_data = urlparse.parse_qs(to_str(data))
152-
action = req_data.get('Action', [None])[0]
153-
event_type = None
154-
queue_url = None
155-
if action == 'CreateQueue':
156-
event_type = event_publisher.EVENT_SQS_CREATE_QUEUE
157-
response_data = xmltodict.parse(response.content)
158-
if 'CreateQueueResponse' in response_data:
159-
queue_url = response_data['CreateQueueResponse']['CreateQueueResult']['QueueUrl']
160-
elif action == 'DeleteQueue':
161-
event_type = event_publisher.EVENT_SQS_DELETE_QUEUE
162-
queue_url = req_data.get('QueueUrl', [None])[0]
163-
164-
if event_type and queue_url:
165-
event_publisher.fire_event(event_type, payload={'u': event_publisher.get_hash(queue_url)})
224+
def _send_message(self, path, data, req_data, headers):
225+
queue_url = req_data.get('QueueUrl', [path.partition('?')[0]])[0]
226+
queue_name = queue_url[queue_url.rindex('/') + 1:]
227+
message_body = req_data.get('MessageBody', [None])[0]
228+
message_attributes = self.format_message_attributes(req_data)
229+
region_name = extract_region_from_auth_header(headers)
166230

167-
# patch the response and return the correct endpoint URLs / ARNs
168-
if action in ('CreateQueue', 'GetQueueUrl', 'ListQueues', 'GetQueueAttributes'):
169-
content_str = content_str_original = to_str(response.content)
170-
new_response = Response()
171-
new_response.status_code = response.status_code
172-
new_response.headers = response.headers
173-
if config.USE_SSL and '<QueueUrl>http://' in content_str:
174-
# return https://... if we're supposed to use SSL
175-
content_str = re.sub(r'<QueueUrl>\s*http://', r'<QueueUrl>https://', content_str)
176-
# expose external hostname:port
177-
external_port = SQS_PORT_EXTERNAL or get_external_port(headers, request_handler)
178-
content_str = re.sub(r'<QueueUrl>\s*([a-z]+)://[^<]*:([0-9]+)/([^<]*)\s*</QueueUrl>',
179-
r'<QueueUrl>\1://%s:%s/\3</QueueUrl>' % (HOSTNAME_EXTERNAL, external_port), content_str)
180-
# fix queue ARN
181-
content_str = re.sub(r'<([a-zA-Z0-9]+)>\s*arn:aws:sqs:elasticmq:([^<]+)</([a-zA-Z0-9]+)>',
182-
r'<\1>arn:aws:sqs:%s:\2</\3>' % (region_name), content_str)
183-
new_response._content = content_str
184-
if content_str_original != new_response._content:
185-
# if changes have been made, return patched response
186-
new_response.headers['content-length'] = len(new_response._content)
187-
return new_response
231+
process_result = lambda_api.process_sqs_message(message_body,
232+
message_attributes, queue_name, region_name=region_name)
233+
if process_result:
234+
# If a Lambda was listening, do not add the message to the queue
235+
new_response = Response()
236+
new_response._content = SUCCESSFUL_SEND_MESSAGE_XML_TEMPLATE.format(
237+
message_attr_hash=md5(data),
238+
message_body_hash=md5(message_body),
239+
message_id=str(uuid.uuid4())
240+
)
241+
new_response.status_code = 200
242+
return new_response
188243

189-
# Since the following 2 API calls are not implemented in ElasticMQ, we're mocking them
190-
# and letting them to return an empty response
191-
if action == 'TagQueue':
192-
new_response = Response()
193-
new_response.status_code = 200
194-
new_response._content = ("""
195-
<?xml version="1.0"?>
196-
<TagQueueResponse>
197-
<ResponseMetadata>
198-
<RequestId>{}</RequestId>
199-
</ResponseMetadata>
200-
</TagQueueResponse>
201-
""").strip().format(uuid.uuid4())
202-
return new_response
203-
elif action == 'ListQueueTags':
204-
new_response = Response()
205-
new_response.status_code = 200
206-
new_response._content = ("""
207-
<?xml version="1.0"?>
208-
<ListQueueTagsResponse xmlns="{}">
209-
<ListQueueTagsResult/>
210-
<ResponseMetadata>
211-
<RequestId>{}</RequestId>
212-
</ResponseMetadata>
213-
</ListQueueTagsResponse>
214-
""").strip().format(XMLNS_SQS, uuid.uuid4())
215-
return new_response
244+
def _set_queue_attributes(self, req_data):
245+
queue_url = req_data['QueueUrl'][0]
246+
attrs = self._format_attributes(req_data)
247+
# select only the attributes in UNSUPPORTED_ATTRIBUTE_NAMES
248+
attrs = dict([(k, v) for k, v in attrs.items() if k in UNSUPPORTED_ATTRIBUTE_NAMES])
249+
QUEUE_ATTRIBUTES[queue_url] = QUEUE_ATTRIBUTES.get(queue_url) or {}
250+
QUEUE_ATTRIBUTES[queue_url].update(attrs)
251+
252+
def _add_queue_attributes(self, req_data, content_str):
253+
flags = re.MULTILINE | re.DOTALL
254+
queue_url = req_data['QueueUrl'][0]
255+
regex = r'(.*<GetQueueAttributesResult>)(.*)(</GetQueueAttributesResult>.*)'
256+
attrs = re.sub(regex, r'\2', content_str, flags=flags)
257+
for key, value in QUEUE_ATTRIBUTES.get(queue_url, {}).items():
258+
if not re.match(r'<Name>\s*%s\s*</Name>' % key, attrs, flags=flags):
259+
attrs += '<Attribute><Name>%s</Name><Value>%s</Value></Attribute>' % (key, value)
260+
content_str = (re.sub(regex, r'\1', content_str, flags=flags) +
261+
attrs + re.sub(regex, r'\3', content_str, flags=flags))
262+
return content_str
263+
264+
def _fire_event(self, req_data, response):
265+
action = req_data.get('Action', [None])[0]
266+
event_type = None
267+
queue_url = None
268+
if action == 'CreateQueue':
269+
event_type = event_publisher.EVENT_SQS_CREATE_QUEUE
270+
response_data = xmltodict.parse(response.content)
271+
if 'CreateQueueResponse' in response_data:
272+
queue_url = response_data['CreateQueueResponse']['CreateQueueResult']['QueueUrl']
273+
elif action == 'DeleteQueue':
274+
event_type = event_publisher.EVENT_SQS_DELETE_QUEUE
275+
queue_url = req_data.get('QueueUrl', [None])[0]
276+
277+
if event_type and queue_url:
278+
event_publisher.fire_event(event_type, payload={'u': event_publisher.get_hash(queue_url)})
216279

217280

218281
# extract the external port used by the client to make the request

requirements.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ jsonpath-rw==1.4.0
2828
localstack-ext[full]>=0.10.2
2929
localstack-ext>=0.10.2 #basic-lib
3030
localstack-client>=0.10 #basic-lib
31-
moto-ext==1.3.14.dev0
31+
moto-ext>=1.3.14.1
3232
nose>=1.3.7
3333
nose-timer>=0.7.5
3434
psutil==5.4.8

tests/integration/test_sqs.py

+42-5
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,64 @@
11
import unittest
22
from localstack.utils.aws import aws_stack
3+
from localstack.utils.common import short_uid
34

45

56
TEST_QUEUE_NAME = 'TestQueue'
67

8+
TEST_POLICY = """
9+
{
10+
"Version":"2012-10-17",
11+
"Statement":[
12+
{
13+
"Effect": "Allow",
14+
"Principal": { "AWS": "*" },
15+
"Action": "sqs:SendMessage",
16+
"Resource": "'$sqs_queue_arn'",
17+
"Condition":{
18+
"ArnEquals":{
19+
"aws:SourceArn":"'$sns_topic_arn'"
20+
}
21+
}
22+
}
23+
]
24+
}
25+
"""
26+
727

828
class SQSTest(unittest.TestCase):
29+
@classmethod
30+
def setUpClass(cls):
31+
cls.client = aws_stack.connect_to_service('sqs')
32+
933
def test_list_queue_tags(self):
1034
# Since this API call is not implemented in ElasticMQ, we're mocking it
1135
# and letting it return an empty response
12-
sqs_client = aws_stack.connect_to_service('sqs')
13-
queue_info = sqs_client.create_queue(QueueName=TEST_QUEUE_NAME)
36+
queue_info = self.client.create_queue(QueueName=TEST_QUEUE_NAME)
1437
queue_url = queue_info['QueueUrl']
15-
res = sqs_client.list_queue_tags(QueueUrl=queue_url)
38+
res = self.client.list_queue_tags(QueueUrl=queue_url)
1639

1740
# Apparently, if there are no tags, then `Tags` should NOT appear in the response.
1841
assert 'Tags' not in res
1942

2043
def test_create_fifo_queue(self):
2144
fifo_queue = 'my-queue.fifo'
22-
sqs_client = aws_stack.connect_to_service('sqs')
23-
queue_info = sqs_client.create_queue(QueueName=fifo_queue, Attributes={'FifoQueue': 'true'})
45+
queue_info = self.client.create_queue(QueueName=fifo_queue, Attributes={'FifoQueue': 'true'})
2446
queue_url = queue_info['QueueUrl']
2547

2648
# it should preserve .fifo in the queue name
2749
self.assertIn(fifo_queue, queue_url)
50+
51+
def test_set_queue_policy(self):
52+
fifo_queue = 'queue-%s' % short_uid()
53+
queue_info = self.client.create_queue(QueueName=fifo_queue)
54+
queue_url = queue_info['QueueUrl']
55+
56+
attributes = {
57+
'Policy': TEST_POLICY
58+
}
59+
self.client.set_queue_attributes(QueueUrl=queue_url, Attributes=attributes)
60+
61+
attrs = self.client.get_queue_attributes(QueueUrl=queue_url, AttributeNames=['All'])['Attributes']
62+
self.assertIn('sqs:SendMessage', attrs['Policy'])
63+
attrs = self.client.get_queue_attributes(QueueUrl=queue_url, AttributeNames=['Policy'])['Attributes']
64+
self.assertIn('sqs:SendMessage', attrs['Policy'])

0 commit comments

Comments
 (0)