From 0bb1f670987faeab3b163f61d4b36d89560ada28 Mon Sep 17 00:00:00 2001 From: Eugene P <144219719+EugeneLightsOn@users.noreply.github.com> Date: Mon, 4 Nov 2024 18:21:35 +0100 Subject: [PATCH] Config settings improvements (#828) * TLK-1987 Initial commit * TLK-1987 Slack tool config changes are applied * TLK-1987 Slack tool small improvement * TLK-1987 Lint --- src/backend/alembic/env.py | 2 +- src/backend/config/auth.py | 6 +++--- src/backend/config/deployments.py | 8 ++++---- src/backend/config/settings.py | 11 ++++++++++- src/backend/config/tools.py | 4 ++-- src/backend/database_models/database.py | 2 +- src/backend/main.py | 2 +- src/backend/model_deployments/azure.py | 2 +- src/backend/model_deployments/bedrock.py | 2 +- src/backend/model_deployments/cohere_platform.py | 2 +- src/backend/model_deployments/sagemaker.py | 2 +- src/backend/model_deployments/single_container.py | 2 +- src/backend/routers/auth.py | 2 +- src/backend/routers/experimental_features.py | 4 ++-- src/backend/routers/scim.py | 2 +- src/backend/services/auth/crypto.py | 2 +- src/backend/services/auth/jwt.py | 2 +- src/backend/services/auth/strategies/google_oauth.py | 4 ++-- src/backend/services/auth/strategies/oidc.py | 4 ++-- src/backend/services/cache.py | 2 +- src/backend/services/logger/utils.py | 6 +++--- src/backend/services/synthesizer.py | 2 +- .../tests/integration/routers/test_conversation.py | 2 +- src/backend/tests/integration/services/test_cache.py | 2 +- src/backend/tests/unit/routers/test_scim.py | 2 +- src/backend/tools/base.py | 6 +++--- src/backend/tools/brave_search/tool.py | 2 +- src/backend/tools/google_drive/auth.py | 4 ++-- src/backend/tools/google_drive/tool.py | 4 ++-- src/backend/tools/google_search.py | 4 ++-- src/backend/tools/hybrid_search.py | 2 +- src/backend/tools/lang_chain.py | 2 +- src/backend/tools/python_interpreter.py | 2 +- src/backend/tools/slack/auth.py | 9 +++------ src/backend/tools/slack/tool.py | 9 ++------- src/backend/tools/slack/utils.py | 6 +++--- src/backend/tools/tavily_search.py | 2 +- src/community/tools/llama_index.py | 2 +- src/community/tools/wolfram.py | 2 +- 39 files changed, 69 insertions(+), 68 deletions(-) diff --git a/src/backend/alembic/env.py b/src/backend/alembic/env.py index d1f51303ef..ac7d474b61 100644 --- a/src/backend/alembic/env.py +++ b/src/backend/alembic/env.py @@ -16,7 +16,7 @@ config = context.config # Overwrite alembic.file `sqlachemy.url` value -config.set_main_option("sqlalchemy.url", Settings().database.url) +config.set_main_option("sqlalchemy.url", Settings().get('database.url')) # Interpret the config file for Python logging. # This line sets up loggers basically. diff --git a/src/backend/config/auth.py b/src/backend/config/auth.py index 6cbdc96788..4c55d0bbe8 100644 --- a/src/backend/config/auth.py +++ b/src/backend/config/auth.py @@ -19,8 +19,8 @@ SKIP_AUTH = os.getenv("SKIP_AUTH", None) # Ex: [BasicAuthentication] ENABLED_AUTH_STRATEGIES = [] -if ENABLED_AUTH_STRATEGIES == [] and Settings().auth.enabled_auth is not None: - ENABLED_AUTH_STRATEGIES = [auth_map[auth] for auth in Settings().auth.enabled_auth] +if ENABLED_AUTH_STRATEGIES == [] and Settings().get('auth.enabled_auth') is not None: + ENABLED_AUTH_STRATEGIES = [auth_map[auth] for auth in Settings().get('auth.enabled_auth')] if "pytest" in sys.modules or SKIP_AUTH == "true": ENABLED_AUTH_STRATEGIES = [] @@ -30,7 +30,7 @@ ENABLED_AUTH_STRATEGY_MAPPING = {cls.NAME: cls() for cls in ENABLED_AUTH_STRATEGIES} # Token to authorize migration requests -MIGRATE_TOKEN = Settings().database.migrate_token +MIGRATE_TOKEN = Settings().get('database.migrate_token') security = HTTPBearer() diff --git a/src/backend/config/deployments.py b/src/backend/config/deployments.py index 2397ce9eff..32eb8e0e59 100644 --- a/src/backend/config/deployments.py +++ b/src/backend/config/deployments.py @@ -28,7 +28,7 @@ class ModelDeploymentName(StrEnum): SingleContainer = "Single Container" -use_community_features = Settings().feature_flags.use_community_features +use_community_features = Settings().get('feature_flags.use_community_features') # TODO names in the map below should not be the display names but ids ALL_MODEL_DEPLOYMENTS = { @@ -90,12 +90,12 @@ def get_available_deployments() -> dict[ModelDeploymentName, Deployment]: event="[Deployments] No available community deployments have been configured" ) - deployments = Settings().deployments.enabled_deployments + deployments = Settings().get('deployments.enabled_deployments') if deployments is not None and len(deployments) > 0: return { key: value for key, value in ALL_MODEL_DEPLOYMENTS.items() - if value.id in Settings().deployments.enabled_deployments + if value.id in Settings().get('deployments.enabled_deployments') } return ALL_MODEL_DEPLOYMENTS @@ -109,7 +109,7 @@ def get_default_deployment(**kwargs) -> BaseDeployment: fallback = deployment.deployment_class(**kwargs) break - default = Settings().deployments.default_deployment + default = Settings().get('deployments.default_deployment') if default: return next( ( diff --git a/src/backend/config/settings.py b/src/backend/config/settings.py index 1d14053987..a60be0942b 100644 --- a/src/backend/config/settings.py +++ b/src/backend/config/settings.py @@ -1,5 +1,5 @@ import sys -from typing import List, Optional, Tuple, Type +from typing import Any, List, Optional, Tuple, Type from pydantic import AliasChoices, BaseModel, Field from pydantic_settings import ( @@ -363,6 +363,15 @@ class Settings(BaseSettings): deployments: Optional[DeploymentSettings] = Field(default=DeploymentSettings()) logger: Optional[LoggerSettings] = Field(default=LoggerSettings()) + def get(self, path: str) -> Any: + keys = path.split('.') + value = self + for key in keys: + value = getattr(value, key, None) + if value is None: + return None + return value + @classmethod def settings_customise_sources( cls, diff --git a/src/backend/config/tools.py b/src/backend/config/tools.py index eef89f556b..6a5d7a13b4 100644 --- a/src/backend/config/tools.py +++ b/src/backend/config/tools.py @@ -263,7 +263,7 @@ class ToolName(StrEnum): def get_available_tools() -> dict[ToolName, dict]: - use_community_tools = Settings().feature_flags.use_community_features + use_community_tools = Settings().get('feature_flags.use_community_features') tools = ALL_TOOLS.copy() if use_community_tools: @@ -283,7 +283,7 @@ def get_available_tools() -> dict[ToolName, dict]: # Retrieve name tool.name = tool.implementation.NAME - enabled_tools = Settings().tools.enabled_tools + enabled_tools = Settings().get('tools.enabled_tools') if enabled_tools is not None and len(enabled_tools) > 0: tools = {key: value for key, value in tools.items() if key in enabled_tools} return tools diff --git a/src/backend/database_models/database.py b/src/backend/database_models/database.py index 7f6ec597f0..937fa822e5 100644 --- a/src/backend/database_models/database.py +++ b/src/backend/database_models/database.py @@ -10,7 +10,7 @@ load_dotenv() -SQLALCHEMY_DATABASE_URL = Settings().database.url +SQLALCHEMY_DATABASE_URL = Settings().get('database.url') engine = create_engine( SQLALCHEMY_DATABASE_URL, pool_size=5, max_overflow=10, pool_timeout=30 ) diff --git a/src/backend/main.py b/src/backend/main.py index 9b79139eb4..9569cde052 100644 --- a/src/backend/main.py +++ b/src/backend/main.py @@ -58,7 +58,7 @@ def create_app(): dependencies_type = "default" if is_authentication_enabled(): # Required to save temporary OAuth state in session - auth_secret = Settings().auth.secret_key + auth_secret = Settings().get('auth.secret_key') app.add_middleware(SessionMiddleware, secret_key=auth_secret) dependencies_type = "auth" for router in routers: diff --git a/src/backend/model_deployments/azure.py b/src/backend/model_deployments/azure.py index 4c373087f3..e7849f0371 100644 --- a/src/backend/model_deployments/azure.py +++ b/src/backend/model_deployments/azure.py @@ -25,7 +25,7 @@ class AzureDeployment(BaseDeployment): DEFAULT_MODELS = ["azure-command"] - azure_config = Settings().deployments.azure + azure_config = Settings().get('deployments.azure') default_api_key = azure_config.api_key default_chat_endpoint_url = azure_config.endpoint_url diff --git a/src/backend/model_deployments/bedrock.py b/src/backend/model_deployments/bedrock.py index fa3eb5613b..094ed243a3 100644 --- a/src/backend/model_deployments/bedrock.py +++ b/src/backend/model_deployments/bedrock.py @@ -24,7 +24,7 @@ class BedrockDeployment(BaseDeployment): DEFAULT_MODELS = ["cohere.command-r-plus-v1:0"] - bedrock_config = Settings().deployments.bedrock + bedrock_config = Settings().get('deployments.bedrock') region_name = bedrock_config.region_name access_key = bedrock_config.access_key secret_access_key = bedrock_config.secret_key diff --git a/src/backend/model_deployments/cohere_platform.py b/src/backend/model_deployments/cohere_platform.py index 9f6b042e08..f8da71693d 100644 --- a/src/backend/model_deployments/cohere_platform.py +++ b/src/backend/model_deployments/cohere_platform.py @@ -20,7 +20,7 @@ class CohereDeployment(BaseDeployment): """Cohere Platform Deployment.""" client_name = "cohere-toolkit" - api_key = Settings().deployments.cohere_platform.api_key + api_key = Settings().get('deployments.cohere_platform.api_key') def __init__(self, **kwargs: Any): # Override the environment variable from the request diff --git a/src/backend/model_deployments/sagemaker.py b/src/backend/model_deployments/sagemaker.py index 5eafbd763a..56d2a96555 100644 --- a/src/backend/model_deployments/sagemaker.py +++ b/src/backend/model_deployments/sagemaker.py @@ -33,7 +33,7 @@ class SageMakerDeployment(BaseDeployment): DEFAULT_MODELS = ["sagemaker-command"] - sagemaker_config = Settings().deployments.sagemaker + sagemaker_config = Settings().get('deployments.sagemaker') endpoint = sagemaker_config.endpoint_name region_name = sagemaker_config.region_name aws_access_key_id = sagemaker_config.access_key diff --git a/src/backend/model_deployments/single_container.py b/src/backend/model_deployments/single_container.py index 2cfc36cd31..9c727a2186 100644 --- a/src/backend/model_deployments/single_container.py +++ b/src/backend/model_deployments/single_container.py @@ -19,7 +19,7 @@ class SingleContainerDeployment(BaseDeployment): """Single Container Deployment.""" client_name = "cohere-toolkit" - config = Settings().deployments.single_container + config = Settings().get('deployments.single_container') default_url = config.url default_model = config.model diff --git a/src/backend/routers/auth.py b/src/backend/routers/auth.py index 197bdbfcb5..b726fef898 100644 --- a/src/backend/routers/auth.py +++ b/src/backend/routers/auth.py @@ -259,7 +259,7 @@ async def tool_auth( HTTPException: If no redirect_uri set. """ logger = ctx.get_logger() - redirect_uri = Settings().auth.frontend_hostname + redirect_uri = Settings().get('auth.frontend_hostname') if not redirect_uri: raise HTTPException( diff --git a/src/backend/routers/experimental_features.py b/src/backend/routers/experimental_features.py index 55f6450efe..944a34684d 100644 --- a/src/backend/routers/experimental_features.py +++ b/src/backend/routers/experimental_features.py @@ -23,7 +23,7 @@ def list_experimental_features(ctx: Context = Depends(get_context)) -> dict[str, Dict[str, bool]: Experimental feature and their isEnabled state """ experimental_features = { - "USE_AGENTS_VIEW": Settings().feature_flags.use_agents_view, - "USE_TEXT_TO_SPEECH_SYNTHESIS": bool(Settings().google_cloud.api_key), + "USE_AGENTS_VIEW": Settings().get('feature_flags.use_agents_view'), + "USE_TEXT_TO_SPEECH_SYNTHESIS": bool(Settings().get('google_cloud.api_key')), } return experimental_features diff --git a/src/backend/routers/scim.py b/src/backend/routers/scim.py index 09a0504fe2..1e03dd0b30 100644 --- a/src/backend/routers/scim.py +++ b/src/backend/routers/scim.py @@ -26,7 +26,7 @@ from backend.services.context import get_context SCIM_PREFIX = "/scim/v2" -scim_auth = Settings().auth.scim +scim_auth = Settings().get('auth.scim') router = APIRouter(prefix=SCIM_PREFIX) router.name = RouterName.SCIM diff --git a/src/backend/services/auth/crypto.py b/src/backend/services/auth/crypto.py index 5523aac528..97ed037ef1 100644 --- a/src/backend/services/auth/crypto.py +++ b/src/backend/services/auth/crypto.py @@ -12,7 +12,7 @@ def get_cipher() -> Fernet: """ # 1. Get env var - auth_key = Settings().auth.secret_key + auth_key = Settings().get('auth.secret_key') # 2. Hash env var using SHA-256 hash_digest = hashlib.sha256(auth_key.encode()).digest() # 3. Base64 encode hash and get 32-byte key diff --git a/src/backend/services/auth/jwt.py b/src/backend/services/auth/jwt.py index ece3bbc5d7..b7346faa94 100644 --- a/src/backend/services/auth/jwt.py +++ b/src/backend/services/auth/jwt.py @@ -15,7 +15,7 @@ class JWTService: ALGORITHM = "HS256" def __init__(self): - secret_key = Settings().auth.secret_key + secret_key = Settings().get('auth.secret_key') if not secret_key: raise ValueError( diff --git a/src/backend/services/auth/strategies/google_oauth.py b/src/backend/services/auth/strategies/google_oauth.py index 0f14515054..50b073bf1e 100644 --- a/src/backend/services/auth/strategies/google_oauth.py +++ b/src/backend/services/auth/strategies/google_oauth.py @@ -19,9 +19,9 @@ class GoogleOAuth(BaseOAuthStrategy): def __init__(self): try: - self.settings = Settings().auth.google_oauth + self.settings = Settings().get('auth.google_oauth') self.REDIRECT_URI = ( - f"{Settings().auth.frontend_hostname}/auth/{self.NAME.lower()}" + f"{Settings().get('auth.frontend_hostname')}/auth/{self.NAME.lower()}" ) self.client = OAuth2Session( client_id=self.settings.client_id, diff --git a/src/backend/services/auth/strategies/oidc.py b/src/backend/services/auth/strategies/oidc.py index 3a3fc2d237..b7db636214 100644 --- a/src/backend/services/auth/strategies/oidc.py +++ b/src/backend/services/auth/strategies/oidc.py @@ -20,9 +20,9 @@ class OpenIDConnect(BaseOAuthStrategy): def __init__(self): try: - self.settings = Settings().auth.oidc + self.settings = Settings().get('auth.oidc') self.REDIRECT_URI = ( - f"{Settings().auth.frontend_hostname}/auth/{self.NAME.lower()}" + f"{Settings().get('auth.frontend_hostname')}/auth/{self.NAME.lower()}" ) self.WELL_KNOWN_ENDPOINT = self.settings.well_known_endpoint self.client = OAuth2Session( diff --git a/src/backend/services/cache.py b/src/backend/services/cache.py index b24938976b..698598caef 100644 --- a/src/backend/services/cache.py +++ b/src/backend/services/cache.py @@ -9,7 +9,7 @@ def get_client() -> Redis: - redis_url = Settings().redis.url + redis_url = Settings().get('redis.url') if not redis_url: error = "Tried retrieving Redis client but redis.url in configuration.yaml is not set." diff --git a/src/backend/services/logger/utils.py b/src/backend/services/logger/utils.py index 49580caa4f..dc69af7fd7 100644 --- a/src/backend/services/logger/utils.py +++ b/src/backend/services/logger/utils.py @@ -11,9 +11,9 @@ def get_logger(self) -> BaseLogger: if self.logger is not None: return self.logger - strategy = Settings().logger.strategy - level = Settings().logger.level - renderer = Settings().logger.renderer + strategy = Settings().get('logger.strategy') + level = Settings().get('logger.level') + renderer = Settings().get('logger.renderer') if strategy == "structlog": return StructuredLogging(level, renderer) diff --git a/src/backend/services/synthesizer.py b/src/backend/services/synthesizer.py index 04c376b778..0c320d1af7 100644 --- a/src/backend/services/synthesizer.py +++ b/src/backend/services/synthesizer.py @@ -68,7 +68,7 @@ def _validate_google_cloud_api_key() -> str: Raises: ValueError: If the API key is not found in the settings or is empty. """ - google_cloud = Settings().google_cloud + google_cloud = Settings().get('google_cloud') if not google_cloud: raise ValueError("google_cloud in secrets.yaml is missing.") diff --git a/src/backend/tests/integration/routers/test_conversation.py b/src/backend/tests/integration/routers/test_conversation.py index 7b330a819c..7d48fc4305 100644 --- a/src/backend/tests/integration/routers/test_conversation.py +++ b/src/backend/tests/integration/routers/test_conversation.py @@ -163,7 +163,7 @@ def test_generate_title_error_invalid_model( # SYNTHESIZE -is_google_cloud_api_key_set = bool(Settings().google_cloud.api_key) +is_google_cloud_api_key_set = bool(Settings().get('google_cloud.api_key')) @pytest.mark.skipif(not is_google_cloud_api_key_set, reason="Google Cloud API key not set, skipping test") diff --git a/src/backend/tests/integration/services/test_cache.py b/src/backend/tests/integration/services/test_cache.py index 87a4e620e1..0f9f409fca 100644 --- a/src/backend/tests/integration/services/test_cache.py +++ b/src/backend/tests/integration/services/test_cache.py @@ -4,7 +4,7 @@ from backend.services.cache import get_client # skip if redis is not available -is_redis_env_set = Settings().redis.url +is_redis_env_set = Settings().get('redis.url') @pytest.mark.skipif(not is_redis_env_set, reason="Redis is not set") diff --git a/src/backend/tests/unit/routers/test_scim.py b/src/backend/tests/unit/routers/test_scim.py index 3633ea9d41..c2212b606b 100644 --- a/src/backend/tests/unit/routers/test_scim.py +++ b/src/backend/tests/unit/routers/test_scim.py @@ -7,7 +7,7 @@ import backend.crud.user as user_repo from backend.config import Settings -scim = Settings().auth.scim +scim = Settings().get('auth.scim') encoded_auth = base64.b64encode( f"{scim.username}:{scim.password}".encode("utf-8") ).decode("utf-8") diff --git a/src/backend/tools/base.py b/src/backend/tools/base.py index f396360678..aa66fcd2c9 100644 --- a/src/backend/tools/base.py +++ b/src/backend/tools/base.py @@ -55,9 +55,9 @@ class BaseToolAuthentication: """ def __init__(self, *args, **kwargs): - self.BACKEND_HOST = Settings().auth.backend_hostname - self.FRONTEND_HOST = Settings().auth.frontend_hostname - self.AUTH_SECRET_KEY = Settings().auth.secret_key + self.BACKEND_HOST = Settings().get('auth.backend_hostname') + self.FRONTEND_HOST = Settings().get('auth.frontend_hostname') + self.AUTH_SECRET_KEY = Settings().get('auth.secret_key') self._post_init_check() diff --git a/src/backend/tools/brave_search/tool.py b/src/backend/tools/brave_search/tool.py index 274c966fa8..85899b6a9d 100644 --- a/src/backend/tools/brave_search/tool.py +++ b/src/backend/tools/brave_search/tool.py @@ -11,7 +11,7 @@ class BraveWebSearch(BaseTool, WebSearchFilteringMixin): NAME = "brave_web_search" - BRAVE_API_KEY = Settings().tools.brave_web_search.api_key + BRAVE_API_KEY = Settings().get('tools.brave_web_search.api_key') def __init__(self): self.client = BraveClient(api_key=self.BRAVE_API_KEY) diff --git a/src/backend/tools/google_drive/auth.py b/src/backend/tools/google_drive/auth.py index 7e451c36ad..85e9c1a7ac 100644 --- a/src/backend/tools/google_drive/auth.py +++ b/src/backend/tools/google_drive/auth.py @@ -28,8 +28,8 @@ class GoogleDriveAuth(BaseToolAuthentication, ToolAuthenticationCacheMixin): def __init__(self): super().__init__() - self.GOOGLE_DRIVE_CLIENT_ID = Settings().tools.google_drive.client_id - self.GOOGLE_DRIVE_CLIENT_SECRET = Settings().tools.google_drive.client_secret + self.GOOGLE_DRIVE_CLIENT_ID = Settings().get('tools.google_drive.client_id') + self.GOOGLE_DRIVE_CLIENT_SECRET = Settings().get('tools.google_drive.client_secret') self.REDIRECT_URL = f"{self.BACKEND_HOST}/v1/tool/auth" if ( diff --git a/src/backend/tools/google_drive/tool.py b/src/backend/tools/google_drive/tool.py index cae3b54feb..3691b75b56 100644 --- a/src/backend/tools/google_drive/tool.py +++ b/src/backend/tools/google_drive/tool.py @@ -26,8 +26,8 @@ class GoogleDrive(BaseTool): NAME = GOOGLE_DRIVE_TOOL_ID - CLIENT_ID = Settings().tools.google_drive.client_id - CLIENT_SECRET = Settings().tools.google_drive.client_secret + CLIENT_ID = Settings().get('tools.google_drive.client_id') + CLIENT_SECRET = Settings().get('tools.google_drive.client_secret') @classmethod def is_available(cls) -> bool: diff --git a/src/backend/tools/google_search.py b/src/backend/tools/google_search.py index 02d3134f14..c8df4216e6 100644 --- a/src/backend/tools/google_search.py +++ b/src/backend/tools/google_search.py @@ -11,8 +11,8 @@ class GoogleWebSearch(BaseTool, WebSearchFilteringMixin): NAME = "google_web_search" - API_KEY = Settings().tools.google_web_search.api_key - CSE_ID = Settings().tools.google_web_search.cse_id + API_KEY = Settings().get('tools.google_web_search.api_key') + CSE_ID = Settings().get('tools.google_web_search.cse_id') def __init__(self): self.client = build("customsearch", "v1", developerKey=self.API_KEY) diff --git a/src/backend/tools/hybrid_search.py b/src/backend/tools/hybrid_search.py index 1b9963ec17..8af1e98cc3 100644 --- a/src/backend/tools/hybrid_search.py +++ b/src/backend/tools/hybrid_search.py @@ -18,7 +18,7 @@ class HybridWebSearch(BaseTool, WebSearchFilteringMixin): NAME = "hybrid_web_search" POST_RERANK_MAX_RESULTS = 6 AVAILABLE_WEB_SEARCH_TOOLS = [TavilyWebSearch, GoogleWebSearch, BraveWebSearch] - ENABLED_WEB_SEARCH_TOOLS = Settings().tools.hybrid_web_search.enabled_web_searches + ENABLED_WEB_SEARCH_TOOLS = Settings().get('tools.hybrid_web_search.enabled_web_searches') WEB_SCRAPE_TOOL = WebScrapeTool def __init__(self): diff --git a/src/backend/tools/lang_chain.py b/src/backend/tools/lang_chain.py index 597f5f8ddd..9dd64f8eec 100644 --- a/src/backend/tools/lang_chain.py +++ b/src/backend/tools/lang_chain.py @@ -60,7 +60,7 @@ class LangChainVectorDBRetriever(BaseTool): """ NAME = "vector_retriever" - COHERE_API_KEY = Settings().deployments.cohere_platform.api_key + COHERE_API_KEY = Settings().get('deployments.cohere_platform.api_key') def __init__(self, filepath: str): self.filepath = filepath diff --git a/src/backend/tools/python_interpreter.py b/src/backend/tools/python_interpreter.py index b60067d3c8..3ebc664124 100644 --- a/src/backend/tools/python_interpreter.py +++ b/src/backend/tools/python_interpreter.py @@ -17,7 +17,7 @@ class PythonInterpreter(BaseTool): """ NAME = "toolkit_python_interpreter" - INTERPRETER_URL = Settings().tools.python_interpreter.url + INTERPRETER_URL = Settings().get('tools.python_interpreter.url') @classmethod def is_available(cls) -> bool: diff --git a/src/backend/tools/slack/auth.py b/src/backend/tools/slack/auth.py index c01cd3b4f6..fbe254b978 100644 --- a/src/backend/tools/slack/auth.py +++ b/src/backend/tools/slack/auth.py @@ -31,12 +31,9 @@ class SlackAuth(BaseToolAuthentication, ToolAuthenticationCacheMixin): def __init__(self): super().__init__() - settings = Settings() - slack_settings = settings.tools.slack if settings.tools and settings.tools.slack else None - self.SLACK_CLIENT_ID = getattr(slack_settings, 'client_id', None) - self.SLACK_CLIENT_SECRET = getattr(slack_settings, 'client_secret', None) - self.USER_SCOPES = getattr(slack_settings, 'user_scopes', None) or self.DEFAULT_USER_SCOPES - + self.SLACK_CLIENT_ID = Settings().get('tools.slack.client_id') + self.SLACK_CLIENT_SECRET = Settings().get('tools.slack.client_secret') + self.USER_SCOPES = Settings().get('tools.slack.user_scopes') or self.DEFAULT_USER_SCOPES self.REDIRECT_URL = f"{self.BACKEND_HOST}/v1/tool/auth" if any([ diff --git a/src/backend/tools/slack/tool.py b/src/backend/tools/slack/tool.py index 2abbe0cb2a..c1adee118e 100644 --- a/src/backend/tools/slack/tool.py +++ b/src/backend/tools/slack/tool.py @@ -16,16 +16,11 @@ class SlackTool(BaseTool): """ NAME = SLACK_TOOL_ID - CLIENT_ID = "" - CLIENT_SECRET = "" + CLIENT_ID = Settings().get('tools.slack.client_id') + CLIENT_SECRET = Settings().get('tools.slack.client_secret') @classmethod def is_available(cls) -> bool: - settings = Settings() - slack_settings = settings.tools.slack if settings.tools and settings.tools.slack else None - cls.CLIENT_ID = getattr(slack_settings, 'client_id', None) - cls.CLIENT_SECRET = getattr(slack_settings, 'client_secret', None) - return cls.CLIENT_ID is not None and cls.CLIENT_SECRET is not None @classmethod diff --git a/src/backend/tools/slack/utils.py b/src/backend/tools/slack/utils.py index 716e1bb1e0..f279ee7d11 100644 --- a/src/backend/tools/slack/utils.py +++ b/src/backend/tools/slack/utils.py @@ -24,7 +24,7 @@ def serialize_results(self, response): document = self.extract_message_data(match) results.append(document) for match in response["files"]["matches"]: - document = self.extract_files_data(match) + document = self.extract_files_data(match, response["query"]) results.append(document) return results @@ -43,14 +43,14 @@ def extract_message_data(message_json): return document @staticmethod - def extract_files_data(message_json): + def extract_files_data(message_json, query=""): document = {} document["type"] = "file" if "permalink" in message_json: document["url"] = str(message_json.pop("permalink")) if "title" in message_json: document["title"] = str(message_json["title"]) - document["text"] = str(message_json["title"]) + document["text"] = f"{query} in {str(message_json['title'])}" return document diff --git a/src/backend/tools/tavily_search.py b/src/backend/tools/tavily_search.py index 3719de8999..abf30db883 100644 --- a/src/backend/tools/tavily_search.py +++ b/src/backend/tools/tavily_search.py @@ -12,7 +12,7 @@ class TavilyWebSearch(BaseTool, WebSearchFilteringMixin): NAME = "tavily_web_search" - TAVILY_API_KEY = Settings().tools.tavily_web_search.api_key + TAVILY_API_KEY = Settings().get('tools.tavily_web_search.api_key') POST_RERANK_MAX_RESULTS = 6 def __init__(self): diff --git a/src/community/tools/llama_index.py b/src/community/tools/llama_index.py index c75aad4bd7..6cdef8da4c 100644 --- a/src/community/tools/llama_index.py +++ b/src/community/tools/llama_index.py @@ -29,7 +29,7 @@ class LlamaIndexUploadPDFRetriever(BaseTool): CHUNK_SIZE = 512 def __init__(self): - self.COHERE_API_KEY = Settings().deployments.cohere_platform.api_key + self.COHERE_API_KEY = Settings().get('deployments.cohere_platform.api_key') def _get_embedding(self, embed_type): diff --git a/src/community/tools/wolfram.py b/src/community/tools/wolfram.py index a88308eda5..9fc022ebab 100644 --- a/src/community/tools/wolfram.py +++ b/src/community/tools/wolfram.py @@ -15,7 +15,7 @@ class WolframAlpha(BaseTool): NAME = "wolfram_alpha" - wolfram_app_id = Settings().tools.wolfram_alpha.app_id + wolfram_app_id = Settings().get('tools.wolfram_alpha.app_id') def __init__(self): self.app_id = self.wolfram_app_id