Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Initial implementation of manual ngram-based search in MongoDB #993

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pydatalab/src/pydatalab/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,7 @@ def create_app(
extension.init_app(app)

pydatalab.mongo.create_default_indices()
pydatalab.mongo.create_ngram_item_index()

if CONFIG.FILE_DIRECTORY is not None:
pathlib.Path(CONFIG.FILE_DIRECTORY).mkdir(parents=False, exist_ok=True)
Expand Down
93 changes: 93 additions & 0 deletions pydatalab/src/pydatalab/mongo.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import collections
from typing import List, Optional

# Must be imported in this way to allow for easy patching with mongomock
Expand All @@ -12,6 +13,7 @@
"flask_mongo",
"check_mongo_connection",
"create_default_indices",
"create_ngram_item_index",
"_get_active_mongo_client",
"insert_pydantic_model_fork_safe",
"ITEMS_FTS_FIELDS",
Expand Down Expand Up @@ -193,3 +195,94 @@ def create_user_fts():
ret += create_user_fts()

return ret


def create_ngram_item_index(
client: Optional[pymongo.MongoClient] = None,
background: bool = False,
filter_top_ngrams: float | None = 0.5,
target_index_name: str = "ngram_fts_index",
):
from bson import ObjectId

if client is None:
client = _get_active_mongo_client()
db = client.get_database()

# construct manual ngram index
ngram_index: dict[ObjectId, set[str]] = {}
type_index: dict[ObjectId, str] = {}
item_count: int = 0
global_ngram_count: dict[str, int] = collections.defaultdict(int)
for item in db.items.find({}):
item_count += 1
ngrams: dict[str, int] = _generate_item_ngrams(item, ITEMS_FTS_FIELDS)
ngram_index[item["_id"]] = set(ngrams)
type_index[item["_id"]] = item["type"]
for g in ngrams:
global_ngram_count[g] += ngrams[g]

# filter out common ngrams that are found in filter_top_ngrams proportion of entries
# if filter_top_ngrams is not None:
# for ngram in global_ngram_count:
# if global_ngram_count[ngram] / item_count > filter_top_ngrams:
# for item in ngram_index:
# ngram_index[item].pop(ngram)

for _id, _ngrams in ngram_index.items():
db.items_fts.update_one(
{"_id": _id},
{"$set": {"type": type_index[_id], "_fts_ngrams": list(_ngrams)}},
upsert=True,
)

try:
result = db.items_fts.create_index(
[("_fts_ngrams", pymongo.ASCENDING), ("type", pymongo.ASCENDING)],
name=target_index_name,
background=background,
)
except pymongo.errors.OperationFailure:
db.users.drop_index(target_index_name)
result = db.items_fts.create_index(
[("_fts_ngrams", pymongo.ASCENDING), ("type", pymongo.ASCENDING)],
name=target_index_name,
background=background,
)

return result


def _generate_ngrams(value: str, n: int = 3) -> dict[str, int]:
import re

ngrams: dict[str, int] = collections.defaultdict(int)

if not value or len(value) < n:
return ngrams

# first, tokenize by whitespace and punctuation (a la normal mongodb fts)
toks = re.split(r"[\s.,!?@#$%^&*()[\]{}\-_+=;:\'\"/<>]+", value)

# then loop over tokens and ngrammify
for value in toks:
if len(value) < n:
continue
for v in ("".join(value[i : i + n].lower()) for i in range(len(value) - (n - 1))):
ngrams[v] += 1

return ngrams


def _generate_item_ngrams(item: dict, fts_fields: set[str], n: int = 3):
ngrams: dict[str, int] = collections.defaultdict(int)
for field in fts_fields:
value = item.get(field, None)
if value:
if field == "refcode" and ":" in value:
value = value.split(":")[1]
field_ngrams = _generate_ngrams(value)
for k in field_ngrams:
ngrams[k] += field_ngrams[k]

return ngrams
103 changes: 96 additions & 7 deletions pydatalab/src/pydatalab/routes/v0_1/items.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from flask import Blueprint, jsonify, redirect, request
from flask_login import current_user
from pydantic import ValidationError
from pymongo import ReturnDocument
from pymongo.command_cursor import CommandCursor

from pydatalab.blocks import BLOCK_TYPES
Expand All @@ -15,7 +16,7 @@
from pydatalab.models.items import Item
from pydatalab.models.relationships import RelationshipType
from pydatalab.models.utils import generate_unique_refcode
from pydatalab.mongo import flask_mongo
from pydatalab.mongo import ITEMS_FTS_FIELDS, _generate_item_ngrams, flask_mongo
from pydatalab.permissions import PUBLIC_USER_ID, active_users_or_get_only, get_default_permissions

ITEMS = Blueprint("items", __name__)
Expand Down Expand Up @@ -283,6 +284,72 @@ def get_samples():
return jsonify({"status": "success", "samples": list(get_samples_summary())})


@ITEMS.route("/search-items-ngram/", methods=["GET"])
def search_items_ngram():
"""Perform n-gram-based free text search on items and return the top results.

GET parameters:
query: String with the search terms.
nresults: Maximum number of (default 100)
types: If None, search all types of items. Otherwise, a list of strings
giving the types to consider. (e.g. ["samples","starting_materials"])

Returns:
response list of dictionaries containing the matching items in order of
descending match score.
"""

query = request.args.get("query", type=str)
nresults = request.args.get("nresults", default=100, type=int)
types = request.args.get("types", default=None)
if isinstance(types, str):
# should figure out how to parse as list automatically
types = types.split(",")

# split search string into trigrams
query = query.lower()
if len(query) < 3:
trigrams = [query]
trigrams = [query[i : i + 3] for i in range(len(query) - 2)]

match_obj = {
"_fts_ngrams": {"$in": trigrams},
**get_default_permissions(user_only=False),
}

if types is not None:
match_obj["type"] = {"$in": types}

cursor = flask_mongo.db.items_fts.aggregate(
[
{"$match": match_obj},
{"$limit": nresults},
{
"$lookup": {
"from": "items",
"localField": "_id",
"foreignField": "_id",
"as": "items",
}
},
{"$unwind": "$items"},
{"$replaceRoot": {"newRoot": {"$mergeObjects": ["$items"]}}},
{
"$project": {
"_id": 0,
"type": 1,
"item_id": 1,
"name": 1,
"chemform": 1,
"refcode": 1,
}
},
]
)

return jsonify({"status": "success", "items": list(cursor)}), 200


@ITEMS.route("/search-items/", methods=["GET"])
def search_items():
"""Perform free text search on items and return the top results.
Expand Down Expand Up @@ -511,6 +578,16 @@ def _create_sample(
400,
)

# Update ngram index, if configured
ngrams = _generate_item_ngrams(
flask_mongo.db.items.find_one(result.inserted_id), ITEMS_FTS_FIELDS
)
flask_mongo.db.items_fts.update_one(
{"_id": result.inserted_id},
{"$set": {"type": data_model.type, "_fts_ngrams": list(ngrams)}},
upsert=True,
)

sample_list_entry = {
"refcode": data_model.refcode,
"item_id": data_model.item_id,
Expand Down Expand Up @@ -608,11 +685,11 @@ def delete_sample():
request_json = request.get_json() # noqa: F821 pylint: disable=undefined-variable
item_id = request_json["item_id"]

result = flask_mongo.db.items.delete_one(
{"item_id": item_id, **get_default_permissions(user_only=True)}
deleted_doc = flask_mongo.db.items.find_one_and_delete(
{"item_id": item_id, **get_default_permissions(user_only=True)}, projection={"_id": 1}
)

if result.deleted_count != 1:
if deleted_doc is None:
return (
jsonify(
{
Expand All @@ -622,6 +699,10 @@ def delete_sample():
),
401,
)

# Update ngram index, if configured
flask_mongo.db.items_fts.delete_one({"_id": deleted_doc["_id"]})

return (
jsonify(
{
Expand Down Expand Up @@ -870,21 +951,29 @@ def save_item():
item.pop("collections")
item.pop("creators")

result = flask_mongo.db.items.update_one(
updated_doc = flask_mongo.db.items.find_one_and_update(
{"item_id": item_id},
{"$set": item},
return_document=ReturnDocument.AFTER,
)

if result.matched_count != 1:
if updated_doc is None:
return (
jsonify(
status="error",
message=f"{item_id} item update failed. no subdocument matched",
output=result.raw_result,
),
400,
)

# Update ngram index, if configured
ngrams = _generate_item_ngrams(updated_doc, ITEMS_FTS_FIELDS)
flask_mongo.db.items_fts.update_one(
{"_id": updated_doc["_id"]},
{"$set": {"type": updated_doc["type"], "_fts_ngrams": list(ngrams)}},
upsert=True,
)

return jsonify(status="success", last_modified=updated_data["last_modified"]), 200


Expand Down
105 changes: 105 additions & 0 deletions pydatalab/tests/server/test_ngram_fts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
"""This module tests fundamental routines around
the n-gram FTS.

"""

from pydatalab.mongo import _generate_item_ngrams, _generate_ngrams, create_ngram_item_index


def test_ngram_single_field():
field = "ABCDEF"
ngrams = _generate_ngrams(field, 3)
expected = ["abc", "bcd", "cde", "def"]
assert list(ngrams) == expected
assert all([ngrams[e] == 1 for e in expected])

field = "ABC"
ngrams = _generate_ngrams(field, 3)
assert ngrams == {"abc": 1}

field = "some: punctuation"
ngrams = _generate_ngrams(field, 3)
expected = ["som", "ome", "pun", "unc", "nct", "ctu", "tua", "uat", "ati", "tio", "ion"]
assert list(ngrams) == expected

field = "What about a full sentence? or: even, two?"
ngrams = _generate_ngrams(field, 3)
expected = [
"wha",
"hat",
"abo",
"bou",
"out",
"ful",
"ull",
"sen",
"ent",
"nte",
"ten",
"enc",
"nce",
"eve",
"ven",
"two",
]
assert list(ngrams) == expected
assert all([ngrams[e] == 1 for e in expected])


def test_ngram_item():
item = {"refcode": "ABCDEF"}
assert _generate_item_ngrams(item, {"refcode"}, n=3) == {"abc": 1, "bcd": 1, "cde": 1, "def": 1}


def test_ngram_fts_route(client, default_sample_dict, real_mongo_client, database):
default_sample_dict["item_id"] = "ABCDEF"
response = client.post("/new-sample/", json=default_sample_dict)
assert response.status_code == 201

# Check that creating the ngram index with existing items works
create_ngram_item_index(real_mongo_client, background=False, filter_top_ngrams=None)

doc = database.items_fts.find_one({})
ngrams = set(doc["_fts_ngrams"])
for ng in ["abc", "bcd", "cde", "def", "sam", "ple"]:
assert ng in ngrams
assert doc["type"] == "samples"

query_strings = ("ABC", "ABCDEF", "abcd", "cdef")

for q in query_strings:
response = client.get(f"/search-items-ngram/?query={q}&types=samples")
assert response.status_code == 200
assert response.json["status"] == "success"
assert len(response.json["items"]) == 1
assert response.json["items"][0]["item_id"] == "ABCDEF"

# Check that new items are added to the ngram index
default_sample_dict["item_id"] = "ABCDEF2"
response = client.post("/new-sample/", json=default_sample_dict)
assert response.status_code == 201

for q in query_strings:
response = client.get(f"/search-items-ngram/?query={q}&types=samples")
assert response.status_code == 200
assert response.json["status"] == "success"
assert len(response.json["items"]) == 2
assert response.json["items"][0]["item_id"] == "ABCDEF"
assert response.json["items"][1]["item_id"] == "ABCDEF2"

# Check that updates are reflected in the ngram index
# This test also makes sure that the string 'test' is not picked up from the refcode,
# which has an explicit carve out
default_sample_dict["description"] = "test string with punctuation"
update_req = {"item_id": "ABCDEF2", "data": default_sample_dict}
response = client.post("/save-item/", json=update_req)
assert response.status_code == 200

query_strings = ("test", "punctuation")

for q in query_strings:
response = client.get(f"/search-items-ngram/?query={q}&types=samples")
assert response.status_code == 200
assert response.json["status"] == "success"
assert len(response.json["items"]) == 1
assert response.json["items"][0]["item_id"] == "ABCDEF2"