|
| 1 | +import json |
| 2 | +from logging import getLogger |
| 3 | +from pathlib import Path |
| 4 | +from fastapi import FastAPI |
| 5 | +from typing import Dict, Optional |
| 6 | +from llama_agents import CallableMessageConsumer, QueueMessage |
| 7 | +from llama_agents.message_queues.base import BaseMessageQueue |
| 8 | +from llama_agents.message_consumers.base import BaseMessageQueueConsumer |
| 9 | +from llama_agents.message_consumers.remote import RemoteMessageConsumer |
| 10 | +from app.utils import load_from_env |
| 11 | +from app.core.message_queue import message_queue |
| 12 | + |
| 13 | + |
| 14 | +logger = getLogger(__name__) |
| 15 | + |
| 16 | + |
| 17 | +class TaskResultService: |
| 18 | + def __init__( |
| 19 | + self, |
| 20 | + message_queue: BaseMessageQueue, |
| 21 | + name: str = "human", |
| 22 | + host: str = "127.0.0.1", |
| 23 | + port: Optional[int] = 8002, |
| 24 | + ) -> None: |
| 25 | + self.name = name |
| 26 | + self.host = host |
| 27 | + self.port = port |
| 28 | + |
| 29 | + self._message_queue = message_queue |
| 30 | + |
| 31 | + # app |
| 32 | + self._app = FastAPI() |
| 33 | + self._app.add_api_route( |
| 34 | + "/", self.home, methods=["GET"], tags=["Human Consumer"] |
| 35 | + ) |
| 36 | + self._app.add_api_route( |
| 37 | + "/process_message", |
| 38 | + self.process_message, |
| 39 | + methods=["POST"], |
| 40 | + tags=["Human Consumer"], |
| 41 | + ) |
| 42 | + |
| 43 | + @property |
| 44 | + def message_queue(self) -> BaseMessageQueue: |
| 45 | + return self._message_queue |
| 46 | + |
| 47 | + def as_consumer(self, remote: bool = False) -> BaseMessageQueueConsumer: |
| 48 | + if remote: |
| 49 | + return RemoteMessageConsumer( |
| 50 | + url=( |
| 51 | + f"http://{self.host}:{self.port}/process_message" |
| 52 | + if self.port |
| 53 | + else f"http://{self.host}/process_message" |
| 54 | + ), |
| 55 | + message_type=self.name, |
| 56 | + ) |
| 57 | + |
| 58 | + return CallableMessageConsumer( |
| 59 | + message_type=self.name, |
| 60 | + handler=self.process_message, |
| 61 | + ) |
| 62 | + |
| 63 | + async def process_message(self, message: QueueMessage) -> None: |
| 64 | + Path("task_results").mkdir(exist_ok=True) |
| 65 | + with open("task_results/task_results.json", "+a") as f: |
| 66 | + json.dump(message.model_dump(), f) |
| 67 | + f.write("\n") |
| 68 | + |
| 69 | + async def home(self) -> Dict[str, str]: |
| 70 | + return {"message": "hello, human."} |
| 71 | + |
| 72 | + async def register_to_message_queue(self) -> None: |
| 73 | + """Register to the message queue.""" |
| 74 | + await self.message_queue.register_consumer(self.as_consumer(remote=True)) |
| 75 | + |
| 76 | + |
| 77 | +human_consumer_host = ( |
| 78 | + load_from_env("HUMAN_CONSUMER_HOST", throw_error=False) or "127.0.0.1" |
| 79 | +) |
| 80 | +human_consumer_port = load_from_env("HUMAN_CONSUMER_PORT", throw_error=False) or "8002" |
| 81 | + |
| 82 | + |
| 83 | +human_consumer_server = TaskResultService( |
| 84 | + message_queue=message_queue, |
| 85 | + host=human_consumer_host, |
| 86 | + port=int(human_consumer_port) if human_consumer_port else None, |
| 87 | + name="human", |
| 88 | +) |
0 commit comments