diff --git a/flask_lambda.py b/flask_lambda.py index 2e23a16..3422365 100644 --- a/flask_lambda.py +++ b/flask_lambda.py @@ -14,9 +14,7 @@ # License for the specific language governing permissions and limitations # under the License. -import sys -import datetime -from decimal import Decimal +from sys import stderr try: from urllib import urlencode @@ -24,106 +22,141 @@ from urllib.parse import urlencode from flask import Flask +import logging -try: - from cStringIO import StringIO -except ImportError: - try: - from StringIO import StringIO - except ImportError: - from io import StringIO - +import datetime import six -from werkzeug.wrappers import BaseRequest, BaseResponse - - -__version__ = '0.0.4' +import time +import os +import sys +from werkzeug import urls +from werkzeug.wsgi import ClosingIterator +from werkzeug.wrappers import BaseResponse, Response +__version__ = '0.0.5' -class Promise(object): - """ - This is just a base class for the proxy class created in - the closure of the lazy function. It can be used to recognize - promises in code. - """ - pass +def init_logger(lg): + try: + lg.setLevel(logging.getLevelName(str(os.environ.get("LAMBDA_LOG_LEVEL", "INFO")).upper())) + console = logging.StreamHandler(sys.stdout) + fmt = "%(asctime)-15s %(levelname)s %(process)d %(message)s" + datefmt = "%Y-%m-%d %H:%M:%S %z" + formatter = logging.Formatter(fmt, datefmt) + console.setFormatter(formatter) + lg.addHandler(console) + except Exception as e: + print(e) -class _UnicodeDecodeError(UnicodeDecodeError): - def __init__(self, obj, *args): - self.obj = obj - UnicodeDecodeError.__init__(self, *args) - def __str__(self): - original = UnicodeDecodeError.__str__(self) - return '%s. You passed in %r (%s)' % (original, self.obj, - type(self.obj)) +logger = logging.getLogger("lambda") +init_logger(logger) -_PROTECTED_TYPES = six.integer_types + (type(None), float, Decimal, - datetime.datetime, datetime.date, - datetime.time) +def log_get_duration_ms(start): + return int((time.time() - start) * 1000) -def is_protected_type(obj): - """Determine if the object instance is of a protected type. - Objects of protected types are preserved as-is when passed to - force_text(strings_only=True). +def log_lambda(environ, response_status, start_stamp): + try: + status_code = response_status.split(" ")[0] + duration = log_get_duration_ms(start_stamp) + logger.info("method='{}' path='{}' query='{}' remote_addr='{}' status='{}' input='{}' duration={}ms".format( + environ["REQUEST_METHOD"], + environ["PATH_INFO"], + environ["QUERY_STRING"], + environ["REMOTE_ADDR"], + status_code, + datetime.datetime.fromtimestamp(start_stamp).isoformat(), + duration, + )) + except Exception as e: + logger.error(e) + + +def titlecase_keys(d): + """ + Takes a dict with keys of type str and returns a new dict with all keys titlecased. """ - return isinstance(obj, _PROTECTED_TYPES) + return {k.title(): v for k, v in d.items()} -def smart_text(s, encoding='utf-8', strings_only=False, errors='strict'): +def get_wsgi_string(string, encoding='utf-8'): """ - Returns a text object representing 's' -- unicode on Python 2 and str on - Python 3. Treats bytestrings using the 'encoding' codec. - If strings_only is True, don't convert (some) non-string-like objects. + Returns wsgi-compatible string """ - if isinstance(s, Promise): - # The input is the result of a gettext_lazy() call. - return s - return force_text(s, encoding, strings_only, errors) + return string.encode(encoding).decode('iso-8859-1') -def force_text(s, encoding='utf-8', strings_only=False, errors='strict'): +def all_casings(input_string): """ - Similar to smart_text, except that lazy instances are resolved to - strings, rather than kept as lazy objects. - If strings_only is True, don't convert (some) non-string-like objects. + Permute all casings of a given string. + + A pretty algorithm, via @Amber + http://stackoverflow.com/questions/6792803/finding-all-possible-case-permutations-in-python """ - # Handle the common case first for performance reasons. - if issubclass(type(s), six.text_type): - return s - if strings_only and is_protected_type(s): - return s - try: - if not issubclass(type(s), six.string_types): - if six.PY3: - if isinstance(s, bytes): - s = six.text_type(s, encoding, errors) - else: - s = six.text_type(s) - elif hasattr(s, '__unicode__'): - s = six.text_type(s) - else: - s = six.text_type(bytes(s), encoding, errors) - else: - # Note: We use .decode() here, instead of six.text_type(s, encoding, - # errors), so that if s is a SafeBytes, it ends up being a - # SafeText at the end. - s = s.decode(encoding, errors) - except UnicodeDecodeError as e: - if not isinstance(s, Exception): - raise _UnicodeDecodeError(s, *e.args) + if not input_string: + yield "" + else: + first = input_string[:1] + if first.lower() == first.upper(): + for sub_casing in all_casings(input_string[1:]): + yield first + sub_casing else: - # If we get to here, the caller has passed in an Exception - # subclass populated with non-ASCII bytestring data without a - # working unicode method. Try to handle this without raising a - # further exception by individually forcing the exception args - # to unicode. - s = ' '.join(force_text(arg, encoding, strings_only, errors) - for arg in s) - return s + for sub_casing in all_casings(input_string[1:]): + yield first.lower() + sub_casing + yield first.upper() + sub_casing + + +class WSGIMiddleware(object): + + def __init__(self, application): + self.application = application + + def __call__(self, environ, start_response): + """ + We must case-mangle the Set-Cookie header name or AWS will use only a + single one of these headers. + """ + + def encode_response(status, headers, exc_info=None): + """ + Create an APIGW-acceptable version of our cookies. + + We have to use a bizarre hack that turns multiple Set-Cookie headers into + their case-permutated format, ex: + + Set-cookie: + sEt-cookie: + seT-cookie: + + To get around an API Gateway limitation. + + This is weird, but better than our previous hack of creating a Base58-encoded + supercookie. + """ + + # All the non-cookie headers should be sent unharmed. + + # The main app can send 'set-cookie' headers in any casing + # Related: https://github.com/Miserlou/Zappa/issues/990 + new_headers = [header for header in headers + if ((type(header[0]) != str) or (header[0].lower() != 'set-cookie'))] + cookie_headers = [header for header in headers + if ((type(header[0]) == str) and (header[0].lower() == "set-cookie"))] + for header, new_name in zip(cookie_headers, + all_casings("Set-Cookie")): + new_headers.append((new_name, header[1])) + ret = start_response(status, new_headers, exc_info) + log_lambda(environ, status, start_stamp) + return ret + + start_stamp = time.time() + # Call the application with our modifier + response = self.application(environ, encode_response) + + # Return the response as a WSGI-safe iterator + return ClosingIterator(response) class HealthCheckMiddleware(object): @@ -138,84 +171,135 @@ def __call__(self, environ, start_response): return self.app(environ, start_response) -def make_environ(event): - environ = {} +def create_wsgi_request(event_info): + """ + Given some event_info via API Gateway, + create and return a valid WSGI request environ. + """ + method = event_info['httpMethod'] + params = event_info.get('pathParameters') - for hdr_name, hdr_value in event['headers'].items(): - hdr_name = hdr_name.replace('-', '_').upper() - if hdr_name in ['CONTENT_TYPE', 'CONTENT_LENGTH']: - environ[hdr_name] = hdr_value - continue + """ + API Gateway and ALB both started allowing for multi-value querystring + params in Nov. 2018. If there aren't multi-value params present, then + it acts identically to 'queryStringParameters', so we can use it as a + drop-in replacement. + The one caveat here is that ALB will only include _one_ of + queryStringParameters _or_ multiValueQueryStringParameters, which means + we have to check for the existence of one and then fall back to the + other. + """ + if 'multiValueQueryStringParameters' in event_info: + query = event_info['multiValueQueryStringParameters'] + query_string = urlencode(query, doseq=True) if query else '' + else: + query = event_info.get('queryStringParameters', {}) + query_string = urlencode(query) if query else '' + headers = event_info['headers'] or {} + + # Extract remote user from context if Authorizer is enabled + remote_user = None + if event_info['requestContext'].get('authorizer'): + remote_user = event_info['requestContext']['authorizer'].get('principalId') + elif event_info['requestContext'].get('identity'): + remote_user = event_info['requestContext']['identity'].get('userArn') + + body = event_info['body'] + if isinstance(body, six.string_types): + body = body.encode("utf-8") + + # Make header names canonical, e.g. content-type => Content-Type + # https://github.com/Miserlou/Zappa/issues/1188 + headers = titlecase_keys(headers) + + path = urls.url_unquote(event_info['path']) + + x_forwarded_for = headers.get('X-Forwarded-For', '') + if ',' in x_forwarded_for: + # The last one is the cloudfront proxy ip. The second to last is the real client ip. + # Everything else is user supplied and untrustworthy. + remote_addr = x_forwarded_for.split(', ')[-2] + else: + remote_addr = '127.0.0.1' + + environ = { + 'PATH_INFO': get_wsgi_string(path), + 'QUERY_STRING': get_wsgi_string(query_string), + 'REMOTE_ADDR': remote_addr, + 'REQUEST_METHOD': method, + 'SCRIPT_NAME': '', + 'SERVER_NAME': '', + 'SERVER_PORT': headers.get('X-Forwarded-Port', '80'), + 'SERVER_PROTOCOL': str('HTTP/1.1'), + 'wsgi.version': (1, 0), + 'wsgi.url_scheme': headers.get('X-Forwarded-Proto', 'http'), + 'wsgi.input': body, + 'wsgi.errors': stderr, + 'wsgi.multiprocess': False, + 'wsgi.multithread': False, + 'wsgi.run_once': False, + } - http_hdr_name = 'HTTP_%s' % hdr_name - environ[http_hdr_name] = hdr_value + # Input processing + if method in ["POST", "PUT", "PATCH", "DELETE"]: + if 'Content-Type' in headers: + environ['CONTENT_TYPE'] = headers['Content-Type'] - qs = event['queryStringParameters'] + # This must be Bytes or None + environ['wsgi.input'] = six.BytesIO(body) + if body: + environ['CONTENT_LENGTH'] = str(len(body)) + else: + environ['CONTENT_LENGTH'] = '0' - environ['REQUEST_METHOD'] = event['httpMethod'] - environ['PATH_INFO'] = event['path'] - environ['QUERY_STRING'] = urlencode(qs) if qs else '' - environ['REMOTE_ADDR'] = event['requestContext']['identity']['sourceIp'] - environ['HOST'] = '%(HTTP_HOST)s:%(HTTP_X_FORWARDED_PORT)s' % environ - environ['SCRIPT_NAME'] = '' + for header in headers: + wsgi_name = "HTTP_" + header.upper().replace('-', '_') + environ[wsgi_name] = str(headers[header]) - environ['SERVER_PORT'] = environ['HTTP_X_FORWARDED_PORT'] - environ['SERVER_PROTOCOL'] = 'HTTP/1.1' + if remote_user: + environ['REMOTE_USER'] = remote_user - body = event.get('body', '') - body = smart_text(body) - environ['CONTENT_LENGTH'] = str( - len(body) if body else '' - ) + if event_info['requestContext'].get('authorizer'): + environ['API_GATEWAY_AUTHORIZER'] = event_info['requestContext']['authorizer'] - environ['wsgi.url_scheme'] = environ['HTTP_X_FORWARDED_PROTO'] - environ['wsgi.input'] = StringIO(body or '') - environ['wsgi.version'] = (1, 0) - environ['wsgi.errors'] = sys.stderr - environ['wsgi.multithread'] = False - environ['wsgi.run_once'] = True - environ['wsgi.multiprocess'] = False + return environ - BaseRequest(environ) - return environ +def _call(self, event, context): + # This is a normal HTTP request + self.wsgi_app_wrapped = WSGIMiddleware(HealthCheckMiddleware(self.wsgi_app)) + if not event.get('httpMethod', None): + return self.wsgi_app_wrapped(event, context) -class LambdaResponse(object): - def __init__(self): - self.status = None - self.response_headers = None + # Create the environment for WSGI and handle the request + environ = create_wsgi_request(event) - def start_response(self, status, response_headers, exc_info=None): - self.status = int(status[:3]) - self.response_headers = dict(response_headers) + # We are always on https on Lambda, so tell our wsgi app that. + environ['HTTPS'] = 'on' + environ['wsgi.url_scheme'] = 'https' + environ['lambda.context'] = context + environ['lambda.event'] = event + # Execute the application + with Response.from_app(self.wsgi_app_wrapped, environ) as response: + # This is the object we're going to return. + # Pack the WSGI response into our special dictionary. + returndict = dict() -def _call(self, event, context): - self.wsgi_app = HealthCheckMiddleware(self.wsgi_app) - - if 'httpMethod' not in event: - # In this "context" `event` is `environ` and - # `context` is `start_response`, meaning the request didn't - # occur via API Gateway and Lambda - return self.wsgi_app(event, context) - - response = LambdaResponse() - - body = next(self.wsgi_app( - make_environ(event), - response.start_response - )) - body = smart_text(body) - - return { - 'statusCode': response.status, - 'headers': response.response_headers, - 'body': body - } + if response.data: + returndict['body'] = response.get_data(as_text=True) + + returndict['statusCode'] = response.status_code + returndict['headers'] = {} + for key, value in response.headers: + returndict['headers'][key] = value + + return returndict class FlaskLambda(Flask): + def __call__(self, event, context): return _call(self, event, context) @@ -226,4 +310,4 @@ def __init__(self, app): self.wsgi_app = app def __call__(self, event, context): - return _call(self, event, context) \ No newline at end of file + return _call(self, event, context)