Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TLK-1864 agents deployments models refactoring #824

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
"""update agent deployment model
Revision ID: 74ba7e1b4810
Revises: 20b03fd331e8
Create Date: 2024-10-28 13:27:22.299287
"""
from typing import Sequence, Union

import sqlalchemy as sa
from alembic import op

# revision identifiers, used by Alembic.
revision: str = '74ba7e1b4810'
down_revision: Union[str, None] = '20b03fd331e8'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None


def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column('agents', sa.Column('deployment_id', sa.String(), nullable=True))
op.add_column('agents', sa.Column('model_id', sa.String(), nullable=True))
op.create_foreign_key('agents_model_id_fkey', 'agents', 'models', ['model_id'], ['id'], ondelete='CASCADE')
op.create_foreign_key('agents_deployment_id_fkey', 'agents', 'deployments', ['deployment_id'], ['id'], ondelete='CASCADE')
# set the deployment_id and model_id for the agents using agent_deployment_model table
# and then drop the table agent_deployment_model
op.execute(
"""
UPDATE agents
SET deployment_id = agent_deployment_model.deployment_id,
model_id = agent_deployment_model.model_id
FROM agent_deployment_model
WHERE agents.id = agent_deployment_model.agent_id;
"""
)
op.drop_table('agent_deployment_model')
# ### end Alembic commands ###


def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_constraint('agents_deployment_id_fkey', 'agents', type_='foreignkey')
op.drop_constraint('agents_model_id_fkey', 'agents', type_='foreignkey')
op.drop_column('agents', 'model_id')
op.drop_column('agents', 'deployment_id')
# ### end Alembic commands ###
148 changes: 6 additions & 142 deletions src/backend/crud/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,8 @@
from sqlalchemy.orm import Session
from sqlalchemy.sql.expression import false, true

from backend.database_models import Deployment
from backend.database_models.agent import Agent, AgentDeploymentModel
from backend.schemas.agent import AgentVisibility, UpdateAgentRequest
from backend.database_models.agent import Agent
from backend.schemas.agent import AgentVisibility, UpdateAgentDB
from backend.services.transaction import validate_transaction


Expand Down Expand Up @@ -78,59 +77,6 @@ def get_agent_by_name(db: Session, agent_name: str, user_id: str) -> Agent:
return agent


@validate_transaction
def get_association_by_deployment_name(
db: Session, agent: Agent, deployment_name: str
) -> AgentDeploymentModel:
"""
Get an agent deployment model association by deployment name.
Args:
db (Session): Database session.
agent (Agent): Agent to get the association.
deployment_name (str): Deployment name.
Returns:
AgentDeploymentModel: Agent deployment model association.
"""
return (
db.query(AgentDeploymentModel)
.join(Deployment, Deployment.id == AgentDeploymentModel.deployment_id)
.filter(
Deployment.name == deployment_name,
AgentDeploymentModel.agent_id == agent.id,
)
.first()
)


@validate_transaction
def get_association_by_deployment_id(
db: Session, agent: Agent, deployment_id: str
) -> AgentDeploymentModel:
"""
Get an agent deployment model association by deployment id.
Args:
db (Session): Database session.
agent (Agent): Agent to get the association.
deployment_id (str): Deployment ID.
Returns:
AgentDeploymentModel: Agent deployment model association.
"""
return (
db.query(AgentDeploymentModel)
.filter(
AgentDeploymentModel.deployment_id == deployment_id,
AgentDeploymentModel.agent_id == agent.id,
AgentDeploymentModel.is_default_deployment == true(),
AgentDeploymentModel.is_default_model == true(),
)
.first()
)


@validate_transaction
def get_agents(
db: Session,
Expand Down Expand Up @@ -176,93 +122,9 @@ def get_agents(
return query.all()


@validate_transaction
def get_agent_model_deployment_association(
db: Session, agent: Agent, model_id: str, deployment_id: str
) -> AgentDeploymentModel:
"""
Get an agent model deployment association.
Args:
db (Session): Database session.
agent (Agent): Agent to get the association.
model_id (str): Model ID.
deployment_id (str): Deployment ID.
Returns:
AgentDeploymentModel: Agent model deployment association.
"""
return (
db.query(AgentDeploymentModel)
.filter(
AgentDeploymentModel.agent_id == agent.id,
AgentDeploymentModel.model_id == model_id,
AgentDeploymentModel.deployment_id == deployment_id,
)
.first()
)


@validate_transaction
def delete_agent_model_deployment_association(
db: Session, agent: Agent, model_id: str, deployment_id: str
):
"""
Delete an agent model deployment association.
Args:
db (Session): Database session.
agent (Agent): Agent to delete the association.
model_id (str): Model ID.
deployment_id (str): Deployment ID.
"""
db.query(AgentDeploymentModel).filter(
AgentDeploymentModel.agent_id == agent.id,
AgentDeploymentModel.model_id == model_id,
AgentDeploymentModel.deployment_id == deployment_id,
).delete()
db.commit()


@validate_transaction
def assign_model_deployment_to_agent(
db: Session,
agent: Agent,
model_id: str,
deployment_id: str,
deployment_config: dict[str, str] = {},
set_default: bool = False,
) -> Agent:
"""
Assign a model and deployment to an agent.
Args:
agent (Agent): Agent to assign the model and deployment.
model_id (str): Model ID.
deployment_id (str): Deployment ID.
deployment_config (dict[str, str]): Deployment configuration.
set_default (bool): Set the model and deployment as default.
Returns:
Agent: Agent with the assigned model and deployment.
"""
agent_deployment = AgentDeploymentModel(
agent_id=agent.id,
model_id=model_id,
deployment_id=deployment_id,
is_default_deployment=set_default,
is_default_model=set_default,
deployment_config=deployment_config,
)
db.add(agent_deployment)
db.commit()
db.refresh(agent)
return agent


@validate_transaction
def update_agent(
db: Session, agent: Agent, new_agent: UpdateAgentRequest, user_id: str
db: Session, agent: Agent, new_agent: UpdateAgentDB, user_id: str
) -> Agent:
"""
Update an agent.
Expand All @@ -278,7 +140,9 @@ def update_agent(
if agent.is_private and agent.user_id != user_id:
return None

for attr, value in new_agent.model_dump(exclude_none=True).items():
new_agent_cleaned = new_agent.dict(exclude_unset=True, exclude_none=True)

for attr, value in new_agent_cleaned.items():
setattr(agent, attr, value)

db.commit()
Expand Down
65 changes: 2 additions & 63 deletions src/backend/crud/deployment.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from sqlalchemy.orm import Session

from backend.database_models import AgentDeploymentModel, Deployment
from backend.database_models import Deployment
from backend.model_deployments.utils import class_name_validator
from backend.schemas.deployment import Deployment as DeploymentSchema
from backend.schemas.deployment import DeploymentCreate, DeploymentUpdate
Expand Down Expand Up @@ -92,70 +92,9 @@ def get_available_deployments(
"""
all_deployments = db.query(Deployment).all()
return [deployment for deployment in all_deployments if deployment.is_available][
offset : offset + limit
offset: offset + limit
]


def get_deployments_by_agent_id(
db: Session, agent_id: str, offset: int = 0, limit: int = 100
) -> list[Deployment]:
"""
List all deployments by user id
Args:
db (Session): Database session.
agent_id (str): User ID
offset (int): Offset to start the list.
limit (int): Limit of deployments to be listed.
Returns:
list[Deployment]: List of deployments.
"""
return (
db.query(Deployment)
.join(
AgentDeploymentModel,
Deployment.id == AgentDeploymentModel.deployment_id,
)
.filter(AgentDeploymentModel.agent_id == agent_id)
.limit(limit)
.offset(offset)
.all()
)


def get_available_deployments_by_agent_id(
db: Session, agent_id: str, offset: int = 0, limit: int = 100
) -> list[Deployment]:
"""
List all deployments by user id
Args:
db (Session): Database session.
agent_id (str): User ID
offset (int): Offset to start the list.
limit (int): Limit of deployments to be listed.
Returns:
list[Deployment]: List of deployments.
"""
agent_deployments = (
db.query(Deployment)
.join(
AgentDeploymentModel,
Deployment.id == AgentDeploymentModel.deployment_id,
)
.filter(AgentDeploymentModel.agent_id == agent_id)
.limit(limit)
.offset(offset)
.all()
)

return [deployment for deployment in agent_deployments if deployment.is_available][
offset : offset + limit
]


@validate_transaction
def update_deployment(
db: Session, deployment: Deployment, new_deployment: DeploymentUpdate
Expand Down
32 changes: 1 addition & 31 deletions src/backend/crud/model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from sqlalchemy.orm import Session

from backend.database_models import AgentDeploymentModel, Deployment
from backend.database_models import Deployment
from backend.database_models.model import Model
from backend.schemas.deployment import Deployment as DeploymentSchema
from backend.schemas.model import ModelCreate, ModelUpdate
Expand Down Expand Up @@ -127,36 +127,6 @@ def delete_model(db: Session, model_id: str) -> None:
db.commit()


def get_models_by_agent_id(
db: Session, agent_id: str, offset: int = 0, limit: int = 100
) -> list[Model]:
"""
List all models by user id

Args:
db (Session): Database session.
agent_id (str): User ID
offset (int): Offset to start the list.
limit (int): Limit of models to be listed.

Returns:
list[Model]: List of models.
"""

return (
db.query(Model)
.join(
AgentDeploymentModel,
agent_id == AgentDeploymentModel.agent_id,
)
.filter(Model.deployment_id == AgentDeploymentModel.deployment_id)
.order_by(Model.name)
.limit(limit)
.offset(offset)
.all()
)


def create_model_by_config(db: Session, deployment: Deployment, deployment_config: DeploymentSchema, model: str) -> Model:
"""
Create a new model by config if present
Expand Down
Loading
Loading