Skip to content

Commit

Permalink
backend: Deployments refactor; Add deployment service and fix deploym…
Browse files Browse the repository at this point in the history
…ent config setting (#831)

* Deployments refactor; Add deployment service and fix deployment config setting

* Changes for code review

* Fix a number of integration and unit tests

* Fix failing chat tests

* Move some tests from unit/routers to integration/routers

* Fix a few more tests

* Fix a few more tests

* Fix remainder of broken integration tests

* Fix lint issues

* Run prettier on Coral

* Remove old, unused model crud helper

* Fix failing deployments unit tests

* Coral fix to account for agent.tools possibly being null

* Fix TS styling

* Provide a dummy Cohere API key during testing

* Update Coral to align with latest version of the backend API

* Fix lint issues in Coral

* Last few changes for code review

* Update generated API in assistants_web

* Fix assistants_web build

* Fix backend lint issues

* Simplify validate_deployment_header

* Don't seed the DB with deployment data, and fix a DeploymentDefinition serialization issue

* Fix backend lint issues

* Fix broken unit tests

* Skip cohere deployments tests since they're breaking other tests

* Fix deployment integration tests

* More fixes to deployments integration tests

* Fix deployment integration tests

* What API key are we using to call Cohere in the tests?

* Mock list_models of CoherePlatform model to avoid Cohere API calls

* Mask all deployment config values when looking up deployment info

* Fix integration tests

* Fix typecheck issue

* wip

* Minor changes

* use defaults

* resolve lint

* Update generated client

---------

Co-authored-by: Tianjing Li <[email protected]>
  • Loading branch information
malexw and tianjing-li authored Jan 14, 2025
1 parent 5ad48e6 commit 802c232
Show file tree
Hide file tree
Showing 88 changed files with 6,130 additions and 3,691 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from alembic import op

from backend.database_models.seeders.deplyments_models_seed import (
from backend.database_models.seeders.deployments_models_seed import (
delete_default_models,
deployments_models_seed,
)
Expand Down
29 changes: 9 additions & 20 deletions src/backend/chat/custom/utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
from typing import Any

from backend.config.deployments import (
AVAILABLE_MODEL_DEPLOYMENTS,
get_default_deployment,
)
from backend.database_models.database import get_session
from backend.exceptions import DeploymentNotFoundError
from backend.model_deployments.base import BaseDeployment
from backend.schemas.context import Context
from backend.services import deployment as deployment_service


def get_deployment(name: str, ctx: Context, **kwargs: Any) -> BaseDeployment:
Expand All @@ -16,22 +15,12 @@ def get_deployment(name: str, ctx: Context, **kwargs: Any) -> BaseDeployment:
Returns:
BaseDeployment: Deployment implementation instance based on the deployment name.
Raises:
ValueError: If the deployment is not supported.
"""
kwargs["ctx"] = ctx
deployment = AVAILABLE_MODEL_DEPLOYMENTS.get(name)

# Check provided deployment against config const
if deployment is not None:
return deployment.deployment_class(**kwargs, **deployment.kwargs)

# Fallback to first available deployment
default = get_default_deployment(**kwargs)
if default is not None:
return default
try:
session = next(get_session())
deployment = deployment_service.get_deployment_by_name(session, name, **kwargs)
except DeploymentNotFoundError:
deployment = deployment_service.get_default_deployment(**kwargs)

raise ValueError(
f"Deployment {name} is not supported, and no available deployments were found."
)
return deployment
4 changes: 2 additions & 2 deletions src/backend/config/default_agent.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import datetime

from backend.config.deployments import ModelDeploymentName
from backend.config.tools import Tool
from backend.model_deployments.cohere_platform import CohereDeployment
from backend.schemas.agent import AgentPublic

DEFAULT_AGENT_ID = "default"
DEFAULT_DEPLOYMENT = ModelDeploymentName.CoherePlatform
DEFAULT_DEPLOYMENT = CohereDeployment.name()
DEFAULT_MODEL = "command-r-plus"

def get_default_agent() -> AgentPublic:
Expand Down
135 changes: 15 additions & 120 deletions src/backend/config/deployments.py
Original file line number Diff line number Diff line change
@@ -1,140 +1,35 @@
from enum import StrEnum

from backend.config.settings import Settings
from backend.model_deployments import (
AzureDeployment,
BedrockDeployment,
CohereDeployment,
SageMakerDeployment,
SingleContainerDeployment,
)
from backend.model_deployments.azure import AZURE_ENV_VARS
from backend.model_deployments.base import BaseDeployment
from backend.model_deployments.bedrock import BEDROCK_ENV_VARS
from backend.model_deployments.cohere_platform import COHERE_ENV_VARS
from backend.model_deployments.sagemaker import SAGE_MAKER_ENV_VARS
from backend.model_deployments.single_container import SC_ENV_VARS
from backend.schemas.deployment import Deployment
from backend.services.logger.utils import LoggerFactory

logger = LoggerFactory().get_logger()


class ModelDeploymentName(StrEnum):
CoherePlatform = "Cohere Platform"
SageMaker = "SageMaker"
Azure = "Azure"
Bedrock = "Bedrock"
SingleContainer = "Single Container"


use_community_features = Settings().get('feature_flags.use_community_features')
ALL_MODEL_DEPLOYMENTS = { d.name(): d for d in BaseDeployment.__subclasses__() }

# TODO names in the map below should not be the display names but ids
ALL_MODEL_DEPLOYMENTS = {
ModelDeploymentName.CoherePlatform: Deployment(
id="cohere_platform",
name=ModelDeploymentName.CoherePlatform,
deployment_class=CohereDeployment,
models=CohereDeployment.list_models(),
is_available=CohereDeployment.is_available(),
env_vars=COHERE_ENV_VARS,
),
ModelDeploymentName.SingleContainer: Deployment(
id="single_container",
name=ModelDeploymentName.SingleContainer,
deployment_class=SingleContainerDeployment,
models=SingleContainerDeployment.list_models(),
is_available=SingleContainerDeployment.is_available(),
env_vars=SC_ENV_VARS,
),
ModelDeploymentName.SageMaker: Deployment(
id="sagemaker",
name=ModelDeploymentName.SageMaker,
deployment_class=SageMakerDeployment,
models=SageMakerDeployment.list_models(),
is_available=SageMakerDeployment.is_available(),
env_vars=SAGE_MAKER_ENV_VARS,
),
ModelDeploymentName.Azure: Deployment(
id="azure",
name=ModelDeploymentName.Azure,
deployment_class=AzureDeployment,
models=AzureDeployment.list_models(),
is_available=AzureDeployment.is_available(),
env_vars=AZURE_ENV_VARS,
),
ModelDeploymentName.Bedrock: Deployment(
id="bedrock",
name=ModelDeploymentName.Bedrock,
deployment_class=BedrockDeployment,
models=BedrockDeployment.list_models(),
is_available=BedrockDeployment.is_available(),
env_vars=BEDROCK_ENV_VARS,
),
}

def get_available_deployments() -> list[type[BaseDeployment]]:
installed_deployments = list(ALL_MODEL_DEPLOYMENTS.values())

def get_available_deployments() -> dict[ModelDeploymentName, Deployment]:
if use_community_features:
if Settings().get("feature_flags.use_community_features"):
try:
from community.config.deployments import (
AVAILABLE_MODEL_DEPLOYMENTS as COMMUNITY_DEPLOYMENTS_SETUP,
)

model_deployments = ALL_MODEL_DEPLOYMENTS.copy()
model_deployments.update(COMMUNITY_DEPLOYMENTS_SETUP)
return model_deployments
except ImportError:
installed_deployments.extend(COMMUNITY_DEPLOYMENTS_SETUP.values())
except ImportError as e:
logger.warning(
event="[Deployments] No available community deployments have been configured"
event="[Deployments] No available community deployments have been configured", ex=e
)

deployments = Settings().get('deployments.enabled_deployments')
if deployments is not None and len(deployments) > 0:
return {
key: value
for key, value in ALL_MODEL_DEPLOYMENTS.items()
if value.id in Settings().get('deployments.enabled_deployments')
}

return ALL_MODEL_DEPLOYMENTS


def get_default_deployment(**kwargs) -> BaseDeployment:
# Fallback to the first available deployment
fallback = None
for deployment in AVAILABLE_MODEL_DEPLOYMENTS.values():
if deployment.is_available:
fallback = deployment.deployment_class(**kwargs)
break

default = Settings().get('deployments.default_deployment')
if default:
return next(
(
v.deployment_class(**kwargs)
for k, v in AVAILABLE_MODEL_DEPLOYMENTS.items()
if v.id == default
),
fallback,
)
else:
return fallback


def find_config_by_deployment_id(deployment_id: str) -> Deployment:
for deployment in AVAILABLE_MODEL_DEPLOYMENTS.values():
if deployment.id == deployment_id:
return deployment
return None


def find_config_by_deployment_name(deployment_name: str) -> Deployment:
for deployment in AVAILABLE_MODEL_DEPLOYMENTS.values():
if deployment.name == deployment_name:
return deployment
return None
enabled_deployment_ids = Settings().get("deployments.enabled_deployments")
if enabled_deployment_ids:
return [
deployment
for deployment in installed_deployments
if deployment.id() in enabled_deployment_ids
]

return installed_deployments

AVAILABLE_MODEL_DEPLOYMENTS = get_available_deployments()
26 changes: 11 additions & 15 deletions src/backend/crud/deployment.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
import os

from sqlalchemy.orm import Session

from backend.database_models import Deployment
from backend.model_deployments.utils import class_name_validator
from backend.schemas.deployment import Deployment as DeploymentSchema
from backend.schemas.deployment import DeploymentCreate, DeploymentUpdate
from backend.services.transaction import validate_transaction
from community.config.deployments import (
AVAILABLE_MODEL_DEPLOYMENTS as COMMUNITY_DEPLOYMENTS,
from backend.schemas.deployment import (
DeploymentCreate,
DeploymentDefinition,
DeploymentUpdate,
)
from backend.services.transaction import validate_transaction


@validate_transaction
Expand All @@ -19,7 +18,7 @@ def create_deployment(db: Session, deployment: DeploymentCreate) -> Deployment:
Args:
db (Session): Database session.
deployment (DeploymentSchema): Deployment data to be created.
deployment (DeploymentDefinition): Deployment data to be created.
Returns:
Deployment: Created deployment.
Expand Down Expand Up @@ -132,27 +131,24 @@ def delete_deployment(db: Session, deployment_id: str) -> None:


@validate_transaction
def create_deployment_by_config(db: Session, deployment_config: DeploymentSchema) -> Deployment:
def create_deployment_by_config(db: Session, deployment_config: DeploymentDefinition) -> Deployment:
"""
Create a new deployment by config.
Args:
db (Session): Database session.
deployment (str): Deployment data to be created.
deployment_config (DeploymentSchema): Deployment config.
deployment_config (DeploymentDefinition): Deployment config.
Returns:
Deployment: Created deployment.
"""
deployment = Deployment(
name=deployment_config.name,
description="",
default_deployment_config= {
env_var: os.environ.get(env_var, "")
for env_var in deployment_config.env_vars
},
deployment_class_name=deployment_config.deployment_class.__name__,
is_community=deployment_config.name in COMMUNITY_DEPLOYMENTS
default_deployment_config=deployment_config.config,
deployment_class_name=deployment_config.class_name,
is_community=deployment_config.is_community,
)
db.add(deployment)
db.commit()
Expand Down
22 changes: 12 additions & 10 deletions src/backend/crud/model.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from sqlalchemy.orm import Session

from backend.database_models import Deployment
from backend.database_models.model import Model
from backend.schemas.deployment import Deployment as DeploymentSchema
from backend.schemas.deployment import DeploymentDefinition
from backend.schemas.model import ModelCreate, ModelUpdate
from backend.services.logger.utils import LoggerFactory
from backend.services.transaction import validate_transaction

logger = LoggerFactory().get_logger()


@validate_transaction
def create_model(db: Session, model: ModelCreate) -> Model:
Expand Down Expand Up @@ -127,29 +129,29 @@ def delete_model(db: Session, model_id: str) -> None:
db.commit()


def create_model_by_config(db: Session, deployment: Deployment, deployment_config: DeploymentSchema, model: str) -> Model:
def create_model_by_config(db: Session, deployment_config: DeploymentDefinition, deployment_id: str, model: str | None) -> Model:
"""
Create a new model by config if present
Args:
db (Session): Database session.
deployment (Deployment): Deployment data.
deployment_config (DeploymentSchema): Deployment config data.
model (str): Model data.
deployment_config (DeploymentDefinition): A deployment definition for any kind of deployment.
deployment_id (DeploymentDefinition): Deployment ID for a deployment from the DB.
model (str): Optional model name that should have its data returned from this call.
Returns:
Model: Created model.
"""
deployment_config_models = deployment_config.models
deployment_db_models = get_models_by_deployment_id(db, deployment.id)
logger.debug(event="create_model_by_config", deployment_models=deployment_config.models, deployment_id=deployment_id, model=model)
deployment_db_models = get_models_by_deployment_id(db, deployment_id)
model_to_return = None
for deployment_config_model in deployment_config_models:
for deployment_config_model in deployment_config.models:
model_in_db = any(record.name == deployment_config_model for record in deployment_db_models)
if not model_in_db:
new_model = Model(
name=deployment_config_model,
cohere_name=deployment_config_model,
deployment_id=deployment.id,
deployment_id=deployment_id,
)
db.add(new_model)
db.commit()
Expand Down
25 changes: 25 additions & 0 deletions src/backend/database_models/seeders/deployments_models_seed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from sqlalchemy.orm import Session

from backend.database_models import Deployment, Model, Organization


def deployments_models_seed(op):
"""
Seed default deployments, models, organization, user and agent.
"""
# Previously we would seed the default deployments and models here. We've changed this
# behaviour during a refactor of the deployments module so that deployments and models
# are inserted when they're first used. This solves an issue where seed data would
# sometimes be inserted with invalid config data.
pass


def delete_default_models(op):
"""
Delete deployments and models.
"""
session = Session(op.get_bind())
session.query(Deployment).delete()
session.query(Model).delete()
session.query(Organization).filter_by(id="default").delete()
session.commit()
Loading

0 comments on commit 802c232

Please sign in to comment.