Skip to content

Commit

Permalink
TLK-1864 agents deployments models refactoring - review fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
EugeneLightsOn committed Nov 4, 2024
1 parent 0cea965 commit 46b536f
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 51 deletions.
104 changes: 53 additions & 51 deletions src/backend/routers/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
AgentToolMetadata as AgentToolMetadataModel,
)
from backend.database_models.database import DBSessionDep
from backend.routers.utils import get_deployment_model_from_agent
from backend.routers.utils import get_deployment_model_from_agent, get_default_deployment_model
from backend.schemas.agent import (
Agent,
AgentPublic,
Expand Down Expand Up @@ -53,6 +53,7 @@
)
router.name = RouterName.AGENT


@router.post(
"",
response_model=AgentPublic,
Expand All @@ -62,9 +63,9 @@
],
)
async def create_agent(
session: DBSessionDep,
agent: CreateAgentRequest,
ctx: Context = Depends(get_context),
session: DBSessionDep,
agent: CreateAgentRequest,
ctx: Context = Depends(get_context),
) -> AgentPublic:
"""
Create an agent.
Expand All @@ -83,6 +84,7 @@ async def create_agent(
logger = ctx.get_logger()

deployment_db, model_db = get_deployment_model_from_agent(agent, session)
default_deployment_db, default_model_db = get_default_deployment_model(session)
try:
if deployment_db and model_db:
agent_data = AgentModel(
Expand All @@ -94,8 +96,8 @@ async def create_agent(
organization_id=agent.organization_id,
tools=agent.tools,
is_private=agent.is_private,
deployment_id=deployment_db.id if deployment_db else None,
model_id=model_db.id if model_db else None,
deployment_id=deployment_db.id if deployment_db else default_deployment_db.id if default_deployment_db else None,
model_id=model_db.id if model_db else default_model_db.id if default_model_db else None,
)

created_agent = agent_crud.create_agent(session, agent_data)
Expand All @@ -117,13 +119,13 @@ async def create_agent(

@router.get("", response_model=list[AgentPublic])
async def list_agents(
*,
offset: int = 0,
limit: int = 100,
session: DBSessionDep,
visibility: AgentVisibility = AgentVisibility.ALL,
organization_id: Optional[str] = None,
ctx: Context = Depends(get_context),
*,
offset: int = 0,
limit: int = 100,
session: DBSessionDep,
visibility: AgentVisibility = AgentVisibility.ALL,
organization_id: Optional[str] = None,
ctx: Context = Depends(get_context),
) -> list[AgentPublic]:
"""
List all agents.
Expand Down Expand Up @@ -161,7 +163,7 @@ async def list_agents(

@router.get("/{agent_id}", response_model=AgentPublic)
async def get_agent_by_id(
agent_id: str, session: DBSessionDep, ctx: Context = Depends(get_context)
agent_id: str, session: DBSessionDep, ctx: Context = Depends(get_context)
) -> Agent:
"""
Args:
Expand Down Expand Up @@ -196,7 +198,7 @@ async def get_agent_by_id(

@router.get("/{agent_id}/deployments", response_model=list[DeploymentSchema])
async def get_agent_deployment(
agent_id: str, session: DBSessionDep, ctx: Context = Depends(get_context)
agent_id: str, session: DBSessionDep, ctx: Context = Depends(get_context)
) -> DeploymentSchema:
"""
Args:
Expand Down Expand Up @@ -228,10 +230,10 @@ async def get_agent_deployment(
],
)
async def update_agent(
agent_id: str,
new_agent: UpdateAgentRequest,
session: DBSessionDep,
ctx: Context = Depends(get_context),
agent_id: str,
new_agent: UpdateAgentRequest,
session: DBSessionDep,
ctx: Context = Depends(get_context),
) -> AgentPublic:
"""
Update an agent by ID.
Expand Down Expand Up @@ -285,9 +287,9 @@ async def update_agent(

@router.delete("/{agent_id}", response_model=DeleteAgent)
async def delete_agent(
agent_id: str,
session: DBSessionDep,
ctx: Context = Depends(get_context),
agent_id: str,
session: DBSessionDep,
ctx: Context = Depends(get_context),
) -> DeleteAgent:
"""
Delete an agent by ID.
Expand Down Expand Up @@ -319,10 +321,10 @@ async def delete_agent(


async def handle_tool_metadata_update(
agent: Agent,
new_agent: Agent,
session: DBSessionDep,
ctx: Context = Depends(get_context),
agent: Agent,
new_agent: Agent,
session: DBSessionDep,
ctx: Context = Depends(get_context),
) -> Agent:
"""Update or create tool metadata for an agent.
Expand Down Expand Up @@ -360,10 +362,10 @@ async def handle_tool_metadata_update(


async def update_or_create_tool_metadata(
agent: Agent,
new_tool_metadata: AgentToolMetadata,
session: DBSessionDep,
ctx: Context = Depends(get_context),
agent: Agent,
new_tool_metadata: AgentToolMetadata,
session: DBSessionDep,
ctx: Context = Depends(get_context),
) -> None:
"""Update or create tool metadata for an agent.
Expand All @@ -389,7 +391,7 @@ async def update_or_create_tool_metadata(

@router.get("/{agent_id}/tool-metadata", response_model=list[AgentToolMetadataPublic])
async def list_agent_tool_metadata(
agent_id: str, session: DBSessionDep, ctx: Context = Depends(get_context)
agent_id: str, session: DBSessionDep, ctx: Context = Depends(get_context)
) -> list[AgentToolMetadataPublic]:
"""
List all agent tool metadata by agent ID.
Expand Down Expand Up @@ -421,10 +423,10 @@ async def list_agent_tool_metadata(
response_model=AgentToolMetadataPublic,
)
def create_agent_tool_metadata(
session: DBSessionDep,
agent_id: str,
agent_tool_metadata: CreateAgentToolMetadataRequest,
ctx: Context = Depends(get_context),
session: DBSessionDep,
agent_id: str,
agent_tool_metadata: CreateAgentToolMetadataRequest,
ctx: Context = Depends(get_context),
) -> AgentToolMetadataPublic:
"""
Create an agent tool metadata.
Expand Down Expand Up @@ -470,11 +472,11 @@ def create_agent_tool_metadata(

@router.put("/{agent_id}/tool-metadata/{agent_tool_metadata_id}")
async def update_agent_tool_metadata(
agent_id: str,
agent_tool_metadata_id: str,
session: DBSessionDep,
new_agent_tool_metadata: UpdateAgentToolMetadataRequest,
ctx: Context = Depends(get_context),
agent_id: str,
agent_tool_metadata_id: str,
session: DBSessionDep,
new_agent_tool_metadata: UpdateAgentToolMetadataRequest,
ctx: Context = Depends(get_context),
) -> AgentToolMetadata:
"""
Update an agent tool metadata by ID.
Expand Down Expand Up @@ -514,10 +516,10 @@ async def update_agent_tool_metadata(

@router.delete("/{agent_id}/tool-metadata/{agent_tool_metadata_id}")
async def delete_agent_tool_metadata(
agent_id: str,
agent_tool_metadata_id: str,
session: DBSessionDep,
ctx: Context = Depends(get_context),
agent_id: str,
agent_tool_metadata_id: str,
session: DBSessionDep,
ctx: Context = Depends(get_context),
) -> DeleteAgentToolMetadata:
"""
Delete an agent tool metadata by ID.
Expand Down Expand Up @@ -556,9 +558,9 @@ async def delete_agent_tool_metadata(

@router.post("/batch_upload_file", response_model=list[UploadAgentFileResponse])
async def batch_upload_file(
session: DBSessionDep,
files: list[FastAPIUploadFile] = RequestFile(...),
ctx: Context = Depends(get_context),
session: DBSessionDep,
files: list[FastAPIUploadFile] = RequestFile(...),
ctx: Context = Depends(get_context),
) -> UploadAgentFileResponse:
user_id = ctx.get_user_id()

Expand All @@ -580,10 +582,10 @@ async def batch_upload_file(

@router.delete("/{agent_id}/files/{file_id}")
async def delete_agent_file(
agent_id: str,
file_id: str,
session: DBSessionDep,
ctx: Context = Depends(get_context),
agent_id: str,
file_id: str,
session: DBSessionDep,
ctx: Context = Depends(get_context),
) -> DeleteAgentFileResponse:
"""
Delete an agent file by ID.
Expand Down
18 changes: 18 additions & 0 deletions src/backend/routers/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from backend.config.deployments import ModelDeploymentName
from backend.database_models.database import DBSessionDep
from backend.schemas.agent import Agent

Expand All @@ -21,3 +22,20 @@ def get_deployment_model_from_agent(agent: Agent, session: DBSessionDep):
None,
)
return deployment_db, model_db


def get_default_deployment_model(session: DBSessionDep):
from backend.crud import deployment as deployment_crud

deployment_db = deployment_crud.get_deployment_by_name(session, ModelDeploymentName.CoherePlatform)
model_db = None
if deployment_db:
model_db = next(
(
model
for model in deployment_db.models
if model.name == 'command-r-plus'
),
None,
)
return deployment_db, model_db

0 comments on commit 46b536f

Please sign in to comment.