-
Notifications
You must be signed in to change notification settings - Fork 381
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
backend: Deployments refactor; Add deployment service and fix deploym…
…ent config setting (#831) * Deployments refactor; Add deployment service and fix deployment config setting * Changes for code review * Fix a number of integration and unit tests * Fix failing chat tests * Move some tests from unit/routers to integration/routers * Fix a few more tests * Fix a few more tests * Fix remainder of broken integration tests * Fix lint issues * Run prettier on Coral * Remove old, unused model crud helper * Fix failing deployments unit tests * Coral fix to account for agent.tools possibly being null * Fix TS styling * Provide a dummy Cohere API key during testing * Update Coral to align with latest version of the backend API * Fix lint issues in Coral * Last few changes for code review * Update generated API in assistants_web * Fix assistants_web build * Fix backend lint issues * Simplify validate_deployment_header * Don't seed the DB with deployment data, and fix a DeploymentDefinition serialization issue * Fix backend lint issues * Fix broken unit tests * Skip cohere deployments tests since they're breaking other tests * Fix deployment integration tests * More fixes to deployments integration tests * Fix deployment integration tests * What API key are we using to call Cohere in the tests? * Mock list_models of CoherePlatform model to avoid Cohere API calls * Mask all deployment config values when looking up deployment info * Fix integration tests * Fix typecheck issue * wip * Minor changes * use defaults * resolve lint * Update generated client --------- Co-authored-by: Tianjing Li <[email protected]>
- Loading branch information
1 parent
5ad48e6
commit 802c232
Showing
88 changed files
with
6,130 additions
and
3,691 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,140 +1,35 @@ | ||
from enum import StrEnum | ||
|
||
from backend.config.settings import Settings | ||
from backend.model_deployments import ( | ||
AzureDeployment, | ||
BedrockDeployment, | ||
CohereDeployment, | ||
SageMakerDeployment, | ||
SingleContainerDeployment, | ||
) | ||
from backend.model_deployments.azure import AZURE_ENV_VARS | ||
from backend.model_deployments.base import BaseDeployment | ||
from backend.model_deployments.bedrock import BEDROCK_ENV_VARS | ||
from backend.model_deployments.cohere_platform import COHERE_ENV_VARS | ||
from backend.model_deployments.sagemaker import SAGE_MAKER_ENV_VARS | ||
from backend.model_deployments.single_container import SC_ENV_VARS | ||
from backend.schemas.deployment import Deployment | ||
from backend.services.logger.utils import LoggerFactory | ||
|
||
logger = LoggerFactory().get_logger() | ||
|
||
|
||
class ModelDeploymentName(StrEnum): | ||
CoherePlatform = "Cohere Platform" | ||
SageMaker = "SageMaker" | ||
Azure = "Azure" | ||
Bedrock = "Bedrock" | ||
SingleContainer = "Single Container" | ||
|
||
|
||
use_community_features = Settings().get('feature_flags.use_community_features') | ||
ALL_MODEL_DEPLOYMENTS = { d.name(): d for d in BaseDeployment.__subclasses__() } | ||
|
||
# TODO names in the map below should not be the display names but ids | ||
ALL_MODEL_DEPLOYMENTS = { | ||
ModelDeploymentName.CoherePlatform: Deployment( | ||
id="cohere_platform", | ||
name=ModelDeploymentName.CoherePlatform, | ||
deployment_class=CohereDeployment, | ||
models=CohereDeployment.list_models(), | ||
is_available=CohereDeployment.is_available(), | ||
env_vars=COHERE_ENV_VARS, | ||
), | ||
ModelDeploymentName.SingleContainer: Deployment( | ||
id="single_container", | ||
name=ModelDeploymentName.SingleContainer, | ||
deployment_class=SingleContainerDeployment, | ||
models=SingleContainerDeployment.list_models(), | ||
is_available=SingleContainerDeployment.is_available(), | ||
env_vars=SC_ENV_VARS, | ||
), | ||
ModelDeploymentName.SageMaker: Deployment( | ||
id="sagemaker", | ||
name=ModelDeploymentName.SageMaker, | ||
deployment_class=SageMakerDeployment, | ||
models=SageMakerDeployment.list_models(), | ||
is_available=SageMakerDeployment.is_available(), | ||
env_vars=SAGE_MAKER_ENV_VARS, | ||
), | ||
ModelDeploymentName.Azure: Deployment( | ||
id="azure", | ||
name=ModelDeploymentName.Azure, | ||
deployment_class=AzureDeployment, | ||
models=AzureDeployment.list_models(), | ||
is_available=AzureDeployment.is_available(), | ||
env_vars=AZURE_ENV_VARS, | ||
), | ||
ModelDeploymentName.Bedrock: Deployment( | ||
id="bedrock", | ||
name=ModelDeploymentName.Bedrock, | ||
deployment_class=BedrockDeployment, | ||
models=BedrockDeployment.list_models(), | ||
is_available=BedrockDeployment.is_available(), | ||
env_vars=BEDROCK_ENV_VARS, | ||
), | ||
} | ||
|
||
def get_available_deployments() -> list[type[BaseDeployment]]: | ||
installed_deployments = list(ALL_MODEL_DEPLOYMENTS.values()) | ||
|
||
def get_available_deployments() -> dict[ModelDeploymentName, Deployment]: | ||
if use_community_features: | ||
if Settings().get("feature_flags.use_community_features"): | ||
try: | ||
from community.config.deployments import ( | ||
AVAILABLE_MODEL_DEPLOYMENTS as COMMUNITY_DEPLOYMENTS_SETUP, | ||
) | ||
|
||
model_deployments = ALL_MODEL_DEPLOYMENTS.copy() | ||
model_deployments.update(COMMUNITY_DEPLOYMENTS_SETUP) | ||
return model_deployments | ||
except ImportError: | ||
installed_deployments.extend(COMMUNITY_DEPLOYMENTS_SETUP.values()) | ||
except ImportError as e: | ||
logger.warning( | ||
event="[Deployments] No available community deployments have been configured" | ||
event="[Deployments] No available community deployments have been configured", ex=e | ||
) | ||
|
||
deployments = Settings().get('deployments.enabled_deployments') | ||
if deployments is not None and len(deployments) > 0: | ||
return { | ||
key: value | ||
for key, value in ALL_MODEL_DEPLOYMENTS.items() | ||
if value.id in Settings().get('deployments.enabled_deployments') | ||
} | ||
|
||
return ALL_MODEL_DEPLOYMENTS | ||
|
||
|
||
def get_default_deployment(**kwargs) -> BaseDeployment: | ||
# Fallback to the first available deployment | ||
fallback = None | ||
for deployment in AVAILABLE_MODEL_DEPLOYMENTS.values(): | ||
if deployment.is_available: | ||
fallback = deployment.deployment_class(**kwargs) | ||
break | ||
|
||
default = Settings().get('deployments.default_deployment') | ||
if default: | ||
return next( | ||
( | ||
v.deployment_class(**kwargs) | ||
for k, v in AVAILABLE_MODEL_DEPLOYMENTS.items() | ||
if v.id == default | ||
), | ||
fallback, | ||
) | ||
else: | ||
return fallback | ||
|
||
|
||
def find_config_by_deployment_id(deployment_id: str) -> Deployment: | ||
for deployment in AVAILABLE_MODEL_DEPLOYMENTS.values(): | ||
if deployment.id == deployment_id: | ||
return deployment | ||
return None | ||
|
||
|
||
def find_config_by_deployment_name(deployment_name: str) -> Deployment: | ||
for deployment in AVAILABLE_MODEL_DEPLOYMENTS.values(): | ||
if deployment.name == deployment_name: | ||
return deployment | ||
return None | ||
enabled_deployment_ids = Settings().get("deployments.enabled_deployments") | ||
if enabled_deployment_ids: | ||
return [ | ||
deployment | ||
for deployment in installed_deployments | ||
if deployment.id() in enabled_deployment_ids | ||
] | ||
|
||
return installed_deployments | ||
|
||
AVAILABLE_MODEL_DEPLOYMENTS = get_available_deployments() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
25 changes: 25 additions & 0 deletions
25
src/backend/database_models/seeders/deployments_models_seed.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
from sqlalchemy.orm import Session | ||
|
||
from backend.database_models import Deployment, Model, Organization | ||
|
||
|
||
def deployments_models_seed(op): | ||
""" | ||
Seed default deployments, models, organization, user and agent. | ||
""" | ||
# Previously we would seed the default deployments and models here. We've changed this | ||
# behaviour during a refactor of the deployments module so that deployments and models | ||
# are inserted when they're first used. This solves an issue where seed data would | ||
# sometimes be inserted with invalid config data. | ||
pass | ||
|
||
|
||
def delete_default_models(op): | ||
""" | ||
Delete deployments and models. | ||
""" | ||
session = Session(op.get_bind()) | ||
session.query(Deployment).delete() | ||
session.query(Model).delete() | ||
session.query(Organization).filter_by(id="default").delete() | ||
session.commit() |
Oops, something went wrong.