-
Notifications
You must be signed in to change notification settings - Fork 38
/
inference.py
72 lines (61 loc) · 2.81 KB
/
inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import argparse
import logging
import sagemaker_containers
import requests
import os
import json
import io
import time
import torch
from transformers import AutoTokenizer, AutoModel
# from sentence_transformers import models, losses, SentenceTransformer
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
#Mean Pooling - Take attention mask into account for correct averaging
def mean_pooling(model_output, attention_mask):
token_embeddings = model_output[0] #First element of model_output contains all token embeddings
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
return sum_embeddings / sum_mask
def embed_tformer(model, tokenizer, sentences):
encoded_input = tokenizer(sentences, padding=True, truncation=True, max_length=256, return_tensors='pt')
#Compute token embeddings
with torch.no_grad():
model_output = model(**encoded_input)
sentence_embeddings = mean_pooling(model_output, encoded_input['attention_mask'])
return sentence_embeddings
def model_fn(model_dir):
logger.info('model_fn')
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(model_dir)
tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/bert-base-nli-mean-tokens")
nlp_model = AutoModel.from_pretrained("sentence-transformers/bert-base-nli-mean-tokens")
nlp_model.to(device)
model = {'model':nlp_model, 'tokenizer':tokenizer}
# model = SentenceTransformer(model_dir + '/transformer/')
# logger.info(model)
return model
# Deserialize the Invoke request body into an object we can perform prediction on
def input_fn(serialized_input_data, content_type='text/plain'):
logger.info('Deserializing the input data.')
try:
data = [serialized_input_data.decode('utf-8')]
return data
except:
raise Exception('Requested unsupported ContentType in content_type: {}'.format(content_type))
# Perform prediction on the deserialized object, with the loaded model
def predict_fn(input_object, model):
logger.info("Calling model")
start_time = time.time()
sentence_embeddings = embed_tformer(model['model'], model['tokenizer'], input_object)
print("--- Inference time: %s seconds ---" % (time.time() - start_time))
response = sentence_embeddings[0].tolist()
return response
# Serialize the prediction result into the desired response content type
def output_fn(prediction, accept):
logger.info('Serializing the generated output.')
if accept == 'application/json':
output = json.dumps(prediction)
return output
raise Exception('Requested unsupported ContentType in Accept: {}'.format(content_type))