-
Notifications
You must be signed in to change notification settings - Fork 2
/
app.py
189 lines (174 loc) · 6.33 KB
/
app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
import os
import uvicorn
from adapters.registry import get_adapter_registry
from fastapi import APIRouter as FastAPIRouter
from fastapi import FastAPI, Request, status
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse
from fastapi.encoders import jsonable_encoder
from fastapi.middleware.cors import CORSMiddleware
from typing import Any, Callable
from utils import oauth2
from utils.registry import get_util_registry
from wrappers.registry import get_wrapper_registry
adapter_registry = get_adapter_registry()
util_registry = get_util_registry()
wrapper_registry = get_wrapper_registry()
class APIRouter(FastAPIRouter):
def add_api_route(
self,
path: str,
endpoint: Callable,
*,
include_in_schema: bool = True,
**kwargs: Any
) -> None:
if path.endswith("/"):
path = path[:-1]
super().add_api_route(
path=path,
endpoint=endpoint,
include_in_schema=include_in_schema,
**kwargs
)
alternate_path = path + "/"
super().add_api_route(
path=alternate_path,
endpoint=endpoint,
include_in_schema=False,
**kwargs
)
app = FastAPI(
swagger_ui_parameters={
"docExpansion": None,
"filter": True,
"tagsSorter": "alpha"
}
)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# overwritten erro-handling function
@app.exception_handler(RequestValidationError)
async def validation_exception_handler(request: Request, exc: RequestValidationError):
return JSONResponse(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
content=jsonable_encoder({
"detail": exc.errors(),
"body": exc.body,
"string_error": str(exc)
}),
)
router = APIRouter(prefix="/api/admin")
router.add_api_route(
path="/get-backend-status",
endpoint=wrapper_registry.get_backend_status,
methods=["GET"],
tags=["admin"]
)
router.add_api_route(
path="/token",
endpoint=oauth2.login_for_access_token,
methods=["POST"],
response_model=oauth2.Token,
tags=["admin"]
)
app.include_router(router)
for wrapper in wrapper_registry:
for i, prefix in enumerate(wrapper.prefixes):
# use hyphens which is the standard
prefix_with_hyphen = prefix.replace("_", "-")
router = APIRouter(prefix=f"/api/{prefix_with_hyphen}")
# Bind specified wrapper.method to urls /{wrapper.prefixes}/method_name
for method_name, bind_types in wrapper.methods_to_bind.items():
# Only include the first prefix, and methods call_sync/async in the schema
# Exclude the schema for context_recommender, which are a bit messy now
include_in_schema = (
i == 0 and method_name in
["call_sync", "call_async", "call_sync_without_token"]
and "context_recommender" not in prefix
and "count_analogs" not in prefix
and not "descriptors" == prefix
and "evaluate_context" not in prefix
and "fast_filter/with_threshold" not in prefix
and "general_selectivity/gnn" not in prefix
and "general_selectivity/qm_gnn" not in prefix
and "general_selectivity/qm_gnn_no_reagent" not in prefix
and "get_top_class_batch" not in prefix
and "pmi_calculator" not in prefix
and "tree_optimizer" not in prefix
)
method_name_with_hyphen = method_name.replace("_", "-")
router.add_api_route(
path=f"/{method_name_with_hyphen}",
endpoint=getattr(wrapper, method_name),
methods=bind_types,
include_in_schema=include_in_schema,
response_model_by_alias="retro" not in prefix,
tags=[prefix_with_hyphen.split("/")[0]]
)
app.include_router(router)
for adapter in adapter_registry:
# Adapter should have a single prefix, by design
# use hyphens which is the standard
prefix_with_hyphen = adapter.prefix.replace("_", "-")
router = APIRouter(prefix=f"/api/{prefix_with_hyphen}")
if adapter.name in ["v1_celery_task"]:
path = "/{task_id}/"
else:
path = "/"
include_in_schema = adapter.name not in ["v1_descriptors"]
if adapter.name in ["v1_descriptors"]:
# unbind this too; appears to be in conflict with wrapper/legacy_descriptors
continue
router.add_api_route(
path=path,
endpoint=adapter.__call__,
methods=adapter.methods,
include_in_schema=include_in_schema,
tags=["legacy"],
deprecated=True
)
app.include_router(router)
for util in util_registry:
for prefix in util.prefixes:
# use hyphens which is the standard
prefix_with_hyphen = prefix.replace("_", "-")
router = APIRouter(prefix=f"/api/{prefix_with_hyphen}")
# Bind specified util.method to urls /{util.prefixes}/method_name
for method_name, bind_types in util.methods_to_bind.items():
if util.name in ["draw"]:
# hardcode for some util for legacy convention
path = "/"
elif util.name in ["selectivity_refs"]:
# hardcode for some util for legacy convention
path = "/{pk}"
elif method_name == "__call__":
path = "/"
else:
method_name_with_hyphen = method_name.replace("_", "-")
path = f"/{method_name_with_hyphen}"
tag = prefix_with_hyphen.split("/")[0]
if tag in ["user", "api-logging", "status", "frontend-config"]:
tag = "admin"
else:
tag = f"utils/{tag}"
router.add_api_route(
path=path,
endpoint=getattr(util, method_name),
methods=bind_types,
tags=[tag]
)
app.include_router(router)
if __name__ == "__main__":
uvicorn.run(
app,
host="0.0.0.0",
port=9100,
ssl_certfile=os.environ.get("ASKCOS_SSL_CERT_FILE"),
ssl_keyfile=os.environ.get("ASKCOS_SSL_KEY_FILE")
)