Skip to content

Commit

Permalink
Merge main changes
Browse files Browse the repository at this point in the history
  • Loading branch information
tianjing-li committed Nov 6, 2024
2 parents 36264f6 + 0bb1f67 commit 4d374d4
Show file tree
Hide file tree
Showing 39 changed files with 68 additions and 69 deletions.
2 changes: 1 addition & 1 deletion src/backend/alembic/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
config = context.config

# Overwrite alembic.file `sqlachemy.url` value
config.set_main_option("sqlalchemy.url", Settings().database.url)
config.set_main_option("sqlalchemy.url", Settings().get('database.url'))

# Interpret the config file for Python logging.
# This line sets up loggers basically.
Expand Down
6 changes: 3 additions & 3 deletions src/backend/config/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
SKIP_AUTH = os.getenv("SKIP_AUTH", None)
# Ex: [BasicAuthentication]
ENABLED_AUTH_STRATEGIES = []
if ENABLED_AUTH_STRATEGIES == [] and Settings().auth.enabled_auth is not None:
ENABLED_AUTH_STRATEGIES = [auth_map[auth] for auth in Settings().auth.enabled_auth]
if ENABLED_AUTH_STRATEGIES == [] and Settings().get('auth.enabled_auth') is not None:
ENABLED_AUTH_STRATEGIES = [auth_map[auth] for auth in Settings().get('auth.enabled_auth')]
if "pytest" in sys.modules or SKIP_AUTH == "true":
ENABLED_AUTH_STRATEGIES = []

Expand All @@ -30,7 +30,7 @@
ENABLED_AUTH_STRATEGY_MAPPING = {cls.NAME: cls() for cls in ENABLED_AUTH_STRATEGIES}

# Token to authorize migration requests
MIGRATE_TOKEN = Settings().database.migrate_token
MIGRATE_TOKEN = Settings().get('database.migrate_token')

security = HTTPBearer()

Expand Down
8 changes: 4 additions & 4 deletions src/backend/config/deployments.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class ModelDeploymentName(StrEnum):
SingleContainer = "Single Container"


use_community_features = Settings().feature_flags.use_community_features
use_community_features = Settings().get('feature_flags.use_community_features')

# TODO names in the map below should not be the display names but ids
ALL_MODEL_DEPLOYMENTS = {
Expand Down Expand Up @@ -90,12 +90,12 @@ def get_available_deployments() -> dict[ModelDeploymentName, Deployment]:
event="[Deployments] No available community deployments have been configured"
)

deployments = Settings().deployments.enabled_deployments
deployments = Settings().get('deployments.enabled_deployments')
if deployments is not None and len(deployments) > 0:
return {
key: value
for key, value in ALL_MODEL_DEPLOYMENTS.items()
if value.id in Settings().deployments.enabled_deployments
if value.id in Settings().get('deployments.enabled_deployments')
}

return ALL_MODEL_DEPLOYMENTS
Expand All @@ -109,7 +109,7 @@ def get_default_deployment(**kwargs) -> BaseDeployment:
fallback = deployment.deployment_class(**kwargs)
break

default = Settings().deployments.default_deployment
default = Settings().get('deployments.default_deployment')
if default:
return next(
(
Expand Down
11 changes: 10 additions & 1 deletion src/backend/config/settings.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import sys
from typing import List, Optional, Tuple, Type
from typing import Any, List, Optional, Tuple, Type

from pydantic import AliasChoices, BaseModel, Field
from pydantic_settings import (
Expand Down Expand Up @@ -362,6 +362,15 @@ class Settings(BaseSettings):
deployments: Optional[DeploymentSettings] = Field(default=DeploymentSettings())
logger: Optional[LoggerSettings] = Field(default=LoggerSettings())

def get(self, path: str) -> Any:
keys = path.split('.')
value = self
for key in keys:
value = getattr(value, key, None)
if value is None:
return None
return value

@classmethod
def settings_customise_sources(
cls,
Expand Down
2 changes: 1 addition & 1 deletion src/backend/config/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def get_available_tools() -> dict[str, ToolDefinition]:
}

# Handle adding Community-implemented tools
use_community_tools = Settings().feature_flags.use_community_features
use_community_tools = Settings().get('feature_flags.use_community_features')
if use_community_tools:
try:
from community.config.tools import get_community_tools
Expand Down
2 changes: 1 addition & 1 deletion src/backend/database_models/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

load_dotenv()

SQLALCHEMY_DATABASE_URL = Settings().database.url
SQLALCHEMY_DATABASE_URL = Settings().get('database.url')
engine = create_engine(
SQLALCHEMY_DATABASE_URL, pool_size=5, max_overflow=10, pool_timeout=30
)
Expand Down
2 changes: 1 addition & 1 deletion src/backend/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def create_app():
dependencies_type = "default"
if is_authentication_enabled():
# Required to save temporary OAuth state in session
auth_secret = Settings().auth.secret_key
auth_secret = Settings().get('auth.secret_key')
app.add_middleware(SessionMiddleware, secret_key=auth_secret)
dependencies_type = "auth"
for router in routers:
Expand Down
2 changes: 1 addition & 1 deletion src/backend/model_deployments/azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class AzureDeployment(BaseDeployment):

DEFAULT_MODELS = ["azure-command"]

azure_config = Settings().deployments.azure
azure_config = Settings().get('deployments.azure')
default_api_key = azure_config.api_key
default_chat_endpoint_url = azure_config.endpoint_url

Expand Down
2 changes: 1 addition & 1 deletion src/backend/model_deployments/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
class BedrockDeployment(BaseDeployment):
DEFAULT_MODELS = ["cohere.command-r-plus-v1:0"]

bedrock_config = Settings().deployments.bedrock
bedrock_config = Settings().get('deployments.bedrock')
region_name = bedrock_config.region_name
access_key = bedrock_config.access_key
secret_access_key = bedrock_config.secret_key
Expand Down
2 changes: 1 addition & 1 deletion src/backend/model_deployments/cohere_platform.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class CohereDeployment(BaseDeployment):
"""Cohere Platform Deployment."""

client_name = "cohere-toolkit"
api_key = Settings().deployments.cohere_platform.api_key
api_key = Settings().get('deployments.cohere_platform.api_key')

def __init__(self, **kwargs: Any):
# Override the environment variable from the request
Expand Down
2 changes: 1 addition & 1 deletion src/backend/model_deployments/sagemaker.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class SageMakerDeployment(BaseDeployment):

DEFAULT_MODELS = ["sagemaker-command"]

sagemaker_config = Settings().deployments.sagemaker
sagemaker_config = Settings().get('deployments.sagemaker')
endpoint = sagemaker_config.endpoint_name
region_name = sagemaker_config.region_name
aws_access_key_id = sagemaker_config.access_key
Expand Down
2 changes: 1 addition & 1 deletion src/backend/model_deployments/single_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class SingleContainerDeployment(BaseDeployment):
"""Single Container Deployment."""

client_name = "cohere-toolkit"
config = Settings().deployments.single_container
config = Settings().get('deployments.single_container')
default_url = config.url
default_model = config.model

Expand Down
2 changes: 1 addition & 1 deletion src/backend/routers/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ async def tool_auth(
HTTPException: If no redirect_uri set.
"""
logger = ctx.get_logger()
redirect_uri = Settings().auth.frontend_hostname
redirect_uri = Settings().get('auth.frontend_hostname')

if not redirect_uri:
raise HTTPException(
Expand Down
4 changes: 2 additions & 2 deletions src/backend/routers/experimental_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def list_experimental_features(ctx: Context = Depends(get_context)) -> dict[str,
Dict[str, bool]: Experimental feature and their isEnabled state
"""
experimental_features = {
"USE_AGENTS_VIEW": Settings().feature_flags.use_agents_view,
"USE_TEXT_TO_SPEECH_SYNTHESIS": bool(Settings().google_cloud.api_key),
"USE_AGENTS_VIEW": Settings().get('feature_flags.use_agents_view'),
"USE_TEXT_TO_SPEECH_SYNTHESIS": bool(Settings().get('google_cloud.api_key')),
}
return experimental_features
2 changes: 1 addition & 1 deletion src/backend/routers/scim.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from backend.services.context import get_context

SCIM_PREFIX = "/scim/v2"
scim_auth = Settings().auth.scim
scim_auth = Settings().get('auth.scim')
router = APIRouter(prefix=SCIM_PREFIX)
router.name = RouterName.SCIM

Expand Down
2 changes: 1 addition & 1 deletion src/backend/services/auth/crypto.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def get_cipher() -> Fernet:
"""

# 1. Get env var
auth_key = Settings().auth.secret_key
auth_key = Settings().get('auth.secret_key')
# 2. Hash env var using SHA-256
hash_digest = hashlib.sha256(auth_key.encode()).digest()
# 3. Base64 encode hash and get 32-byte key
Expand Down
2 changes: 1 addition & 1 deletion src/backend/services/auth/jwt.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ class JWTService:
ALGORITHM = "HS256"

def __init__(self):
secret_key = Settings().auth.secret_key
secret_key = Settings().get('auth.secret_key')

if not secret_key:
raise ValueError(
Expand Down
4 changes: 2 additions & 2 deletions src/backend/services/auth/strategies/google_oauth.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@ class GoogleOAuth(BaseOAuthStrategy):

def __init__(self):
try:
self.settings = Settings().auth.google_oauth
self.settings = Settings().get('auth.google_oauth')
self.REDIRECT_URI = (
f"{Settings().auth.frontend_hostname}/auth/{self.NAME.lower()}"
f"{Settings().get('auth.frontend_hostname')}/auth/{self.NAME.lower()}"
)
self.client = OAuth2Session(
client_id=self.settings.client_id,
Expand Down
4 changes: 2 additions & 2 deletions src/backend/services/auth/strategies/oidc.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@ class OpenIDConnect(BaseOAuthStrategy):

def __init__(self):
try:
self.settings = Settings().auth.oidc
self.settings = Settings().get('auth.oidc')
self.REDIRECT_URI = (
f"{Settings().auth.frontend_hostname}/auth/{self.NAME.lower()}"
f"{Settings().get('auth.frontend_hostname')}/auth/{self.NAME.lower()}"
)
self.WELL_KNOWN_ENDPOINT = self.settings.well_known_endpoint
self.client = OAuth2Session(
Expand Down
2 changes: 1 addition & 1 deletion src/backend/services/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@


def get_client() -> Redis:
redis_url = Settings().redis.url
redis_url = Settings().get('redis.url')

if not redis_url:
error = "Tried retrieving Redis client but redis.url in configuration.yaml is not set."
Expand Down
6 changes: 3 additions & 3 deletions src/backend/services/logger/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@ def get_logger(self) -> BaseLogger:
if self.logger is not None:
return self.logger

strategy = Settings().logger.strategy
level = Settings().logger.level
renderer = Settings().logger.renderer
strategy = Settings().get('logger.strategy')
level = Settings().get('logger.level')
renderer = Settings().get('logger.renderer')

if strategy == "structlog":
return StructuredLogging(level, renderer)
Expand Down
2 changes: 1 addition & 1 deletion src/backend/services/synthesizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def _validate_google_cloud_api_key() -> str:
Raises:
ValueError: If the API key is not found in the settings or is empty.
"""
google_cloud = Settings().google_cloud
google_cloud = Settings().get('google_cloud')

if not google_cloud:
raise ValueError("google_cloud in secrets.yaml is missing.")
Expand Down
2 changes: 1 addition & 1 deletion src/backend/tests/integration/routers/test_conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def test_generate_title_error_invalid_model(
# SYNTHESIZE


is_google_cloud_api_key_set = bool(Settings().google_cloud.api_key)
is_google_cloud_api_key_set = bool(Settings().get('google_cloud.api_key'))


@pytest.mark.skipif(not is_google_cloud_api_key_set, reason="Google Cloud API key not set, skipping test")
Expand Down
2 changes: 1 addition & 1 deletion src/backend/tests/integration/services/test_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from backend.services.cache import get_client

# skip if redis is not available
is_redis_env_set = Settings().redis.url
is_redis_env_set = Settings().get('redis.url')


@pytest.mark.skipif(not is_redis_env_set, reason="Redis is not set")
Expand Down
2 changes: 1 addition & 1 deletion src/backend/tests/unit/routers/test_scim.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import backend.crud.user as user_repo
from backend.config import Settings

scim = Settings().auth.scim
scim = Settings().get('auth.scim')
encoded_auth = base64.b64encode(
f"{scim.username}:{scim.password}".encode("utf-8")
).decode("utf-8")
Expand Down
6 changes: 3 additions & 3 deletions src/backend/tools/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,9 @@ class BaseToolAuthentication(ABC):
"""

def __init__(self, *args, **kwargs):
self.BACKEND_HOST = Settings().auth.backend_hostname
self.FRONTEND_HOST = Settings().auth.frontend_hostname
self.AUTH_SECRET_KEY = Settings().auth.secret_key
self.BACKEND_HOST = Settings().get('auth.backend_hostname')
self.FRONTEND_HOST = Settings().get('auth.frontend_hostname')
self.AUTH_SECRET_KEY = Settings().get('auth.secret_key')

self._post_init_check()

Expand Down
2 changes: 1 addition & 1 deletion src/backend/tools/brave_search/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

class BraveWebSearch(BaseTool, WebSearchFilteringMixin):
ID = "brave_web_search"
BRAVE_API_KEY = Settings().tools.brave_web_search.api_key
BRAVE_API_KEY = Settings().get('tools.brave_web_search.api_key')

def __init__(self):
self.client = BraveClient(api_key=self.BRAVE_API_KEY)
Expand Down
4 changes: 2 additions & 2 deletions src/backend/tools/google_drive/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ class GoogleDriveAuth(BaseToolAuthentication, ToolAuthenticationCacheMixin):

def __init__(self):
super().__init__()
self.GOOGLE_DRIVE_CLIENT_ID = Settings().tools.google_drive.client_id
self.GOOGLE_DRIVE_CLIENT_SECRET = Settings().tools.google_drive.client_secret
self.GOOGLE_DRIVE_CLIENT_ID = Settings().get('tools.google_drive.client_id')
self.GOOGLE_DRIVE_CLIENT_SECRET = Settings().get('tools.google_drive.client_secret')
self.REDIRECT_URL = f"{self.BACKEND_HOST}/v1/tool/auth"

if (
Expand Down
5 changes: 2 additions & 3 deletions src/backend/tools/google_drive/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,8 @@ class GoogleDrive(BaseTool):
Tool that searches Google Drive
"""
ID = GOOGLE_DRIVE_TOOL_ID

CLIENT_ID = Settings().tools.google_drive.client_id
CLIENT_SECRET = Settings().tools.google_drive.client_secret
CLIENT_ID = Settings().get('tools.google_drive.client_id')
CLIENT_SECRET = Settings().get('tools.google_drive.client_secret')

@classmethod
def is_available(cls) -> bool:
Expand Down
4 changes: 2 additions & 2 deletions src/backend/tools/google_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@

class GoogleWebSearch(BaseTool, WebSearchFilteringMixin):
ID = "google_web_search"
API_KEY = Settings().tools.google_web_search.api_key
CSE_ID = Settings().tools.google_web_search.cse_id
API_KEY = Settings().get('tools.google_web_search.api_key')
CSE_ID = Settings().get('tools.google_web_search.cse_id')

def __init__(self):
self.client = build("customsearch", "v1", developerKey=self.API_KEY)
Expand Down
2 changes: 1 addition & 1 deletion src/backend/tools/hybrid_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class HybridWebSearch(BaseTool, WebSearchFilteringMixin):
ID = "hybrid_web_search"
POST_RERANK_MAX_RESULTS = 6
AVAILABLE_WEB_SEARCH_TOOLS = [TavilyWebSearch, GoogleWebSearch, BraveWebSearch]
ENABLED_WEB_SEARCH_TOOLS = Settings().tools.hybrid_web_search.enabled_web_searches
ENABLED_WEB_SEARCH_TOOLS = Settings().get('tools.hybrid_web_search.enabled_web_searches')
WEB_SCRAPE_TOOL = WebScrapeTool

def __init__(self):
Expand Down
3 changes: 1 addition & 2 deletions src/backend/tools/lang_chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,8 @@ class LangChainVectorDBRetriever(BaseTool):
"""
This class retrieves documents from a vector database using the langchain package.
"""

ID = "vector_retriever"
COHERE_API_KEY = Settings().deployments.cohere_platform.api_key
COHERE_API_KEY = Settings().get('deployments.cohere_platform.api_key')

def __init__(self, filepath: str):
self.filepath = filepath
Expand Down
2 changes: 1 addition & 1 deletion src/backend/tools/python_interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ class PythonInterpreter(BaseTool):
"""

ID = "toolkit_python_interpreter"
INTERPRETER_URL = Settings().tools.python_interpreter.url
INTERPRETER_URL = Settings().get('tools.python_interpreter.url')

@classmethod
def is_available(cls) -> bool:
Expand Down
9 changes: 3 additions & 6 deletions src/backend/tools/slack/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,9 @@ class SlackAuth(BaseToolAuthentication, ToolAuthenticationCacheMixin):
def __init__(self):
super().__init__()

settings = Settings()
slack_settings = settings.tools.slack if settings.tools and settings.tools.slack else None
self.SLACK_CLIENT_ID = getattr(slack_settings, 'client_id', None)
self.SLACK_CLIENT_SECRET = getattr(slack_settings, 'client_secret', None)
self.USER_SCOPES = getattr(slack_settings, 'user_scopes', None) or self.DEFAULT_USER_SCOPES

self.SLACK_CLIENT_ID = Settings().get('tools.slack.client_id')
self.SLACK_CLIENT_SECRET = Settings().get('tools.slack.client_secret')
self.USER_SCOPES = Settings().get('tools.slack.user_scopes') or self.DEFAULT_USER_SCOPES
self.REDIRECT_URL = f"{self.BACKEND_HOST}/v1/tool/auth"

if any([
Expand Down
Loading

0 comments on commit 4d374d4

Please sign in to comment.