Skip to content

Commit

Permalink
Add some tests for ext::ai
Browse files Browse the repository at this point in the history
  • Loading branch information
elprans committed Apr 16, 2024
1 parent 45d9358 commit 6c888ab
Show file tree
Hide file tree
Showing 5 changed files with 246 additions and 13 deletions.
26 changes: 20 additions & 6 deletions edb/testbase/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
Any,
Callable,
Optional,
Type,
)

import http.server
Expand Down Expand Up @@ -312,34 +313,47 @@ def assert_graphql_query_result(


class MockHttpServerHandler(http.server.BaseHTTPRequestHandler):
def get_server_and_path(self) -> tuple[str, str]:
server = f'http://{self.headers.get("Host")}'
return server, self.path

def do_GET(self):
self.close_connection = False
server, path = self.path.lstrip('/').split('/', 1)
server = urllib.parse.unquote(server)
server, path = self.get_server_and_path()
self.server.owner.handle_request('GET', server, path, self)

def do_POST(self):
self.close_connection = False
server, path = self.path.lstrip('/').split('/', 1)
server = urllib.parse.unquote(server)
server, path = self.get_server_and_path()
self.server.owner.handle_request('POST', server, path, self)

def log_message(self, *args):
pass


class MultiHostMockHttpServerHandler(MockHttpServerHandler):
def get_server_and_path(self) -> tuple[str, str]:
server, path = self.path.lstrip('/').split('/', 1)
server = urllib.parse.unquote(server)
return server, path


ResponseType = tuple[str, int] | tuple[str, int, dict[str, str]]


class MockHttpServer:
def __init__(self) -> None:
def __init__(
self,
handler_type: Type[MockHttpServerHandler] = MockHttpServerHandler,
) -> None:
self.has_started = threading.Event()
self.routes: dict[
tuple[str, str, str],
ResponseType | Callable[[MockHttpServerHandler], ResponseType],
] = {}
self.requests: dict[tuple[str, str, str], list[dict[str, Any]]] = {}
self.url: Optional[str] = None
self.handler_type = handler_type

def get_base_url(self) -> str:
if self.url is None:
Expand Down Expand Up @@ -466,7 +480,7 @@ def __enter__(self):

def _http_worker(self):
self._http_server = http.server.HTTPServer(
('localhost', 0), MockHttpServerHandler
('localhost', 0), self.handler_type
)
self._http_server.owner = self
self._address = self._http_server.server_address
Expand Down
18 changes: 12 additions & 6 deletions edb/testbase/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -1127,9 +1127,9 @@ def look(obj):

class DatabaseTestCase(ConnectedTestCase):

SETUP: Optional[Union[str, List[str]]] = None
SETUP: Optional[str | pathlib.Path | list[str] | list[pathlib.Path]] = None
TEARDOWN: Optional[str] = None
SCHEMA: Optional[Union[str, pathlib.Path]] = None
SCHEMA: Optional[str | pathlib.Path] = None
DEFAULT_MODULE: str = 'default'
EXTENSIONS: List[str] = []

Expand Down Expand Up @@ -1300,13 +1300,19 @@ def get_setup_script(cls):
for scr in scripts:
has_nontrivial_script = True

if '\n' not in scr and os.path.exists(scr):
is_path = (
isinstance(scr, pathlib.Path)
or '\n' not in scr and os.path.exists(scr)
)

if is_path:
with open(scr, 'rt') as f:
setup = f.read()
setup_text = f.read()
else:
setup = scr
assert isinstance(scr, str)
setup_text = scr

script += '\n' + setup
script += '\n' + setup_text

# If the SETUP script did a SET MODULE, make sure it is cleared
# (since in some modes we keep using the same connection)
Expand Down
34 changes: 34 additions & 0 deletions tests/schemas/ext_ai.esdl
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
#
# This source file is part of the EdgeDB open source project.
#
# Copyright 2023-present MagicStack Inc. and the EdgeDB authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#


type TestEmbeddingModel
extending ext::ai::EmbeddingModel
{
annotation ext::ai::model_name := "text-embedding-test";
annotation ext::ai::model_provider := "custom::test";
annotation ext::ai::embedding_model_max_input_tokens := "8191";
annotation ext::ai::embedding_model_max_output_dimensions := "10";
annotation ext::ai::embedding_model_supports_shortening := "true";
};

type Astronomy {
content: str;
deferred index ext::ai::index(embedding_model := 'text-embedding-test')
on (.content);
};
178 changes: 178 additions & 0 deletions tests/test_ext_ai.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
#
# This source file is part of the EdgeDB open source project.
#
# Copyright 2023-present MagicStack Inc. and the EdgeDB authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#


import json
import pathlib

from edb.testbase import http as tb


class TestExtAI(tb.BaseHttpExtensionTest):
EXTENSIONS = ['pgvector', 'ai']
BACKEND_SUPERUSER = True
TRANSACTION_ISOLATION = False
PARALLELISM_GRANULARITY = 'suite'

SCHEMA = pathlib.Path(__file__).parent / 'schemas' / 'ext_ai.esdl'

@classmethod
def setUpClass(cls):
super().setUpClass()
cls.mock_server = tb.MockHttpServer()
cls.mock_server.start()
base_url = cls.mock_server.get_base_url().rstrip("/")

cls.mock_server.register_route_handler(
"POST",
base_url,
"/v1/embeddings",
)(cls.mock_api_embeddings)

async def _setup():
await cls.con.execute(
f"""
CONFIGURE CURRENT DATABASE
INSERT ext::ai::CustomProviderConfig {{
name := 'custom::test',
secret := 'very secret',
api_url := '{base_url}/v1',
api_style := ext::ai::ProviderAPIStyle.OpenAI,
}};
CONFIGURE CURRENT DATABASE
SET ext::ai::Config::indexer_naptime := <duration>'100ms';
""",
)

await cls._wait_for_db_config('ext::ai::Config::providers')

cls.loop.run_until_complete(_setup())

@classmethod
def tearDownClass(cls):
cls.mock_server.stop()

@classmethod
def get_setup_script(cls):
res = super().get_setup_script()

# HACK: As a debugging cycle hack, when RELOAD is true, we reload the
# extension package from the file, so we can test without a bootstrap.
RELOAD = False

if RELOAD:
root = pathlib.Path(__file__).parent.parent
with open(root / 'edb/lib/ext/ai.edgeql') as f:
contents = f.read()
to_add = (
'''
drop extension package ai version '1.0';
create extension ai;
'''
+ contents
)
splice = '__internal_testmode := true;'
res = res.replace(splice, splice + to_add)

return res

@classmethod
def mock_api_embeddings(
cls,
handler: tb.MockHttpServerHandler,
) -> tb.ResponseType:
return (
json.dumps({
"object": "list",
"data": [{
"object": "embedding",
"index": 0,
"embedding": [
1.0,
-2.0,
3.0,
-4.0,
5.0,
-6.0,
7.0,
-8.0,
9.0,
-10.0,
],
}],
}),
200,
)

async def test_ext_ai_indexing(self):
await self.con.execute(
"""
insert Astronomy {
content := 'Skies on Mars are red'
};
insert Astronomy {
content := 'Skies on Earth are blue'
};
""",
)

async for tr in self.try_until_succeeds(
ignore=(AssertionError,),
timeout=10.0,
):
async with tr:
await self.assert_query_result(
r'''
with
result := ext::ai::search(
Astronomy, <array<float32>>$qv)
select
result.object {
content,
distance := result.distance,
}
order by
result.distance asc empty last
then result.object.content
''',
[
{
'content': 'Skies on Earth are blue',
'distance': 0,
},
{
'content': 'Skies on Mars are red',
'distance': 0,
},
],
variables={
"qv": [
1.0,
-2.0,
3.0,
-4.0,
5.0,
-6.0,
7.0,
-8.0,
9.0,
-10.0,
],
}
)
3 changes: 2 additions & 1 deletion tests/test_http_ext_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,8 @@ def setUpClass(cls):
mock_provider: tb.MockHttpServer

def setUp(self):
self.mock_provider = tb.MockHttpServer()
self.mock_provider = tb.MockHttpServer(
handler_type=tb.MultiHostMockHttpServerHandler)
self.mock_provider.start()
HTTP_TEST_PORT.set(self.mock_provider.get_base_url())

Expand Down

0 comments on commit 6c888ab

Please sign in to comment.