|
1 | 1 | from abc import ABC
|
| 2 | + |
| 3 | +from semantic_router import Route |
| 4 | + |
| 5 | +from semantic_routing_core import SemanticRouter |
2 | 6 | import litserve as ls
|
| 7 | +import os |
3 | 8 |
|
4 | 9 |
|
5 | 10 | class SemanticRoutingAPI(ls.LitAPI, ABC):
|
6 | 11 | def __init__(self):
|
7 |
| - pass |
| 12 | + self.semantic_routing_core = None |
| 13 | + # Define routes |
| 14 | + politics = Route( |
| 15 | + name="politics", |
| 16 | + utterances=[ |
| 17 | + "isn't politics the best thing ever", |
| 18 | + "why don't you tell me about your political opinions", |
| 19 | + "don't you just love the president", |
| 20 | + "they're going to destroy this country!", |
| 21 | + "they will save the country!", |
| 22 | + ], |
| 23 | + ) |
| 24 | + |
| 25 | + chitchat = Route( |
| 26 | + name="chitchat", |
| 27 | + utterances=[ |
| 28 | + "how's the weather today?", |
| 29 | + "how are things going?", |
| 30 | + "lovely weather today", |
| 31 | + "the weather is horrendous", |
| 32 | + "let's go to the chippy", |
| 33 | + ], |
| 34 | + ) |
| 35 | + |
| 36 | + self.routes = [politics, chitchat] |
8 | 37 |
|
9 | 38 | def setup(self, device):
|
10 |
| - pass |
| 39 | + self.semantic_routing_core = SemanticRouter() |
| 40 | + # Set up routes |
| 41 | + self.semantic_routing_core.setup_routes(self.routes) |
11 | 42 |
|
12 | 43 | def decode_request(self, request, **kwargs):
|
13 |
| - pass |
| 44 | + return request['question'] |
14 | 45 |
|
15 |
| - def predict(self, x, **kwargs): |
16 |
| - pass |
| 46 | + def predict(self, query, **kwargs): |
| 47 | + return self.semantic_routing_core.route_query(query=query) |
17 | 48 |
|
18 | 49 | def encode_response(self, output, **kwargs):
|
19 |
| - pass |
| 50 | + return {'response': output} |
| 51 | + |
| 52 | + |
| 53 | +if __name__ == '__main__': |
| 54 | + api = SemanticRoutingAPI() |
| 55 | + server = ls.LitServer(lit_api=api, api_path='/api/v1/chat-completion', |
| 56 | + workers_per_device=int(os.environ.get('LIT_SERVER_WORKERS_PER_DEVICE'))) |
| 57 | + server.run(port=os.environ.get('LIT_SERVER_PORT')) |
0 commit comments