Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

backend: Deployments refactor; Add deployment service and fix deployment config setting #831

Open
wants to merge 43 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
080fde9
Deployments refactor; Add deployment service and fix deployment confi…
malexw Nov 6, 2024
54da111
Changes for code review
malexw Nov 14, 2024
6b8025e
Fix a number of integration and unit tests
malexw Nov 19, 2024
87d4367
Merge latest main and fix a few tests
malexw Nov 22, 2024
3775f56
Fix failing chat tests
malexw Nov 26, 2024
9a3436d
Move some tests from unit/routers to integration/routers
malexw Nov 28, 2024
0575161
Fix a few more tests
malexw Nov 29, 2024
ba6a829
Merge branch 'main' into alexw/tlk-801-when-i-fill-in-the-env-variabl…
malexw Nov 29, 2024
e154d30
Fix a few more tests
malexw Nov 29, 2024
a2cb2cc
Fix remainder of broken integration tests
malexw Dec 2, 2024
00867f4
Fix lint issues
malexw Dec 2, 2024
062dd41
Run prettier on Coral
malexw Dec 2, 2024
14bc51d
Remove old, unused model crud helper
malexw Dec 2, 2024
3617b64
Fix failing deployments unit tests
malexw Dec 2, 2024
883064d
Coral fix to account for agent.tools possibly being null
malexw Dec 2, 2024
04787a1
Fix TS styling
malexw Dec 2, 2024
721f447
Provide a dummy Cohere API key during testing
malexw Dec 2, 2024
b6ff9d7
Update Coral to align with latest version of the backend API
malexw Dec 5, 2024
6bdcad3
Fix lint issues in Coral
malexw Dec 5, 2024
dad938b
Merge branch 'main' into alexw/tlk-801-when-i-fill-in-the-env-variabl…
malexw Dec 5, 2024
af5fca0
Last few changes for code review
malexw Dec 5, 2024
a14623e
Merge branch 'main' into alexw/tlk-801-when-i-fill-in-the-env-variabl…
malexw Dec 6, 2024
fb7e0eb
Update generated API in assistants_web
malexw Dec 10, 2024
51c641e
Merge branch 'main' into alexw/tlk-801-when-i-fill-in-the-env-variabl…
malexw Dec 10, 2024
5dbf9a4
Fix assistants_web build
malexw Dec 10, 2024
dc8ab67
Fix backend lint issues
malexw Dec 10, 2024
6b62703
Merge branch 'main' into alexw/tlk-801-when-i-fill-in-the-env-variabl…
malexw Dec 10, 2024
609584d
Simplify validate_deployment_header
malexw Dec 10, 2024
e7e9d48
Don't seed the DB with deployment data, and fix a DeploymentDefinitio…
malexw Dec 14, 2024
b1dcce9
Fix backend lint issues
malexw Dec 14, 2024
81a5e88
Merge branch 'main' into alexw/tlk-801-when-i-fill-in-the-env-variabl…
malexw Dec 14, 2024
99114b1
Fix broken unit tests
malexw Dec 15, 2024
025a455
Skip cohere deployments tests since they're breaking other tests
malexw Dec 15, 2024
28982bb
Fix deployment integration tests
malexw Dec 17, 2024
c67e213
More fixes to deployments integration tests
malexw Dec 17, 2024
faed53e
Fix deployment integration tests
malexw Dec 18, 2024
b8cc3c3
What API key are we using to call Cohere in the tests?
malexw Dec 18, 2024
18d1088
Mock list_models of CoherePlatform model to avoid Cohere API calls
malexw Dec 18, 2024
ccb6a6f
Merge branch 'main' into alexw/tlk-801-when-i-fill-in-the-env-variabl…
malexw Dec 18, 2024
13c554f
Mask all deployment config values when looking up deployment info
malexw Dec 20, 2024
d65f7da
Fix integration tests
malexw Dec 20, 2024
33be802
Fix typecheck issue
malexw Dec 20, 2024
f24eaeb
Merge branch 'main' into alexw/tlk-801-when-i-fill-in-the-env-variabl…
malexw Dec 20, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 7 additions & 20 deletions src/backend/chat/custom/utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
from typing import Any

from backend.config.deployments import (
AVAILABLE_MODEL_DEPLOYMENTS,
get_default_deployment,
)
from backend.exceptions import DeploymentNotFoundError
malexw marked this conversation as resolved.
Show resolved Hide resolved
from backend.model_deployments.base import BaseDeployment
from backend.schemas.context import Context
from backend.services import deployment as deployment_service


def get_deployment(name: str, ctx: Context, **kwargs: Any) -> BaseDeployment:
Expand All @@ -16,22 +14,11 @@ def get_deployment(name: str, ctx: Context, **kwargs: Any) -> BaseDeployment:

Returns:
BaseDeployment: Deployment implementation instance based on the deployment name.

Raises:
ValueError: If the deployment is not supported.
"""
kwargs["ctx"] = ctx
deployment = AVAILABLE_MODEL_DEPLOYMENTS.get(name)

# Check provided deployment against config const
if deployment is not None:
return deployment.deployment_class(**kwargs, **deployment.kwargs)

# Fallback to first available deployment
default = get_default_deployment(**kwargs)
if default is not None:
return default
try:
deployment = deployment_service.get_deployment_by_name(name)
malexw marked this conversation as resolved.
Show resolved Hide resolved
except DeploymentNotFoundError:
deployment = deployment_service.get_default_deployment()

raise ValueError(
f"Deployment {name} is not supported, and no available deployments were found."
)
return deployment(**kwargs)
6 changes: 0 additions & 6 deletions src/backend/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,6 @@
from backend.config.deployments import (
AVAILABLE_MODEL_DEPLOYMENTS as MANAGED_DEPLOYMENTS_SETUP,
)
from community.config.deployments import (
AVAILABLE_MODEL_DEPLOYMENTS as COMMUNITY_DEPLOYMENTS_SETUP,
)
from community.config.tools import COMMUNITY_TOOLS_SETUP


Expand Down Expand Up @@ -51,9 +48,6 @@ def start():

# SET UP ENVIRONMENT FOR DEPLOYMENTS
all_deployments = MANAGED_DEPLOYMENTS_SETUP.copy()
if use_community_features:
all_deployments.update(COMMUNITY_DEPLOYMENTS_SETUP)

selected_deployments = select_deployments_prompt(all_deployments, secrets)

for deployment in selected_deployments:
Expand Down
137 changes: 16 additions & 121 deletions src/backend/config/deployments.py
Original file line number Diff line number Diff line change
@@ -1,140 +1,35 @@
from enum import StrEnum

from backend.config.settings import Settings
from backend.model_deployments import (
AzureDeployment,
BedrockDeployment,
CohereDeployment,
SageMakerDeployment,
SingleContainerDeployment,
)
from backend.model_deployments.azure import AZURE_ENV_VARS
from backend.model_deployments.base import BaseDeployment
from backend.model_deployments.bedrock import BEDROCK_ENV_VARS
from backend.model_deployments.cohere_platform import COHERE_ENV_VARS
from backend.model_deployments.sagemaker import SAGE_MAKER_ENV_VARS
from backend.model_deployments.single_container import SC_ENV_VARS
from backend.schemas.deployment import Deployment
from backend.services.logger.utils import LoggerFactory

logger = LoggerFactory().get_logger()


class ModelDeploymentName(StrEnum):
CoherePlatform = "Cohere Platform"
SageMaker = "SageMaker"
Azure = "Azure"
Bedrock = "Bedrock"
SingleContainer = "Single Container"


use_community_features = Settings().get('feature_flags.use_community_features')
ALL_MODEL_DEPLOYMENTS = { d.name(): d for d in BaseDeployment.__subclasses__() }

# TODO names in the map below should not be the display names but ids
ALL_MODEL_DEPLOYMENTS = {
ModelDeploymentName.CoherePlatform: Deployment(
id="cohere_platform",
name=ModelDeploymentName.CoherePlatform,
deployment_class=CohereDeployment,
models=CohereDeployment.list_models(),
is_available=CohereDeployment.is_available(),
env_vars=COHERE_ENV_VARS,
),
ModelDeploymentName.SingleContainer: Deployment(
id="single_container",
name=ModelDeploymentName.SingleContainer,
deployment_class=SingleContainerDeployment,
models=SingleContainerDeployment.list_models(),
is_available=SingleContainerDeployment.is_available(),
env_vars=SC_ENV_VARS,
),
ModelDeploymentName.SageMaker: Deployment(
id="sagemaker",
name=ModelDeploymentName.SageMaker,
deployment_class=SageMakerDeployment,
models=SageMakerDeployment.list_models(),
is_available=SageMakerDeployment.is_available(),
env_vars=SAGE_MAKER_ENV_VARS,
),
ModelDeploymentName.Azure: Deployment(
id="azure",
name=ModelDeploymentName.Azure,
deployment_class=AzureDeployment,
models=AzureDeployment.list_models(),
is_available=AzureDeployment.is_available(),
env_vars=AZURE_ENV_VARS,
),
ModelDeploymentName.Bedrock: Deployment(
id="bedrock",
name=ModelDeploymentName.Bedrock,
deployment_class=BedrockDeployment,
models=BedrockDeployment.list_models(),
is_available=BedrockDeployment.is_available(),
env_vars=BEDROCK_ENV_VARS,
),
}

def get_installed_deployments() -> list[type[BaseDeployment]]:
malexw marked this conversation as resolved.
Show resolved Hide resolved
installed_deployments = list(ALL_MODEL_DEPLOYMENTS.values())

def get_available_deployments() -> dict[ModelDeploymentName, Deployment]:
if use_community_features:
if Settings().get("feature_flags.use_community_features"):
try:
from community.config.deployments import (
AVAILABLE_MODEL_DEPLOYMENTS as COMMUNITY_DEPLOYMENTS_SETUP,
)

model_deployments = ALL_MODEL_DEPLOYMENTS.copy()
model_deployments.update(COMMUNITY_DEPLOYMENTS_SETUP)
return model_deployments
except ImportError:
installed_deployments.extend(COMMUNITY_DEPLOYMENTS_SETUP.values())
except ImportError as e:
logger.warning(
event="[Deployments] No available community deployments have been configured"
event="[Deployments] No available community deployments have been configured", ex=e
)

deployments = Settings().get('deployments.enabled_deployments')
if deployments is not None and len(deployments) > 0:
return {
key: value
for key, value in ALL_MODEL_DEPLOYMENTS.items()
if value.id in Settings().get('deployments.enabled_deployments')
}

return ALL_MODEL_DEPLOYMENTS


def get_default_deployment(**kwargs) -> BaseDeployment:
# Fallback to the first available deployment
fallback = None
for deployment in AVAILABLE_MODEL_DEPLOYMENTS.values():
if deployment.is_available:
fallback = deployment.deployment_class(**kwargs)
break

default = Settings().get('deployments.default_deployment')
if default:
return next(
(
v.deployment_class(**kwargs)
for k, v in AVAILABLE_MODEL_DEPLOYMENTS.items()
if v.id == default
),
fallback,
)
else:
return fallback


def find_config_by_deployment_id(deployment_id: str) -> Deployment:
for deployment in AVAILABLE_MODEL_DEPLOYMENTS.values():
if deployment.id == deployment_id:
return deployment
return None


def find_config_by_deployment_name(deployment_name: str) -> Deployment:
for deployment in AVAILABLE_MODEL_DEPLOYMENTS.values():
if deployment.name == deployment_name:
return deployment
return None
enabled_deployment_ids = Settings().get("deployments.enabled_deployments")
if enabled_deployment_ids:
return [
deployment
for deployment in installed_deployments
if deployment.id() in enabled_deployment_ids
]

return installed_deployments

AVAILABLE_MODEL_DEPLOYMENTS = get_available_deployments()
AVAILABLE_MODEL_DEPLOYMENTS = get_installed_deployments()
18 changes: 9 additions & 9 deletions src/backend/crud/deployment.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@

from backend.database_models import AgentDeploymentModel, Deployment
from backend.model_deployments.utils import class_name_validator
from backend.schemas.deployment import Deployment as DeploymentSchema
from backend.schemas.deployment import DeploymentCreate, DeploymentUpdate
from backend.services.transaction import validate_transaction
from community.config.deployments import (
AVAILABLE_MODEL_DEPLOYMENTS as COMMUNITY_DEPLOYMENTS,
from backend.schemas.deployment import (
DeploymentCreate,
DeploymentDefinition,
DeploymentUpdate,
)
from backend.services.transaction import validate_transaction


@validate_transaction
Expand All @@ -19,7 +19,7 @@ def create_deployment(db: Session, deployment: DeploymentCreate) -> Deployment:

Args:
db (Session): Database session.
deployment (DeploymentSchema): Deployment data to be created.
deployment (DeploymentDefinition): Deployment data to be created.

Returns:
Deployment: Created deployment.
Expand Down Expand Up @@ -193,14 +193,14 @@ def delete_deployment(db: Session, deployment_id: str) -> None:


@validate_transaction
def create_deployment_by_config(db: Session, deployment_config: DeploymentSchema) -> Deployment:
def create_deployment_by_config(db: Session, deployment_config: DeploymentDefinition) -> Deployment:
malexw marked this conversation as resolved.
Show resolved Hide resolved
"""
Create a new deployment by config.

Args:
db (Session): Database session.
deployment (str): Deployment data to be created.
deployment_config (DeploymentSchema): Deployment config.
deployment_config (DeploymentDefinition): Deployment config.

Returns:
Deployment: Created deployment.
Expand All @@ -213,7 +213,7 @@ def create_deployment_by_config(db: Session, deployment_config: DeploymentSchema
for env_var in deployment_config.env_vars
},
deployment_class_name=deployment_config.deployment_class.__name__,
malexw marked this conversation as resolved.
Show resolved Hide resolved
is_community=deployment_config.name in COMMUNITY_DEPLOYMENTS
is_community=deployment_config.is_community,
)
db.add(deployment)
db.commit()
Expand Down
6 changes: 3 additions & 3 deletions src/backend/crud/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from backend.database_models import AgentDeploymentModel, Deployment
from backend.database_models.model import Model
from backend.schemas.deployment import Deployment as DeploymentSchema
from backend.schemas.deployment import DeploymentDefinition
from backend.schemas.model import ModelCreate, ModelUpdate
from backend.services.transaction import validate_transaction

Expand Down Expand Up @@ -157,14 +157,14 @@ def get_models_by_agent_id(
)


def create_model_by_config(db: Session, deployment: Deployment, deployment_config: DeploymentSchema, model: str) -> Model:
def create_model_by_config(db: Session, deployment: Deployment, deployment_config: DeploymentDefinition, model: str) -> Model:
"""
Create a new model by config if present

Args:
db (Session): Database session.
deployment (Deployment): Deployment data.
deployment_config (DeploymentSchema): Deployment config data.
deployment_config (DeploymentDefinition): Deployment config data.
model (str): Model data.

Returns:
Expand Down
19 changes: 13 additions & 6 deletions src/backend/database_models/seeders/deplyments_models_seed.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,15 @@
from sqlalchemy import text
from sqlalchemy.orm import Session

from backend.config.deployments import ALL_MODEL_DEPLOYMENTS, ModelDeploymentName
from backend.config.deployments import ALL_MODEL_DEPLOYMENTS
from backend.database_models import Deployment, Model, Organization
from backend.model_deployments import (
CohereDeployment,
SingleContainerDeployment,
SageMakerDeployment,
AzureDeployment,
BedrockDeployment,
)
from community.config.deployments import (
AVAILABLE_MODEL_DEPLOYMENTS as COMMUNITY_DEPLOYMENTS_SETUP,
)
Expand All @@ -18,7 +25,7 @@
model_deployments.update(COMMUNITY_DEPLOYMENTS_SETUP)
malexw marked this conversation as resolved.
Show resolved Hide resolved

MODELS_NAME_MAPPING = {
ModelDeploymentName.CoherePlatform: {
CohereDeployment.name(): {
"command": {
"cohere_name": "command",
"is_default": False,
Expand Down Expand Up @@ -60,7 +67,7 @@
"is_default": False,
},
},
ModelDeploymentName.SingleContainer: {
SingleContainerDeployment.name(): {
"command": {
"cohere_name": "command",
"is_default": False,
Expand Down Expand Up @@ -102,19 +109,19 @@
"is_default": False,
},
},
ModelDeploymentName.SageMaker: {
SageMakerDeployment.name(): {
"sagemaker-command": {
"cohere_name": "command",
"is_default": True,
},
},
ModelDeploymentName.Azure: {
AzureDeployment.name(): {
"azure-command": {
"cohere_name": "command-r",
"is_default": True,
},
},
ModelDeploymentName.Bedrock: {
BedrockDeployment.name(): {
"cohere.command-r-plus-v1:0": {
"cohere_name": "command-r-plus",
"is_default": True,
Expand Down
13 changes: 13 additions & 0 deletions src/backend/exceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
class ToolkitException(Exception):
"""
Base class for all toolkit exceptions.
"""

class DeploymentNotFoundError(ToolkitException):
def __init__(self, deployment_id: str):
super(DeploymentNotFoundError, self).__init__(f"Deployment {deployment_id} not found")
self.deployment_id = deployment_id

class NoAvailableDeploymentsError(ToolkitException):
malexw marked this conversation as resolved.
Show resolved Hide resolved
def __init__(self):
super(NoAvailableDeploymentsError, self).__init__("No deployments have been configured. Have the appropriate config values been added to configuration.yaml or secrets.yaml?")
15 changes: 15 additions & 0 deletions src/backend/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
)
from backend.config.routers import ROUTER_DEPENDENCIES
from backend.config.settings import Settings
from backend.exceptions import DeploymentNotFoundError
from backend.routers.agent import router as agent_router
from backend.routers.auth import router as auth_router
from backend.routers.chat import router as chat_router
Expand Down Expand Up @@ -111,6 +112,20 @@ async def validation_exception_handler(request: Request, exc: Exception):
)


@app.exception_handler(DeploymentNotFoundError)
async def deployment_not_found_handler(request: Request, exc: DeploymentNotFoundError):
ctx = get_context(request)
logger = ctx.get_logger()
logger.error(
event="Deployment not found",
deployment_id=exc.deployment_id,
)
return JSONResponse(
status_code=404,
malexw marked this conversation as resolved.
Show resolved Hide resolved
content={"detail": str(exc)},
)


@app.on_event("startup")
async def startup_event():
"""
Expand Down
Loading