-
Notifications
You must be signed in to change notification settings - Fork 8
/
handler.py
103 lines (85 loc) · 3.8 KB
/
handler.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
'''
Contains handler for AWS lambda function. First, the precompiled dependencies (such as TensorFlow)
are added. GAN model instance is created at the beginning according to AWS lambda best practices
(see http://docs.aws.amazon.com/lambda/latest/dg/best-practices.html).
Next, the saved model is imported from S3 bucket. After all those actions, we are
ready to make predictions with our lambda function
'''
import os
import sys
import json
'''
This is needed so that the script running on AWS will pick up the pre-compiled dependencies
from the vendored folder
'''
current_location = os.path.dirname(os.path.realpath(__file__))
sys.path.append(os.path.join(current_location, 'vendored'))
'''
The following imports must be placed after picking up of pre-compiled dependencies
'''
from gan_model import GANModel
import utils
'''
Declare global objects living across requests
'''
model_dir = utils.create_model_dir()
utils.download_model_from_bucket(model_dir)
gan_model = GANModel(model_dir)
def get_param_from_url(event, param_name):
'''
Retrieve query parameters from a Lambda call. Parameters are passed through the
event object as a dictionary. We interested in 'queryStringParameters', since
the bucket name and the key are passed in the query string
:param event: the event as input in the Lambda function
:param param_name: the name of the parameter in the query string
:return: parameter value or None if the parameter is not in the event dictionary
'''
params = event['queryStringParameters']
if param_name in params:
return params[param_name]
return None
def lambda_gateway_response(code, body):
'''
This function wraps the endpoint responses. We have to return HTTP response:
status code, content-type in header and body
:param code: HTTP response code, must be integer
:param body: response body as JSON
'''
return {"statusCode": code,
"headers": {"Content-Type": "application/json"},
"body": json.dumps(body)}
def predict(event, context):
'''
The function is called by AWS Lambda:
{LambdaURL}/{stage}/predict?bucket=vb-tf-aws-lambda-images&key=image_3.jpg
{LambdaURL} is Lambda URL as returned by serveless installation and {stage} is set in the
serverless.yml file.
:param event: AWS Lambda uses this parameter to pass in event data to the handler.
We are expecting a Python dict here.
:param context: AWS Lambda uses this parameter to provide runtime information
to the handler. This parameter is of the LambdaContext type.
:return: JSON with status code and result JSON object
'''
try:
# extract S3 bucket name and key from the event - they defines the
# image for prediction
bucket_name = get_param_from_url(event, 'bucket')
key = get_param_from_url(event, 'key')
print 'Predict function called! Bucket/key is {}/{}'.format(bucket_name, key)
if bucket_name and key:
# download the image from S3 bucket and call prediction on it
image = utils.download_image_from_bucket(bucket_name, key)
results = gan_model.predict(image)
results_json = [{'digit': str(res[0]),
'probability': str(res[1])} for res in results]
print 'Results retrieved: {}'.format(results_json)
else:
message = 'Input parameters are invalid: bucket name and key must be specified'
return lambda_gateway_response(400, {'message': message})
except Exception as exception:
error_response = {
'error_message': "Unexpected error",
'stack_trace': str(exception)
}
return lambda_gateway_response(503, error_response)
return lambda_gateway_response(200, {'prediction_result': results_json})