-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathmain.py
54 lines (43 loc) · 1.44 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
from fastapi import FastAPI
import tensorflow as tf
from pydantic import BaseModel
app = FastAPI()
# Create a Pydantic model for the POST request
class DataInput(BaseModel):
age: int
level: int
gender: str
user_id: int
# Load the model
path = "model_id"
model = tf.saved_model.load(path)
age_get = tf.constant([18], dtype=tf.int64)
level_get = tf.constant([1], dtype=tf.int64)
user_id_get = tf.constant([14], dtype=tf.int64)
gender_get = tf.constant(["Laki-laki"], dtype=tf.string)
# Pass a user id in, get top predicted movie titles back.
query = {"age": age_get, "gender": gender_get, "level": level_get, "user_id": user_id_get}
@app.get("/")
def hello():
return {"message": "FastAPI TensorFlow Deployment"}
@app.get("/predict")
def predict():
scores, titles = model(query)
titles = titles[0][:3]
titles = titles.numpy().tolist()
return {"res": titles}
@app.post("/predict")
def predict(data: DataInput):
input_data = {
"age": tf.constant([data.age], dtype=tf.int64),
"level": tf.constant([data.level], dtype=tf.int64),
"gender": tf.constant([data.gender], dtype=tf.string),
"user_id": tf.constant([data.user_id], dtype=tf.int64),
}
print(data.dict())
scores, titles = model(input_data)
titles = titles[0][:3]
titles = titles.numpy().tolist()
return {"res": titles}
if __name__ == '__main__':
uvicorn.run(app, host="0.0.0.0", port=port, timeout_keep_alive=1200)