Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes for using nested Agents, languse upgrade needed #899

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
)
from app.utils.prisma import prisma
from prisma.models import Agent as AgentModel
from prisma.enums import ToolType

from .base import BaseApiAgentManager

Expand Down Expand Up @@ -150,12 +151,12 @@ async def add_assistant(self, data: dict, _):
new_tool = await self.create_tool(
assistant=self.parent_agent.dict(),
data={
"name": new_agent.name,
"description": data.get("description"),
**data,
"metadata": {
**(data.get("metadata")),
"agentId": new_agent.id,
},
"type": "AGENT",
"type": ToolType.AGENT.value,
},
)

Expand Down Expand Up @@ -227,13 +228,19 @@ async def update_assistant(self, assistant: dict, data: dict):
logger.error(f"Error updating assistant: {assistant} - Error: {err}")

tool = await self.get_agent_tool(assistant)
tool_metadata = json.loads(tool.metadata)

try:
await api_update_tool(
tool_id=tool.id,
body=ToolUpdateRequest(
name=data.get("name"),
description=data.get("description"),
body=ToolUpdateRequest.parse_obj(
{
**data,
"metadata": {
**(data.get("metadata")),
"agentId": tool_metadata.get("agentId"),
},
}
),
api_user=self.api_user,
)
Expand Down
16 changes: 11 additions & 5 deletions libs/superagent/app/api/workflow_configs/data_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

SUPERRAG_MIME_TYPE_TO_EXTENSION = {
"text/plain": "TXT",
"text/markdown": "MARKDOWN",
"text/markdown": "MD",
"application/pdf": "PDF",
"application/vnd.openxmlformats-officedocument.wordprocessingml.document": "DOCX",
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet": "XLSX",
Expand Down Expand Up @@ -110,7 +110,8 @@ def transform_assistant(self):
if self.assistant.get("type") == AgentType.LLM.value:
remove_key_if_present(self.assistant, "llmModel")
else:
self.assistant["llmModel"] = LLM_REVERSE_MAPPING.get(llm_model, llm_model)
self.assistant["llmModel"] = LLM_REVERSE_MAPPING.get(
llm_model, llm_model)

self.assistant["metadata"] = {
**(self.assistant.get("params") or {}),
Expand All @@ -122,7 +123,10 @@ def transform_tools(self):
tool_type = get_first_non_null_key(tool_obj)
tool = tool_obj.get(tool_type)

rename_and_remove_keys(tool, {"use_for": "description"})
rename_and_remove_keys(
tool, {"use_for": "description",
"return_direct": "returnDirect"}
)

if tool_type:
tool["type"] = tool_type.upper()
Expand Down Expand Up @@ -177,7 +181,8 @@ async def _set_database_provider(self, datasource: dict):

# this is for superrag
if database:
database_provider = REVERSE_VECTOR_DB_MAPPING.get(database.provider)
database_provider = REVERSE_VECTOR_DB_MAPPING.get(
database.provider)
credentials = get_superrag_compatible_credentials(database.options)
datasource["vector_database"] = {
"type": database_provider,
Expand Down Expand Up @@ -208,7 +213,8 @@ async def _set_superrag_files(self, datasource: dict):

def _get_file_type(self, url: str):
try:
file_type = SUPERRAG_MIME_TYPE_TO_EXTENSION[get_mimetype_from_url(url)]
file_type = SUPERRAG_MIME_TYPE_TO_EXTENSION[get_mimetype_from_url(
url)]
except KeyError:
supported_file_types = ", ".join(
value for value in SUPERRAG_MIME_TYPE_TO_EXTENSION.values()
Expand Down
15 changes: 8 additions & 7 deletions libs/superagent/app/api/workflow_configs/saml_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,11 @@ class Tool(BaseModel):
name: str
use_for: str
metadata: Optional[dict[str, Any]]
return_direct: Optional[bool] = Field(
default=False,
description="""Whether to return the tool's output directly.
If this is set to true, the output of the tool will not be returned to the LLM""",
)


class ToolModel(BaseModel):
Expand Down Expand Up @@ -119,19 +124,15 @@ class OpenAIAgent(Assistant):
pass


class BaseAgentToolModel(BaseModel):
use_for: str


class SuperagentAgentTool(BaseAgentToolModel, SuperagentAgent):
class SuperagentAgentTool(Tool, SuperagentAgent):
pass


class OpenAIAgentTool(BaseAgentToolModel, OpenAIAgent):
class OpenAIAgentTool(Tool, OpenAIAgent):
pass


class LLMAgentTool(BaseAgentToolModel, LLMAgent):
class LLMAgentTool(Tool, LLMAgent):
pass


Expand Down
14 changes: 9 additions & 5 deletions libs/superagent/app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,10 @@
from app.routers import router
from app.utils.prisma import prisma

# Create a color formatter
# Create a color formatter including file name and line number
formatter = colorlog.ColoredFormatter(
"%(log_color)s%(levelname)s: %(message)s",
"%(log_color)s%(asctime)s - %(levelname)s - %(pathname)s:%(lineno)d:\n%(message)s",
datefmt='%Y-%m-%d %H:%M:%S',
log_colors={
"DEBUG": "cyan",
"INFO": "green",
Expand All @@ -21,17 +22,20 @@
},
secondary_log_colors={},
style="%",
) # Create a console handler and set the formatter
)
console_handler = logging.StreamHandler()
console_handler.setFormatter(formatter)

# Update basicConfig format to include file name and line number
logging.basicConfig(
level=logging.INFO,
format="%(levelname)s: %(message)s",
format="%(asctime)s - %(levelname)s - %(pathname)s:%(lineno)d:\n%(message)s",
handlers=[console_handler],
force=True,
)

logger = logging.getLogger("main")

app = FastAPI(
title="Superagent",
docs_url="/",
Expand All @@ -54,7 +58,7 @@ async def add_process_time_header(request: Request, call_next):
start_time = time.time()
response = await call_next(request)
process_time = time.time() - start_time
print(f"Total request time: {process_time} secs")
logger.debug(f"Total request time: {process_time} secs")
return response


Expand Down
8 changes: 5 additions & 3 deletions libs/superagent/app/utils/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ class CustomAsyncIteratorCallbackHandler(AsyncCallbackHandler):

done: asyncio.Event

TIMEOUT_SECONDS = 30
TIMEOUT_SECONDS = 60
is_stream_started = False

@property
Expand Down Expand Up @@ -81,14 +81,16 @@ async def aiter(self) -> AsyncIterator[str]:
timeout=self.TIMEOUT_SECONDS,
)
if not done:
logger.warning(f"{self.TIMEOUT_SECONDS} seconds of timeout reached")
logger.warning(
f"{self.TIMEOUT_SECONDS} seconds of timeout reached")
self.done.set()
break

for future in pending:
future.cancel()

token_or_done = cast(Union[str, Literal[True]], done.pop().result())
token_or_done = cast(
Union[str, Literal[True]], done.pop().result())

if token_or_done is True:
continue
Expand Down
Loading