diff --git a/argilla-server/src/argilla_server/_app.py b/argilla-server/src/argilla_server/_app.py index a844abde70..5a58596096 100644 --- a/argilla-server/src/argilla_server/_app.py +++ b/argilla-server/src/argilla_server/_app.py @@ -66,7 +66,7 @@ async def redirect_api(): for app_configure in [ configure_app_logging, configure_database, - ping_search_engine, + configure_search_engine, configure_telemetry, configure_middleware, configure_app_security, @@ -148,18 +148,35 @@ def _create_statics_folder(path_from): ) -def ping_search_engine(app: FastAPI): +def configure_search_engine(app: FastAPI): + @app.on_event("startup") + async def configure_elasticsearch(): + if not settings.search_engine_is_elasticsearch: + return + + logging.getLogger("elasticsearch").setLevel(logging.ERROR) + logging.getLogger("elastic_transport").setLevel(logging.ERROR) + + @app.on_event("startup") + async def configure_opensearch(): + if not settings.search_engine_is_opensearch: + return + + logging.getLogger("opensearch").setLevel(logging.ERROR) + logging.getLogger("opensearch_transport").setLevel(logging.ERROR) + @app.on_event("startup") @backoff.on_exception(backoff.expo, ConnectionError, max_time=60) - async def _ping_search_engine(): + async def ping_search_engine(): async for search_engine in get_search_engine(): if not await search_engine.ping(): raise ConnectionError( - f"Your Elasticsearch endpoint at {settings.obfuscated_elasticsearch()} is not available or not responding.\n" - "Please make sure your Elasticsearch instance is launched and correctly running and\n" + f"Your {settings.search_engine} endpoint at {settings.obfuscated_elasticsearch()} is not available or not responding.\n" + f"Please make sure your {settings.search_engine} instance is launched and correctly running and\n" "you have the necessary access permissions. Once you have verified this, restart the argilla server.\n" ) + def configure_app_security(app: FastAPI): auth.configure_app(app) diff --git a/argilla-server/src/argilla_server/constants.py b/argilla-server/src/argilla_server/constants.py index 176d8da54a..6bda197c27 100644 --- a/argilla-server/src/argilla_server/constants.py +++ b/argilla-server/src/argilla_server/constants.py @@ -15,6 +15,9 @@ API_KEY_HEADER_NAME = "X-Argilla-Api-Key" WORKSPACE_HEADER_NAME = "X-Argilla-Workspace" +SEARCH_ENGINE_ELASTICSEARCH = "elasticsearch" +SEARCH_ENGINE_OPENSEARCH = "opensearch" + DEFAULT_USERNAME = "argilla" DEFAULT_PASSWORD = "1234" DEFAULT_API_KEY = "argilla.apikey" diff --git a/argilla-server/src/argilla_server/search_engine/elasticsearch.py b/argilla-server/src/argilla_server/search_engine/elasticsearch.py index 2e1b47f7f4..b5bd7cef97 100644 --- a/argilla-server/src/argilla_server/search_engine/elasticsearch.py +++ b/argilla-server/src/argilla_server/search_engine/elasticsearch.py @@ -18,6 +18,7 @@ from elasticsearch8 import AsyncElasticsearch, helpers +from argilla_server.constants import SEARCH_ENGINE_ELASTICSEARCH from argilla_server.models import VectorSettings from argilla_server.search_engine import SearchEngine from argilla_server.search_engine.commons import ( @@ -37,7 +38,7 @@ def _compute_num_candidates_from_k(k: int) -> int: return 2000 -@SearchEngine.register(engine_name="elasticsearch") +@SearchEngine.register(engine_name=SEARCH_ENGINE_ELASTICSEARCH) @dataclasses.dataclass class ElasticSearchEngine(BaseElasticAndOpenSearchEngine): config: Dict[str, Any] = dataclasses.field(default_factory=dict) diff --git a/argilla-server/src/argilla_server/search_engine/opensearch.py b/argilla-server/src/argilla_server/search_engine/opensearch.py index 927ce1cb0f..e4d4860471 100644 --- a/argilla-server/src/argilla_server/search_engine/opensearch.py +++ b/argilla-server/src/argilla_server/search_engine/opensearch.py @@ -18,6 +18,7 @@ from opensearchpy import AsyncOpenSearch, helpers +from argilla_server.constants import SEARCH_ENGINE_OPENSEARCH from argilla_server.models import VectorSettings from argilla_server.search_engine.base import SearchEngine from argilla_server.search_engine.commons import ( @@ -29,7 +30,7 @@ from argilla_server.settings import settings -@SearchEngine.register(engine_name="opensearch") +@SearchEngine.register(engine_name=SEARCH_ENGINE_OPENSEARCH) @dataclasses.dataclass class OpenSearchEngine(BaseElasticAndOpenSearchEngine): config: Dict[str, Any] = dataclasses.field(default_factory=dict) diff --git a/argilla-server/src/argilla_server/settings.py b/argilla-server/src/argilla_server/settings.py index e558b09467..32ba6c4778 100644 --- a/argilla-server/src/argilla_server/settings.py +++ b/argilla-server/src/argilla_server/settings.py @@ -29,6 +29,8 @@ DEFAULT_MAX_KEYWORD_LENGTH, DEFAULT_SPAN_OPTIONS_MAX_ITEMS, DEFAULT_TELEMETRY_KEY, + SEARCH_ENGINE_ELASTICSEARCH, + SEARCH_ENGINE_OPENSEARCH, ) from argilla_server.pydantic_v1 import BaseSettings, Field, root_validator, validator @@ -97,7 +99,7 @@ class Settings(BaseSettings): es_mapping_total_fields_limit: int = 2000 - search_engine: str = "elasticsearch" + search_engine: str = SEARCH_ENGINE_ELASTICSEARCH vectors_fields_limit: int = Field( default=5, @@ -217,6 +219,14 @@ def old_dataset_records_index_name(self) -> str: return index_name.replace("", "") return index_name.replace("", f".{ns}") + @property + def search_engine_is_elasticsearch(self) -> bool: + return self.search_engine == SEARCH_ENGINE_ELASTICSEARCH + + @property + def search_engine_is_opensearch(self) -> bool: + return self.search_engine == SEARCH_ENGINE_OPENSEARCH + def obfuscated_elasticsearch(self) -> str: """Returns configured elasticsearch url obfuscating the provided password, if any""" parsed = urlparse(self.elasticsearch)