diff --git a/Cargo.lock b/Cargo.lock index c026864907d..beb86b22404 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -7563,6 +7563,7 @@ dependencies = [ "libsqlite3-sys", "log", "percent-encoding", + "regex", "serde", "serde_urlencoded", "sqlx-core", diff --git a/chromadb/test/property/test_filtering.py b/chromadb/test/property/test_filtering.py index 8fc80969c6b..fb8336ee4f1 100644 --- a/chromadb/test/property/test_filtering.py +++ b/chromadb/test/property/test_filtering.py @@ -6,6 +6,7 @@ from chromadb.test.property import invariants from chromadb.api.types import ( Document, + Documents, Embedding, Embeddings, GetResult, @@ -566,3 +567,33 @@ def test_query_ids_filter_property( # Also check that the number of results is reasonable assert len(result_ids) <= n_results assert len(result_ids) <= len(filter_ids_set) + + +def test_regex(client: ClientAPI) -> None: + """Tests that regex works""" + + reset(client) + coll = client.create_collection(name="test") + + test_ids: IDs = ["1", "2", "3"] + test_documents: Documents = ["cat", "Cat", "CAT"] + test_embeddings: Embeddings = [np.array([1, 1]), np.array([2, 2]), np.array([3, 3])] + test_metadatas: Metadatas = [{"test": 10}, {"test": 20}, {"test": 30}] + + coll.add( + ids=test_ids, + documents=test_documents, + embeddings=test_embeddings, + metadatas=test_metadatas, + ) + + res = coll.get(where_document={"$regex": "cat"}) + assert res["ids"] == ["1"] + + res = coll.get(where_document={"$regex": "(?i)cat"}) + assert sorted(res["ids"]) == ["1", "2", "3"] + + res = coll.get( + where={"test": {"$ne": 10}}, where_document={"$regex": "(?i)c(?-i)at"} # type: ignore[dict-item] + ) + assert res["ids"] == ["2"] diff --git a/rust/segment/src/sqlite_metadata.rs b/rust/segment/src/sqlite_metadata.rs index 527e8884e22..c5f630a0d9e 100644 --- a/rust/segment/src/sqlite_metadata.rs +++ b/rust/segment/src/sqlite_metadata.rs @@ -523,16 +523,27 @@ impl IntoSqliteExpr for DocumentExpression { let subq = Query::select() .column(EmbeddingFulltextSearch::Rowid) .from(EmbeddingFulltextSearch::Table) - .and_where( - Expr::col(EmbeddingFulltextSearch::StringValue) - .like(format!("%{}%", self.pattern.replace("%", ""))), - ) + .and_where(match self.operator { + DocumentOperator::Contains | DocumentOperator::NotContains => { + Expr::col(EmbeddingFulltextSearch::StringValue) + .like(format!("%{}%", self.pattern.replace("%", ""))) + } + DocumentOperator::Regex | DocumentOperator::NotRegex => Expr::cust_with_exprs( + "? REGEXP ?", + [ + Expr::col(EmbeddingFulltextSearch::StringValue).into(), + Expr::value(&self.pattern), + ], + ), + }) .to_owned(); match self.operator { - DocumentOperator::Contains => Expr::col((Embeddings::Table, Embeddings::Id)).in_subquery(subq), - DocumentOperator::NotContains => Expr::col((Embeddings::Table, Embeddings::Id)).not_in_subquery(subq), - DocumentOperator::Regex => todo!("Implement Regex matching. The result must be a not-nullable boolean (use `.is(true)`)"), - DocumentOperator::NotRegex => todo!("Implement negated Regex matching. This must be exact opposite of Regex matching (use `.not()`)"), + DocumentOperator::Contains | DocumentOperator::Regex => { + Expr::col((Embeddings::Table, Embeddings::Id)).in_subquery(subq) + } + DocumentOperator::NotContains | DocumentOperator::NotRegex => { + Expr::col((Embeddings::Table, Embeddings::Id)).not_in_subquery(subq) + } } } } diff --git a/rust/sqlite/Cargo.toml b/rust/sqlite/Cargo.toml index 6ca50a4fb6c..319d5f4a2ef 100644 --- a/rust/sqlite/Cargo.toml +++ b/rust/sqlite/Cargo.toml @@ -9,7 +9,7 @@ regex = { workspace = true } sea-query = { workspace = true, features = ["derive"] } sea-query-binder = { workspace = true, features = ["sqlx-sqlite"] } sha2 = { workspace = true } -sqlx = { workspace = true } +sqlx = { workspace = true, features = ["regexp"] } tempfile = { workspace = true } pyo3 = { workspace = true, optional = true } thiserror = { workspace = true } @@ -22,4 +22,3 @@ rust-embed = { workspace = true } chroma-error = { workspace = true, features = ["sqlx"] } chroma-types = { workspace = true } chroma-config = { workspace = true } - diff --git a/rust/sqlite/src/config.rs b/rust/sqlite/src/config.rs index abb8b142919..baaf753c236 100644 --- a/rust/sqlite/src/config.rs +++ b/rust/sqlite/src/config.rs @@ -104,7 +104,8 @@ impl Configurable for SqliteDb { // we turn it off .pragma("foreign_keys", "OFF") .pragma("case_sensitive_like", "ON") - .busy_timeout(Duration::from_secs(1000)); + .busy_timeout(Duration::from_secs(1000)) + .with_regexp(); let conn = if let Some(url) = &config.url { let path = Path::new(url); if let Some(parent) = path.parent() {