From 7c29a9be2a67b8b21bcada6b2d61af218ddf4fb8 Mon Sep 17 00:00:00 2001 From: "Michael J. Sullivan" Date: Sun, 21 Apr 2024 20:07:03 -0700 Subject: [PATCH] Make ai::search have integrated sort and hit indexes (#7242) Tweak ai::search codegen to make it hit the index reliably even with filtering NULLs out. It seems that postgres *can* sometimes manage to use an ORDER BY index even when the function call isn't directly in the ORDER BY, but it is much more fragile (broken by adding the NULL check in #7223, for one). Making ai::search return sorted output makes it easy to hit the indexes and improves ergonomics. Also: * Compile the arguments in the enclosing scope, which helps us hit the index in more complex scenarios (like a cast from json) * Make sure to export a source rvar for `.object` --- edb/pgsql/compiler/relgen.py | 42 +++++++++++++++++--- tests/schemas/ext_ai.esdl | 6 +++ tests/test_edgeql_data_migration.py | 17 ++++---- tests/test_ext_ai.py | 60 +++++++++++++++++++++++++++++ 4 files changed, 111 insertions(+), 14 deletions(-) diff --git a/edb/pgsql/compiler/relgen.py b/edb/pgsql/compiler/relgen.py index 529f21f1a09..537e0b43dd2 100644 --- a/edb/pgsql/compiler/relgen.py +++ b/edb/pgsql/compiler/relgen.py @@ -47,6 +47,8 @@ from edb.schema import objects as s_obj from edb.schema import name as sn +from edb.edgeql import ast as qlast + from edb.ir import ast as irast from edb.ir import typeutils as irtyputils from edb.ir import utils as irutils @@ -3228,6 +3230,7 @@ def _compile_call_args( ir_set: irast.Set, *, skip: Collection[int] = (), + no_subquery_args: bool = False, ctx: context.CompilerContextLevel, ) -> List[pgast.BaseExpr]: """ @@ -3263,6 +3266,7 @@ def _compile_call_args( and ir_arg.cardinality.is_single() and (arg_typeref.is_scalar or arg_typeref.collection) and not _needs_arg_null_check(expr, ir_arg, typemod, ctx=ctx) + and not no_subquery_args ) if make_subquery: @@ -3991,9 +3995,28 @@ def _ext_ai_search_inner_pgvector( ], ) + # Install the filter directly in newctx.rel. We could return it + # and have it put in inner_ctx.rel, and that does seem to work, + # but seems weirder. valid = pgast.NullTest(arg=embedding, negated=True) + newctx.rel.where_clause = astutils.extend_binop( + newctx.rel.where_clause, valid + ) - return similarity, valid + # Do an integrated sort. This ensures we can hit the index, and is + # more ergonomic anyway. Having the ORDER BY operate directly on + # the function call is not the *only* way to have it work, but it + # is the most reliable. + sort_by = pgast.SortBy( + node=similarity, + dir=qlast.SortOrder.Asc, + nulls=qlast.NonesOrder.Last, + ) + if newctx.rel.sort_clause is None: + newctx.rel.sort_clause = [] + newctx.rel.sort_clause.append(sort_by) + + return similarity, None def _process_set_as_object_search( @@ -4003,6 +4026,15 @@ def _process_set_as_object_search( ctx: context.CompilerContextLevel, ) -> SetRVars: func_call = ir_set.expr + + # We skip the object, as it has to be compiled as rvar source. + # + # Also, disable subquery args. ai::search needs it for its + # scoping effects, but we don't need to use it here, since + # it can cause the ai search to duplicate arguments. + args_pg = _compile_call_args( + ir_set, skip={0}, no_subquery_args=True, ctx=ctx) + with ctx.subrel() as newctx: newctx.expr_exposed = False @@ -4010,10 +4042,6 @@ def _process_set_as_object_search( obj_id = obj_ir.path_id obj_rvar = ensure_source_rvar(obj_ir, newctx.rel, ctx=newctx) - # we skip the object, as it has to be compiled as rvar source - args_pg = _compile_call_args( - ir_set, skip={0}, ctx=newctx) - out_obj_id, out_score_id = func_call.tuple_path_ids with newctx.subrel() as inner_ctx: @@ -4082,13 +4110,15 @@ def _process_set_as_object_search( pathctx.put_path_id_map(newctx.rel, out_obj_id, obj_id) - aspects = {'value'} + aspects = {'value', 'source'} func_rvar = relctx.new_rel_rvar(ir_set, newctx.rel, ctx=ctx) relctx.include_rvar( ctx.rel, func_rvar, ir_set.path_id, aspects=aspects, ctx=ctx ) + pathctx.put_path_rvar(ctx.rel, out_obj_id, func_rvar, aspect='source') + return new_stmt_set_rvar(ir_set, ctx.rel, aspects=aspects, ctx=ctx) diff --git a/tests/schemas/ext_ai.esdl b/tests/schemas/ext_ai.esdl index ef0171ee923..f596e7544ed 100644 --- a/tests/schemas/ext_ai.esdl +++ b/tests/schemas/ext_ai.esdl @@ -42,3 +42,9 @@ type Stuff extending Astronomy { type Star extending Astronomy; type Supernova extending Star; + +function _set_seqscan(val: std::str) -> std::str { + using sql $$ + select set_config('enable_seqscan', val, true) + $$; +}; diff --git a/tests/test_edgeql_data_migration.py b/tests/test_edgeql_data_migration.py index a70619aa614..768cb5d28ed 100644 --- a/tests/test_edgeql_data_migration.py +++ b/tests/test_edgeql_data_migration.py @@ -12542,12 +12542,13 @@ async def test_edgeql_migration_ai_08(self): }; ''', explicit_modules=True) + arg = [0.0] * 1536 await self.con.query(''' select { - base := ext::ai::search(Base, >[1]), - sub := ext::ai::search(Sub, >[1]), + base := ext::ai::search(Base, >$0), + sub := ext::ai::search(Sub, >$0), } - ''') + ''', arg) await self.migrate(''' using extension ai; @@ -12571,10 +12572,10 @@ async def test_edgeql_migration_ai_08(self): await self.con.query(''' select { - base := ext::ai::search(Base, >[1]), - sub := ext::ai::search(Sub, >[1]), + base := ext::ai::search(Base, >$0), + sub := ext::ai::search(Sub, >$0), } - ''') + ''', arg) await self.migrate(''' using extension ai; @@ -12596,9 +12597,9 @@ async def test_edgeql_migration_ai_08(self): # Base lost the index, just select Sub await self.con.query(''' select { - sub := ext::ai::search(Sub, >[1]), + sub := ext::ai::search(Sub, >$0), } - ''') + ''', arg) class EdgeQLMigrationRewriteTestCase(EdgeQLDataMigrationTestCase): diff --git a/tests/test_ext_ai.py b/tests/test_ext_ai.py index 12a089681d1..e56a31d1b97 100644 --- a/tests/test_ext_ai.py +++ b/tests/test_ext_ai.py @@ -305,3 +305,63 @@ async def test_ext_ai_indexing_03(self): ], } ) + + async def _assert_index_use(self, query, *args): + def look(obj): + if isinstance(obj, dict) and obj.get('plan_type') == "IndexScan": + return any( + prop['title'] == 'index_name' + and f'ai::index' in prop['value'] + for prop in obj.get('properties', []) + ) + + if isinstance(obj, dict): + return any([look(v) for v in obj.values()]) + elif isinstance(obj, list): + return any(look(v) for v in obj) + else: + return False + + async with self._run_and_rollback(): + await self.con.execute('select _set_seqscan("off");') + plan = await self.con.query_json(f'analyze {query};', *args) + if not look(json.loads(plan)): + raise AssertionError(f'query did not use ext::ai::index index') + + async def test_ext_ai_indexing_04(self): + qv = [1.0, -2.0, 3.0, -4.0, 5.0, -6.0, 7.0, -8.0, 9.0, -10.0] + + await self._assert_index_use( + f''' + with vector := >$0 + select ext::ai::search(Stuff, vector) limit 5; + ''', + qv, + ) + await self._assert_index_use( + f''' + with vector := >$0 + select ext::ai::search(Stuff, vector).object limit 5; + ''', + qv, + ) + await self._assert_index_use( + f''' + select ext::ai::search(Stuff, >$0) limit 5; + ''', + qv, + ) + + await self._assert_index_use( + f''' + with vector := >$0 + select ext::ai::search(Stuff, vector) limit 5; + ''', + json.dumps(qv), + ) + await self._assert_index_use( + f''' + select ext::ai::search(Stuff, >$0) limit 5; + ''', + json.dumps(qv), + )