Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
gromdimon committed Jan 12, 2025
1 parent 5acd25d commit 2bb514c
Show file tree
Hide file tree
Showing 13 changed files with 168 additions and 25 deletions.
48 changes: 26 additions & 22 deletions src/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@

from src.core.config import settings
from src.core.defs import AgentAction, AgentState
from src.feedback.feedback_module import FeedbackModule
from src.memory.memory_module import get_memory_module
from src.planning.planning_module import PlanningModule
from src.execution import ExecutionModule
from src.feedback import FeedbackModule
from src.memory import get_memory_module
from src.planning import PlanningModule
from src.workflows.analyze_signal import analyze_signal
from src.workflows.research_news import analyze_news_workflow

Expand Down Expand Up @@ -35,10 +36,10 @@ def __init__(self):
self.memory_module = get_memory_module()

#: Initialize Planning Module with persistent Q-table
self.planning_module = PlanningModule(
actions=list(AgentAction),
q_table_path=settings.PERSISTENT_Q_TABLE_PATH, # Persistent Q-table file
)
self.planning_module = PlanningModule()

#: Initialize Execution Module
self.execution_module = ExecutionModule()

#: Initialize Feedback Module
self.feedback_module = FeedbackModule()
Expand Down Expand Up @@ -73,10 +74,6 @@ def _update_planning_policy(
"""Update the Q-learning table in the PlanningModule."""
self.planning_module.update_q_table(state, action, reward, next_state)

def _collect_feedback(self, action: str, outcome: Optional[Any]) -> float:
"""Collect feedback for the action & outcome in the FeedbackModule."""
return self.feedback_module.collect_feedback(action, outcome)

# --------------------------------------------------------------
# RL-based PLANNING & EXECUTION
# --------------------------------------------------------------
Expand Down Expand Up @@ -130,23 +127,30 @@ async def start_runtime_loop(self) -> None:
# 1. Choose an action
# You might treat the entire system as one "state", or define states.
logger.info(f"Current state: {self.state.value}")
action_name = self.planning_module.get_action(self.state)
logger.info(f"Action chosen: {action_name.value}")
next_action = await self.planning_module.get_next_action(self.state)
logger.info(f"Action chosen: {next_action.value}")

# 2. Perform that action
outcome = await self._perform_planned_action(action_name)
logger.info(f"Outcome: {outcome}")
result = await self.execution_module.execute_action(next_action)
logger.info(f"Outcome: {result}")

# 3. Collect feedback
reward = self._collect_feedback(action_name.value, outcome)
reward = self.feedback_module.collect_feedback(
next_action.value, result.get("outcome")
)
logger.info(f"Reward: {reward}")

# 4. Update the planning policy
next_state = self.state
logger.info(f"Next state: {next_state.value}")
self._update_planning_policy(self.state, action_name, reward, next_state)

# 4. Sleep or yield
# 4. Update state and memory
self.state = AgentState.from_action(next_action)
# self.memory_module.store(
# {
# "state": self.state.value,
# "action": next_action.value,
# "outcome": result.get("outcome", "Unknown"),
# }
# )

# 5. Sleep or yield
logger.info("Let's rest a bit...")
await asyncio.sleep(settings.AGENT_REST_TIME)

Expand Down
6 changes: 6 additions & 0 deletions src/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,12 @@ class Settings(BaseSettings):

# --- Planning settings ---

#: Planning API URL
PLANNING_API_URL: str = "http://localhost:11434/api/generate"

#: Planning model
PLANNING_MODEL: str = "llama3.2"

#: Path to the persistent Q-table file
PERSISTENT_Q_TABLE_PATH: str = "persistent_q_table.json"

Expand Down
18 changes: 17 additions & 1 deletion src/core/defs.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,26 @@ class AgentState(Enum):

DEFAULT = "default"
IDLE = "idle"
WAITING_FOR_NEWS = "waiting_for_news"
JUST_ANALYZED_NEWS = "just_analyzed_news"
JUST_ANALYZED_SIGNAL = "just_analyzed_signal"

@classmethod
def from_action(cls, action: AgentAction) -> "AgentState":
"""Convert an AgentAction to the corresponding AgentState.
Args:
action: The AgentAction to convert
Returns:
The corresponding AgentState
"""
action_to_state = {
AgentAction.IDLE: cls.IDLE,
AgentAction.ANALYZE_NEWS: cls.JUST_ANALYZED_NEWS,
AgentAction.CHECK_SIGNAL: cls.JUST_ANALYZED_SIGNAL,
}
return action_to_state.get(action, cls.DEFAULT)


class MemoryBackendType(str, Enum):
"""Available memory backend types."""
Expand Down
3 changes: 3 additions & 0 deletions src/execution/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from src.execution.execution import ExecutionModule

__all__ = ["ExecutionModule"]
38 changes: 38 additions & 0 deletions src/execution/execution.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from loguru import logger

from src.core.defs import AgentAction
from src.workflows.analyze_signal import analyze_signal
from src.workflows.research_news import analyze_news_workflow


class ExecutionModule:
"""Module to execute agent actions."""

async def execute_action(self, action: AgentAction) -> dict:
"""
Execute the given action and return the result.
Args:
action (AgentAction): The action to execute.
context (dict): Context to pass to the action.
Returns:
dict: Outcome and status of the action.
"""
try:
if action == AgentAction.CHECK_SIGNAL:
result = await analyze_signal()
return {"status": "success", "outcome": result or "No signal"}

elif action == AgentAction.ANALYZE_NEWS:
recent_news = "Placeholder"
result = await analyze_news_workflow(recent_news)
return {"status": "success", "outcome": result or "No outcome"}

elif action == AgentAction.IDLE:
logger.info("Agent is idling.")
return {"status": "success", "outcome": "Idling"}

except Exception as e:
logger.error(f"Error executing action {action}: {e}")
return {"status": "error", "outcome": str(e)}
3 changes: 3 additions & 0 deletions src/feedback/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from src.feedback.feedback import FeedbackModule

__all__ = ["FeedbackModule"]
File renamed without changes.
3 changes: 3 additions & 0 deletions src/memory/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from src.memory.memory import MemoryModule, get_memory_module

__all__ = ["MemoryModule", "get_memory_module"]
File renamed without changes.
3 changes: 3 additions & 0 deletions src/planning/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from src.planning.planning import PlanningModule

__all__ = ["PlanningModule"]
67 changes: 67 additions & 0 deletions src/planning/planning.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import json

import httpx
from loguru import logger

from src.core.config import settings
from src.core.defs import AgentAction, AgentState


class PlanningModule:
"""Centralized Planning Module using LLM for action selection."""

def __init__(
self, api_url: str = settings.PLANNING_API_URL, model: str = settings.PLANNING_MODEL
):
"""
Initialize the Planning Module.
Args:
api_url (str): The URL for the LLM API.
model (str): The model name (e.g., 'llama3.2').
"""
self.api_url = api_url
self.model = model

async def get_next_action(self, state: AgentState) -> AgentAction:
"""
Call the LLM to decide the next action.
Args:
context (dict): Context including state, previous actions, and outcomes.
Returns:
AgentAction: The next action to perform.
"""
prompt = (
"You are a helpful assistant, who only decides the next action to perform. "
f"Following states are possible: {AgentState.__members__}. "
f"You have the following possible actions: {AgentAction.__members__}. "
"You need to decide the next action to perform. "
"Return only the name of the action to perform and nothing else!"
f"Current state of the agent: {state.value}. "
)
try:
payload = {
"model": self.model,
"prompt": prompt,
"stream": False,
}
async with httpx.AsyncClient() as client:
logger.debug(f"Sending payload: {payload}")
response = await client.post(self.api_url, json=payload)
logger.debug(f"Received response: {response}")
response.raise_for_status()

data = response.json()
logger.debug(f"LLM response: {data}")
next_action = data.get("response")
if not next_action or next_action not in AgentAction.__members__:
logger.warning(f"Invalid action returned by LLM: {next_action}")
return AgentAction.IDLE

logger.info(f"LLM suggested action: {next_action}")
return AgentAction[next_action]
except Exception as e:
logger.error(f"Failed to get next action from LLM: {e}")
return AgentAction.IDLE
2 changes: 1 addition & 1 deletion src/workflows/analyze_signal.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from loguru import logger

from src.llm.llm import LLM
from src.memory.memory_module import MemoryModule, get_memory_module
from src.memory import MemoryModule, get_memory_module
from src.tools.get_signal import fetch_signal
from src.tools.twitter import post_twitter_thread

Expand Down
2 changes: 1 addition & 1 deletion tests/feedback/test_feedback_module.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pytest

from src.feedback.feedback_module import FeedbackModule
from src.feedback import FeedbackModule


@pytest.fixture
Expand Down

0 comments on commit 2bb514c

Please sign in to comment.