-
Notifications
You must be signed in to change notification settings - Fork 409
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Initial implementation of the ai
extension
#7183
Conversation
832614e
to
87bfcf5
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A first batch of comments on the schema and index commits.
Mostly they are requests for more documentation
index_id = target_index_metadata.get("id") | ||
if index_id is None: | ||
raise AssertionError( | ||
"missing expected index metadata in FunctionCall.extras") | ||
dimensions = target_index_metadata.get("dimensions") | ||
if dimensions is None: | ||
raise AssertionError( | ||
"missing expected index metadata in FunctionCall.extras") | ||
df = target_index_metadata.get("distance_function") | ||
if index_id is None: | ||
raise AssertionError( | ||
"missing expected index metadata in FunctionCall.extras") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe extract with a single match
?
_ctx: context.CompilerContextLevel, | ||
newctx: context.CompilerContextLevel, | ||
_inner_ctx: context.CompilerContextLevel, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This one doesn't need to be addressed yet since it's really an FTS thing, but having three different contexts is quite complex and needs to be explained and given likely better names. (I think it's because the FTS stuff is picky about where stuff goes?)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is going to need tests. Testing the integrations themselves might not be plausible, but there is a ton of other code that we want to test. Possible approaches:
|
body = json.loads(request.body) | ||
if not isinstance(body, dict): | ||
raise TypeError( | ||
'the body of the request must be a JSON object') | ||
|
||
context = body.get('context') | ||
if context is None: | ||
raise TypeError( | ||
'missing required "context" object in request') | ||
if not isinstance(context, dict): | ||
raise TypeError( | ||
'"context" value in request is not a valid JSON object') | ||
|
||
ctx_query = context.get("query") | ||
ctx_variables = context.get("variables") | ||
ctx_globals = context.get("globals") | ||
ctx_max_obj_count = context.get("max_object_count") | ||
|
||
if not ctx_query: | ||
raise TypeError( | ||
'missing required "query" in request "context" object') | ||
|
||
if ctx_variables is not None and not isinstance(ctx_variables, dict): | ||
raise TypeError('"variables" must be a JSON object') | ||
|
||
if ctx_globals is not None and not isinstance(ctx_globals, dict): | ||
raise TypeError('"globals" must be a JSON object') | ||
|
||
model = body.get('model') | ||
if not model: | ||
raise TypeError( | ||
'missing required "model" in request') | ||
|
||
query = body.get('query') | ||
if not query: | ||
raise TypeError( | ||
'missing required "query" in request') | ||
|
||
stream = body.get('stream') | ||
if stream is None: | ||
stream = False | ||
elif not isinstance(stream, bool): | ||
raise TypeError('"stream" must be a boolean') | ||
|
||
if ctx_max_obj_count is None: | ||
ctx_max_obj_count = 5 | ||
elif not isinstance(ctx_max_obj_count, int) or ctx_max_obj_count <= 0: | ||
raise TypeError( | ||
'"context.max_object_count" must be an postitive integer') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not a comment for now, but I think we're going to want a lightweight HTTP API framework at some point...
if not ctx_query: | ||
raise TypeError( | ||
'missing required "query" in request "context" object') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we make sure that it parses as a standalone fragment?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why? If it parses while wrapped then that's good enough, no?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I worry that some sort of injection might be possible.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Another thing we need to handle is comments in the query.
For that it probably suffices to append a newline after?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The API here isn't intended to be used by untrusted parties, it is similar in that regard to edgeql+http or graphql. We also explicitly disable all capabilities, including mutation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe my answer is just "it squicks me out", then
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think there are queries with unbalanced parens that would work
Really impressive work. I might still have some more questions about the integrations. |
Pushed a bunch of fixup commits. I realized that I initially misunderstood how embedding shortening works, so I reworked the model metadata around that and also removed creation of discrete subvector columns for shortening-capable embeddings as it should be possible to create expression indexes on subvectors directly (with the yet-unreleased |
Oh, and I fixed handling of |
Looking good, I think! |
Implementation merged in #7174 is buggy, fix it.
Add abstract schema definitions for the new `ai` extension: 1. Provider config objects. 2. Abstract model types intended to house model metadata via annotations. 3. The `ext::ai::index` abstract object-level index similar to `fts::index`. 4. The `ext::ai::search` function similar to `fts::search`. 5. The `ext::ai_to_str_context` used to "stringify" objects returned by `ai::search` (or other search) for the purposes of generating text context for submission to an LLM. 6. The `ext::ai::ChatPrompt` type used to structurally define LLM chat prompts.
The `ext::ai::index` is (currently) an always-deferred index. Under the hood it adds several new columns to the relation of the object type it is declared on: one for each embedding vector variant (if the text extraction model supports outputs of varying dimensionality, a.k.a Matryoshka Representation), and one to denote if the embeddings are up-to-date with respect to the object content (maintained by a trigger). Declaration of `ext::ai::index` indexes also results in population of several internal views that expose objects-to-be-indexed for the deferred indexing process to consume. The text extraction model used to convert the index expression to an embedding is passed as a keyword argument when declaring a concrete index. Model metadata (name, dimensionality, limits etc) is then persisted on the index as internal annotations for ease of access.
The function is compiled into a vector distance search against the generated embeddings columns. The distance function used is determined by the arguments of `ext::ai::index` defined on the object type being searched.
The `ext::ai::to_context` takes an object and returns the result of evaluation of an `ext::ai::index` expression defined on it. It is an error to pass an object that does not have an `ext::ai::index`.
@scotttrinh, added you as a reviewer as I'm refactoring the auth ext tests a bit in the last few commits. |
That test looks good. Do we need some more for various schema manipulations? RAG? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Test changes look good!
This implements the auto-vectorization of content indexed with `ext::ai::index` via communication with the corresponding model API. The code here is generic, there are not concrete model API implementations yet.
This is a convenience proxy to the upstream model `/embeddings` API.
This `POST` request expects a JSON object in the request body, containing the following: * `model`: the name of the text generation model to use. Must match an exactly one `ext::ai::model_name` annotation on an `ext::ai::TextGenerationModel` subtype. * `query`: user text input, used to rank selected context objects as well as the user prompt to the AI assistant. * `context`: - `query`: arbitrary non-DML EdgeQL query returning a set of objects, which type must be indexed with `ext::ai::index`. - `variables`: values for any EdgeQL query variables in `query`. - `globals`: values for any EdgeQL globals that the `query` might depend on. - `max_object_count`: maximum count of objects to include in the prompt context after running the similarity search against `query`. * `prompt`: - `id`: ID of an existing `ext::ai::ChatPrompt` object containing prompt configuration for this request; - `name`: name (`.name`) of an existing `ext::ai::ChatPrompt` object containing prompt configuration for this request (mutually exclusive with `prompt.id`; - `custom`: a list of `{"role": ..., "content": ...}` prompt messages to add to the pre-defined prompt, if `prompt.id` or `prompt.name` are specified), or to use as the whole prompt if neither `prompt.id` or `prompt.name` are specified. The `role` might be either `system` (to configure the general parameters of the chat), `user` (user prompt) or `assistant` (constrains or prefixes the LLM response). * `stream`: if `true`, the response will be streamed as server-sent events (`text/event-stream`), otherwise the entire response is sent all at once a JSON object.
Wire in OpenAI models: `gpt-{3.5,4}-turbo`, as well as text embedding models: `text-embedding-3-{small,large}`.
Wire in Mistral models: `mistral-{small,medium,large}-latest`, as well as the `mistral-embed` text embedding model.
Wire in Antrhopic models: `claude-3-{haiku,sonnet,opus}`.
All tests in the test case use it, so it's an appropriate thing to do, and allows us to avoid setting the ContextVar in the mock server guts.
There's nothing specific to the auth extension in the mock HTTP server implementation and it will be useful in tests of other HTTP extensions, so move it to `testbase.http` and rename to `MockHttpServer`.
This adds the
ai
extension to EdgeDB, containing the followingfunctionality:
The new object-level
ext::ai::index
(similar tofts::index
) thatautomatically generates and indexes embeddings from the given index
expression.
A basic RAG interface via the
/ai/rag
HTTP endpoint that takesa query selecting from an
ai::index
-indexed type and uses it ascontext for a text generation question-answer completion.
The support for the above is implemented here for OpenAI, Mistral and
Anthropic, mostly to demonstrate multi-provider support. Model and
provider metadata is implemented as annotated abstract types in
ext::ai
,and the intent is that users can define models in their schemas by
extending from the appropriate base type in
ext::ai
.To test this out:
Apply the following schema:
Configure the OpenAI provider:
Insert some data:
Test RAG: