From 6c888abd5e209839344de7c17f703bb6cd6059fb Mon Sep 17 00:00:00 2001 From: Elvis Pranskevichus Date: Fri, 12 Apr 2024 15:41:41 -0700 Subject: [PATCH] Add some tests for `ext::ai` --- edb/testbase/http.py | 26 ++++-- edb/testbase/server.py | 18 ++-- tests/schemas/ext_ai.esdl | 34 +++++++ tests/test_ext_ai.py | 178 ++++++++++++++++++++++++++++++++++++ tests/test_http_ext_auth.py | 3 +- 5 files changed, 246 insertions(+), 13 deletions(-) create mode 100644 tests/schemas/ext_ai.esdl create mode 100644 tests/test_ext_ai.py diff --git a/edb/testbase/http.py b/edb/testbase/http.py index 900acf2df484..6b50c23283e6 100644 --- a/edb/testbase/http.py +++ b/edb/testbase/http.py @@ -22,6 +22,7 @@ Any, Callable, Optional, + Type, ) import http.server @@ -312,27 +313,39 @@ def assert_graphql_query_result( class MockHttpServerHandler(http.server.BaseHTTPRequestHandler): + def get_server_and_path(self) -> tuple[str, str]: + server = f'http://{self.headers.get("Host")}' + return server, self.path + def do_GET(self): self.close_connection = False - server, path = self.path.lstrip('/').split('/', 1) - server = urllib.parse.unquote(server) + server, path = self.get_server_and_path() self.server.owner.handle_request('GET', server, path, self) def do_POST(self): self.close_connection = False - server, path = self.path.lstrip('/').split('/', 1) - server = urllib.parse.unquote(server) + server, path = self.get_server_and_path() self.server.owner.handle_request('POST', server, path, self) def log_message(self, *args): pass +class MultiHostMockHttpServerHandler(MockHttpServerHandler): + def get_server_and_path(self) -> tuple[str, str]: + server, path = self.path.lstrip('/').split('/', 1) + server = urllib.parse.unquote(server) + return server, path + + ResponseType = tuple[str, int] | tuple[str, int, dict[str, str]] class MockHttpServer: - def __init__(self) -> None: + def __init__( + self, + handler_type: Type[MockHttpServerHandler] = MockHttpServerHandler, + ) -> None: self.has_started = threading.Event() self.routes: dict[ tuple[str, str, str], @@ -340,6 +353,7 @@ def __init__(self) -> None: ] = {} self.requests: dict[tuple[str, str, str], list[dict[str, Any]]] = {} self.url: Optional[str] = None + self.handler_type = handler_type def get_base_url(self) -> str: if self.url is None: @@ -466,7 +480,7 @@ def __enter__(self): def _http_worker(self): self._http_server = http.server.HTTPServer( - ('localhost', 0), MockHttpServerHandler + ('localhost', 0), self.handler_type ) self._http_server.owner = self self._address = self._http_server.server_address diff --git a/edb/testbase/server.py b/edb/testbase/server.py index ce0e77f02945..81e0a152e560 100644 --- a/edb/testbase/server.py +++ b/edb/testbase/server.py @@ -1127,9 +1127,9 @@ def look(obj): class DatabaseTestCase(ConnectedTestCase): - SETUP: Optional[Union[str, List[str]]] = None + SETUP: Optional[str | pathlib.Path | list[str] | list[pathlib.Path]] = None TEARDOWN: Optional[str] = None - SCHEMA: Optional[Union[str, pathlib.Path]] = None + SCHEMA: Optional[str | pathlib.Path] = None DEFAULT_MODULE: str = 'default' EXTENSIONS: List[str] = [] @@ -1300,13 +1300,19 @@ def get_setup_script(cls): for scr in scripts: has_nontrivial_script = True - if '\n' not in scr and os.path.exists(scr): + is_path = ( + isinstance(scr, pathlib.Path) + or '\n' not in scr and os.path.exists(scr) + ) + + if is_path: with open(scr, 'rt') as f: - setup = f.read() + setup_text = f.read() else: - setup = scr + assert isinstance(scr, str) + setup_text = scr - script += '\n' + setup + script += '\n' + setup_text # If the SETUP script did a SET MODULE, make sure it is cleared # (since in some modes we keep using the same connection) diff --git a/tests/schemas/ext_ai.esdl b/tests/schemas/ext_ai.esdl new file mode 100644 index 000000000000..ebaf390432b3 --- /dev/null +++ b/tests/schemas/ext_ai.esdl @@ -0,0 +1,34 @@ +# +# This source file is part of the EdgeDB open source project. +# +# Copyright 2023-present MagicStack Inc. and the EdgeDB authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + + +type TestEmbeddingModel + extending ext::ai::EmbeddingModel +{ + annotation ext::ai::model_name := "text-embedding-test"; + annotation ext::ai::model_provider := "custom::test"; + annotation ext::ai::embedding_model_max_input_tokens := "8191"; + annotation ext::ai::embedding_model_max_output_dimensions := "10"; + annotation ext::ai::embedding_model_supports_shortening := "true"; +}; + +type Astronomy { + content: str; + deferred index ext::ai::index(embedding_model := 'text-embedding-test') + on (.content); +}; diff --git a/tests/test_ext_ai.py b/tests/test_ext_ai.py new file mode 100644 index 000000000000..330c3f42e27c --- /dev/null +++ b/tests/test_ext_ai.py @@ -0,0 +1,178 @@ +# +# This source file is part of the EdgeDB open source project. +# +# Copyright 2023-present MagicStack Inc. and the EdgeDB authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + + +import json +import pathlib + +from edb.testbase import http as tb + + +class TestExtAI(tb.BaseHttpExtensionTest): + EXTENSIONS = ['pgvector', 'ai'] + BACKEND_SUPERUSER = True + TRANSACTION_ISOLATION = False + PARALLELISM_GRANULARITY = 'suite' + + SCHEMA = pathlib.Path(__file__).parent / 'schemas' / 'ext_ai.esdl' + + @classmethod + def setUpClass(cls): + super().setUpClass() + cls.mock_server = tb.MockHttpServer() + cls.mock_server.start() + base_url = cls.mock_server.get_base_url().rstrip("/") + + cls.mock_server.register_route_handler( + "POST", + base_url, + "/v1/embeddings", + )(cls.mock_api_embeddings) + + async def _setup(): + await cls.con.execute( + f""" + CONFIGURE CURRENT DATABASE + INSERT ext::ai::CustomProviderConfig {{ + name := 'custom::test', + secret := 'very secret', + api_url := '{base_url}/v1', + api_style := ext::ai::ProviderAPIStyle.OpenAI, + }}; + + CONFIGURE CURRENT DATABASE + SET ext::ai::Config::indexer_naptime := '100ms'; + """, + ) + + await cls._wait_for_db_config('ext::ai::Config::providers') + + cls.loop.run_until_complete(_setup()) + + @classmethod + def tearDownClass(cls): + cls.mock_server.stop() + + @classmethod + def get_setup_script(cls): + res = super().get_setup_script() + + # HACK: As a debugging cycle hack, when RELOAD is true, we reload the + # extension package from the file, so we can test without a bootstrap. + RELOAD = False + + if RELOAD: + root = pathlib.Path(__file__).parent.parent + with open(root / 'edb/lib/ext/ai.edgeql') as f: + contents = f.read() + to_add = ( + ''' + drop extension package ai version '1.0'; + create extension ai; + ''' + + contents + ) + splice = '__internal_testmode := true;' + res = res.replace(splice, splice + to_add) + + return res + + @classmethod + def mock_api_embeddings( + cls, + handler: tb.MockHttpServerHandler, + ) -> tb.ResponseType: + return ( + json.dumps({ + "object": "list", + "data": [{ + "object": "embedding", + "index": 0, + "embedding": [ + 1.0, + -2.0, + 3.0, + -4.0, + 5.0, + -6.0, + 7.0, + -8.0, + 9.0, + -10.0, + ], + }], + }), + 200, + ) + + async def test_ext_ai_indexing(self): + await self.con.execute( + """ + insert Astronomy { + content := 'Skies on Mars are red' + }; + insert Astronomy { + content := 'Skies on Earth are blue' + }; + """, + ) + + async for tr in self.try_until_succeeds( + ignore=(AssertionError,), + timeout=10.0, + ): + async with tr: + await self.assert_query_result( + r''' + with + result := ext::ai::search( + Astronomy, >$qv) + select + result.object { + content, + distance := result.distance, + } + order by + result.distance asc empty last + then result.object.content + ''', + [ + { + 'content': 'Skies on Earth are blue', + 'distance': 0, + }, + { + 'content': 'Skies on Mars are red', + 'distance': 0, + }, + ], + variables={ + "qv": [ + 1.0, + -2.0, + 3.0, + -4.0, + 5.0, + -6.0, + 7.0, + -8.0, + 9.0, + -10.0, + ], + } + ) diff --git a/tests/test_http_ext_auth.py b/tests/test_http_ext_auth.py index e58057dd57dd..2093b7578aee 100644 --- a/tests/test_http_ext_auth.py +++ b/tests/test_http_ext_auth.py @@ -319,7 +319,8 @@ def setUpClass(cls): mock_provider: tb.MockHttpServer def setUp(self): - self.mock_provider = tb.MockHttpServer() + self.mock_provider = tb.MockHttpServer( + handler_type=tb.MultiHostMockHttpServerHandler) self.mock_provider.start() HTTP_TEST_PORT.set(self.mock_provider.get_base_url())