Skip to content

Commit

Permalink
Rename AI to RAGClient and add compat names (#578)
Browse files Browse the repository at this point in the history
  • Loading branch information
fantix authored Feb 14, 2025
1 parent 023697a commit e5aae27
Show file tree
Hide file tree
Showing 6 changed files with 69 additions and 25 deletions.
18 changes: 18 additions & 0 deletions edgedb/ai/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,30 @@
TYPE_CHECKING = False
if TYPE_CHECKING:
from gel.ai import * # noqa
create_ai = create_rag_client # noqa
EdgeDBAI = RAGClient # noqa
create_async_ai = create_async_rag_client # noqa
AsyncEdgeDBAI = AsyncRAGClient # noqa
AIOptions = RAGOptions # noqa
import gel.ai as _mod
import sys as _sys
_cur = _sys.modules['edgedb.ai']
for _k in vars(_mod):
if not _k.startswith('__') or _k in ('__all__', '__doc__'):
setattr(_cur, _k, getattr(_mod, _k))
_cur.create_ai = _mod.create_rag_client
_cur.EdgeDBAI = _mod.RAGClient
_cur.create_async_ai = _mod.create_async_rag_client
_cur.AsyncEdgeDBAI = _mod.AsyncRAGClient
_cur.AIOptions = _mod.RAGOptions
if hasattr(_cur, '__all__'):
_cur.__all__ = _cur.__all__ + [
'create_ai',
'EdgeDBAI',
'create_async_ai',
'AsyncEdgeDBAI',
'AIOptions',
]
del _cur
del _sys
del _mod
Expand Down
16 changes: 8 additions & 8 deletions gel/ai/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,17 @@
# limitations under the License.
#

from .types import AIOptions, ChatParticipantRole, Prompt, QueryContext
from .core import create_ai, EdgeDBAI
from .core import create_async_ai, AsyncEdgeDBAI
from .types import RAGOptions, ChatParticipantRole, Prompt, QueryContext
from .core import create_rag_client, RAGClient
from .core import create_async_rag_client, AsyncRAGClient

__all__ = [
"AIOptions",
"RAGOptions",
"ChatParticipantRole",
"Prompt",
"QueryContext",
"create_ai",
"EdgeDBAI",
"create_async_ai",
"AsyncEdgeDBAI",
"create_rag_client",
"RAGClient",
"create_async_rag_client",
"AsyncRAGClient",
]
20 changes: 10 additions & 10 deletions gel/ai/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,27 +26,27 @@
from . import types


def create_ai(client: gel.Client, **kwargs) -> EdgeDBAI:
def create_rag_client(client: gel.Client, **kwargs) -> RAGClient:
client.ensure_connected()
return EdgeDBAI(client, types.AIOptions(**kwargs))
return RAGClient(client, types.RAGOptions(**kwargs))


async def create_async_ai(
async def create_async_rag_client(
client: gel.AsyncIOClient, **kwargs
) -> AsyncEdgeDBAI:
) -> AsyncRAGClient:
await client.ensure_connected()
return AsyncEdgeDBAI(client, types.AIOptions(**kwargs))
return AsyncRAGClient(client, types.RAGOptions(**kwargs))


class BaseEdgeDBAI:
options: types.AIOptions
class BaseRAGClient:
options: types.RAGOptions
context: types.QueryContext
client_cls = NotImplemented

def __init__(
self,
client: typing.Union[gel.Client, gel.AsyncIOClient],
options: types.AIOptions,
options: types.RAGOptions,
**kwargs,
):
pool = client._impl
Expand Down Expand Up @@ -103,7 +103,7 @@ def _make_rag_request(
)


class EdgeDBAI(BaseEdgeDBAI):
class RAGClient(BaseRAGClient):
client: httpx.Client

def _init_client(self, **kwargs):
Expand Down Expand Up @@ -146,7 +146,7 @@ def generate_embeddings(self, *inputs: str, model: str) -> list[float]:
return resp.json()["data"][0]["embedding"]


class AsyncEdgeDBAI(BaseEdgeDBAI):
class AsyncRAGClient(BaseRAGClient):
client: httpx.AsyncClient

def _init_client(self, **kwargs):
Expand Down
4 changes: 2 additions & 2 deletions gel/ai/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,12 @@ class Prompt:


@dc.dataclass
class AIOptions:
class RAGOptions:
model: str
prompt: typing.Optional[Prompt] = None

def derive(self, kwargs):
return AIOptions(**{**dc.asdict(self), **kwargs})
return RAGOptions(**{**dc.asdict(self), **kwargs})


@dc.dataclass
Expand Down
4 changes: 2 additions & 2 deletions tools/gen_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@
if __name__ == '__main__':
this = pathlib.Path(__file__)

errors_fn = this.parent.parent / 'edgedb' / 'errors' / '__init__.py'
init_fn = this.parent.parent / 'edgedb' / '__init__.py'
errors_fn = this.parent.parent / 'gel' / 'errors' / '__init__.py'
init_fn = this.parent.parent / 'gel' / '__init__.py'

with open(errors_fn, 'rt') as f:
errors_txt = f.read()
Expand Down
32 changes: 29 additions & 3 deletions tools/make_import_shims.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,15 @@
import os
import sys

MODS = sorted(['gel', 'gel._taskgroup', 'gel._version', 'gel.abstract', 'gel.ai', 'gel.ai.core', 'gel.ai.types', 'gel.asyncio_client', 'gel.base_client', 'gel.blocking_client', 'gel.codegen', 'gel.color', 'gel.con_utils', 'gel.credentials', 'gel.datatypes', 'gel.datatypes.datatypes', 'gel.datatypes.range', 'gel.describe', 'gel.enums', 'gel.errors', 'gel.errors._base', 'gel.errors.tags', 'gel.introspect', 'gel.options', 'gel.pgproto', 'gel.pgproto.pgproto', 'gel.pgproto.types', 'gel.platform', 'gel.protocol', 'gel.protocol.asyncio_proto', 'gel.protocol.blocking_proto', 'gel.protocol.protocol', 'gel.scram', 'gel.scram.saslprep', 'gel.transaction'])

COMPAT = {
'gel.ai': {
'create_ai': 'create_rag_client',
'EdgeDBAI': 'RAGClient',
'create_async_ai': 'create_async_rag_client',
'AsyncEdgeDBAI': 'AsyncRAGClient',
'AIOptions': 'RAGOptions',
},
}


def main():
Expand All @@ -12,7 +19,10 @@ def main():
nmod = 'edgedb' + mod[len('gel'):]
slash_name = nmod.replace('.', '/')
if is_package:
os.mkdir(slash_name)
try:
os.mkdir(slash_name)
except FileExistsError:
pass
fname = slash_name + '/__init__.py'
else:
fname = slash_name + '.py'
Expand All @@ -25,12 +35,28 @@ def main():
TYPE_CHECKING = False
if TYPE_CHECKING:
from {mod} import * # noqa
''')
if mod in COMPAT:
for k, v in COMPAT[mod].items():
f.write(f' {k} = {v} # noqa\n')
f.write(f'''\
import {mod} as _mod
import sys as _sys
_cur = _sys.modules['{nmod}']
for _k in vars(_mod):
if not _k.startswith('__') or _k in ('__all__', '__doc__'):
setattr(_cur, _k, getattr(_mod, _k))
''')
if mod in COMPAT:
for k, v in COMPAT[mod].items():
f.write(f"_cur.{k} = _mod.{v}\n")
f.write(f'''\
if hasattr(_cur, '__all__'):
_cur.__all__ = _cur.__all__ + [
{',\n '.join(repr(k) for k in COMPAT[mod])},
]
''')
f.write(f'''\
del _cur
del _sys
del _mod
Expand Down

0 comments on commit e5aae27

Please sign in to comment.