diff --git a/scripts/integration_tests.py b/scripts/integration_tests.py index be1cd41c..ac20f268 100644 --- a/scripts/integration_tests.py +++ b/scripts/integration_tests.py @@ -650,6 +650,44 @@ def test_vector_search_with_filter(primary, source_table): row[2] == filter_val ), f"Expected all results to have random_bool == {filter_val}" +@pytest.mark.parametrize("distance_metric", ["", "l2sq", "cos"]) +def test_weighted_vector_search(primary, distance_metric): + primary.execute("testdb", "CREATE TABLE IF NOT EXISTS small_world (id VARCHAR(3), b BOOLEAN, v VECTOR(3), s SPARSEVEC(3));") + primary.execute("testdb", """ + INSERT INTO small_world VALUES + ('000', TRUE, '[0,0,0]', '{}/3'), + ('001', TRUE, '[0,0,1]', '{3:1}/3'), + ('010', FALSE, '[0,1,0]' , '{2:1}/3'), + ('011', TRUE, '[0,1,1]', '{2:1,3:1}/3'), + ('100', FALSE, '[1,0,0]', '{1:1}/3'), + ('101', FALSE, '[1,0,1]', '{1:1,3:1}/3'), + ('110', FALSE, '[1,1,0]', '{1:1,2:1}/3'), + ('111', TRUE, '[1,1,1]', '{1:1,2:1,3:1}/3'); + """) + operator = op = { 'l2sq': '<->', 'cos': '<=>', 'hamming': '<+>' }[distance_metric or 'l2sq'] + query_s = "{1:0.4,2:0.3,3:0.2}/3" + query_v = "[-0.5,-0.1,-0.3]" + function = f'weighted_vector_search_{distance_metric}' if distance_metric else 'weighted_vector_search' + query = f""" + SELECT + id, + round(cast(0.9 * (s {operator} :'{query_s}'::sparsevec) + 0.1 * (v {operator} :'{query_v}'::vector) as numeric), 2) as dist + FROM lantern.{function}(CAST(NULL as "small_world"), operator=>'{operator}', + w1=> 0.9, col1=>'s'::text, vec1=>:'{query_s}'::sparsevec, + w2=> 0.1, col2=>'v'::text, vec2=>:'{query_v}'::vector + ); + LIMIT 3; + """ + res = primary.execute("testdb", query) + + expected_results_cos = [('111', 0.22), ('110', 0.24), ('101', 0.39)] + expected_results_l2sq = [('000', 0.54), ('100', 0.78), ('010', 0.87)] + if distance_metric == 'cos: + assert res == expected_results_cos + else: + assert res == expected_results_l2sq + + # fixture to handle external index server setup @pytest.fixture def external_index(request): diff --git a/sql/updates/0.3.2--0.3.3.sql b/sql/updates/0.3.2--0.3.3.sql index e52c6c0a..131cecdd 100644 --- a/sql/updates/0.3.2--0.3.3.sql +++ b/sql/updates/0.3.2--0.3.3.sql @@ -19,8 +19,8 @@ DECLARE -- function suffix, function default operator utility_functions text[2][] := ARRAY[ ARRAY['', '<->'], - ARRAY['_cos', '<->'], - ARRAY['_l2sq', '<=>'] + ARRAY['_cos', '<=>'], + ARRAY['_l2sq', '<->'] ]; BEGIN -- Check if the vector type from pgvector exists