Skip to content

Commit

Permalink
Merge pull request #10 from lancedb/addFts
Browse files Browse the repository at this point in the history
Add Full text search query and query with filter
  • Loading branch information
LuQQiu authored Nov 6, 2024
2 parents 9cd26e2 + caf3773 commit 76a94a0
Show file tree
Hide file tree
Showing 3 changed files with 142 additions and 5 deletions.
25 changes: 21 additions & 4 deletions bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from datasets import load_dataset, DownloadConfig

from cloud.benchmark.util import await_indices, BenchmarkResults
from cloud.benchmark.query import QueryType, VectorQuery, FTSQuery


def add_benchmark_args(parser: argparse.ArgumentParser):
Expand Down Expand Up @@ -49,6 +50,13 @@ def add_benchmark_args(parser: argparse.ArgumentParser):
default=1000,
help="number of queries to run against each table",
)
parser.add_argument(
"--query-type",
type=str,
choices=[qt.value for qt in QueryType],
default=QueryType.VECTOR.value,
help="type of query to run",
)
parser.add_argument(
"--ingest",
type=bool,
Expand Down Expand Up @@ -87,6 +95,7 @@ def __init__(
num_tables: int,
batch_size: int,
num_queries: int,
query_type: str,
ingest: bool,
index: bool,
prefix: str,
Expand All @@ -108,6 +117,13 @@ def __init__(
region=os.getenv("LANCEDB_REGION", "us-east-1"),
)

if query_type == QueryType.VECTOR.value:
self.query_obj = VectorQuery()
elif query_type == QueryType.VECTOR_WITH_FILTER.value:
self.query_obj = VectorQuery(filter=True)
elif query_type == QueryType.FTS.value:
self.query_obj = FTSQuery()

self.tables: List[RemoteTable] = []
self.results = BenchmarkResults()
self.results.tables = num_tables
Expand Down Expand Up @@ -357,11 +373,9 @@ def _query_table(self, table: RemoteTable, warmup_queries=100):
self._add_percentiles("query", diffs)
return qps

def _query(self, table: RemoteTable, nprobes=1):
def _query(self, table: RemoteTable):
try:
table.search(np.random.standard_normal(1536)).metric("cosine").nprobes(
nprobes
).select(["openai", "title"]).to_arrow()
self.query_obj.query(table)
except Exception as e:
print(f"{table.name}: error during query: {e}")

Expand Down Expand Up @@ -406,6 +420,7 @@ def run_multi_benchmark(
num_tables: int,
batch_size: int,
num_queries: int,
query_type: str,
ingest: bool,
index: bool,
prefix: str,
Expand All @@ -421,6 +436,7 @@ def run_multi_benchmark(
"num_tables": num_tables,
"batch_size": batch_size,
"num_queries": num_queries,
"query_type": query_type,
"ingest": ingest,
"index": index,
"prefix": prefix, # Base prefix, will be modified per process
Expand Down Expand Up @@ -499,6 +515,7 @@ def main():
args.tables,
args.batch,
args.queries,
args.query_type,
args.ingest,
args.index,
args.prefix,
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "lancedb-cloud-benchmarks"
version = "0.1.1"
version = "0.1.2"
description = ""
authors = [{ name = "LanceDB Devs", email = "[email protected]" }]
readme = "README.md"
Expand Down
120 changes: 120 additions & 0 deletions src/cloud/benchmark/query.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
from abc import ABC, abstractmethod
from enum import Enum
import random
from typing import Any, List

from lancedb.remote.table import RemoteTable
import numpy as np

QUERY_WORDS = [
# Common nouns
"University",
"Institute",
"School",
"Museum",
"Library",
"History",
"Science",
"Art",
"Literature",
"Philosophy",
# Locations
"America",
"Europe",
"Asia",
"China",
"India",
"Japan",
"Russia",
"Germany",
"France",
"England",
# Organizations
"Company",
"Corporation",
"Association",
"Society",
"Foundation",
# Fields
"Technology",
"Engineering",
"Medicine",
"Economics",
"Politics",
]


class QueryType(Enum):
VECTOR = "vector"
VECTOR_WITH_FILTER = "vector_with_filter"
FTS = "fts"
HYBRID = "hybrid"


class Query(ABC):
"""Abstract base class for different query types"""

@abstractmethod
def query(self, table: RemoteTable, **kwargs) -> Any:
"""Execute the query on the given table"""
pass


class VectorQuery(Query):
def __init__(
self,
words: List[str] = None,
dim: int = 1536,
metric: str = "cosine",
nprobes: int = 1,
selected_columns: List[str] = None,
limit: int = 1,
filter: bool = False,
):
self.words = words or QUERY_WORDS
self.dim = dim
self.metric = metric
self.nprobes = nprobes
self.selected_columns = selected_columns or ["openai", "title"]
self.limit = limit
self.filter = filter

def query(self, table: RemoteTable, **kwargs) -> Any:
query = table.search(np.random.standard_normal(self.dim))

if self.filter:
filter_text = random.choice(self.words)
query = query.where(f"text LIKE '%{filter_text}%'", prefilter=True)

return (
query.metric(self.metric)
.nprobes(self.nprobes)
.select(self.selected_columns)
.limit(self.limit)
.to_arrow()
)


class FTSQuery(Query):
"""Simple full-text search implementation"""

def __init__(
self,
words: List[str] = None,
column: str = "title",
selected_columns: List[str] = None,
limit: int = 1,
):
self.words = words or QUERY_WORDS
self.column = column
self.selected_columns = selected_columns or ["title"]
self.limit = limit

def query(self, table: RemoteTable, **kwargs) -> Any:
query_text = random.choice(self.words)
return (
table.search(query_text, query_type="fts")
.select(self.selected_columns)
.limit(self.limit)
.to_arrow()
)

0 comments on commit 76a94a0

Please sign in to comment.