diff --git a/lib/sycamore/sycamore/connectors/opensearch/opensearch_reader.py b/lib/sycamore/sycamore/connectors/opensearch/opensearch_reader.py index 70fb9f4ef..89c6cf1b8 100644 --- a/lib/sycamore/sycamore/connectors/opensearch/opensearch_reader.py +++ b/lib/sycamore/sycamore/connectors/opensearch/opensearch_reader.py @@ -80,7 +80,7 @@ class OpenSearchReaderQueryResponse(BaseDBReader.QueryResponse): def to_docs(self, query_params: "BaseDBReader.QueryParams") -> list[Document]: assert isinstance(query_params, OpenSearchReaderQueryParams) - result = [] + result: list[Document] = [] if not query_params.reconstruct_document: for data in self.output: doc = Document( @@ -141,6 +141,10 @@ def to_docs(self, query_params: "BaseDBReader.QueryParams") -> list[Document]: result = list(unique_docs.values()) + # sort elements per doc + for doc in result: + doc.elements.sort(key=lambda e: e.element_index if e.element_index is not None else float("inf")) + return result def _get_all_elements_for_doc_ids(self, doc_ids: list[str], index: str) -> list[typing.Any]: diff --git a/lib/sycamore/sycamore/docset.py b/lib/sycamore/sycamore/docset.py index f6522eb5c..ba3420772 100644 --- a/lib/sycamore/sycamore/docset.py +++ b/lib/sycamore/sycamore/docset.py @@ -4,6 +4,7 @@ import pprint import sys from typing import Callable, Optional, Any, Iterable, Type, Union, TYPE_CHECKING +import re from sycamore.context import Context, context_params, OperationTypes from sycamore.data import Document, Element, MetadataDocument @@ -956,6 +957,7 @@ def process_doc(doc: Document) -> Document: return self.map(process_doc, **resource_args) @context_params(OperationTypes.BINARY_CLASSIFIER) + @context_params(OperationTypes.TEXT_SIMILARITY) def llm_filter( self, llm: LLM, @@ -963,6 +965,10 @@ def llm_filter( prompt: Union[list[dict], str], field: str = "text_representation", threshold: int = 3, + keep_none: bool = False, + use_elements: bool = False, + similarity_query: Optional[str] = None, + similarity_scorer: Optional[SimilarityScorer] = None, **resource_args, ) -> "DocSet": """ @@ -974,29 +980,55 @@ def llm_filter( new_field: The field that will be added to the DocSet with the outputs. prompt: LLM prompt. field: Document field to filter based on. - threshold: Cutoff that determines whether or not to keep document. + threshold: If the value of the computed result is an integer value greater than or equal to this threshold, + the document will be kept. + keep_none: keep records with a None value for the provided field to filter on. + Warning: using this might hide data corruption issues. + use_elements: use contents of a document's elements to filter as opposed to document level contents. + similarity_query: query string to compute similarity against. Also requires a 'similarity_scorer'. + similarity_scorer: scorer used to generate similarity scores used in element sorting. + Also requires a 'similarity_query'. **resource_args Returns: A filtered DocSet. """ - - def threshold_filter(doc: Document, threshold) -> bool: - try: - return_value = int(doc.properties[new_field]) >= threshold - except Exception: - # accounts for llm output errors - return_value = False - - return return_value - - docset = self.filter(lambda doc: doc.field_to_value(field) is not None and doc.field_to_value(field) != "None") - entity_extractor = OpenAIEntityExtractor( entity_name=new_field, llm=llm, use_elements=False, prompt=prompt, field=field ) - docset = docset.extract_entity(entity_extractor=entity_extractor) - docset = docset.filter(lambda doc: threshold_filter(doc, threshold), **resource_args) + + def threshold_filter(doc: Document, threshold) -> bool: + if not use_elements: + if doc.field_to_value(field) is None: + return keep_none + doc = entity_extractor.extract_entity(doc) + # todo: move data extraction and validation to entity extractor + return int(re.findall(r"\d+", doc.properties[new_field])[0]) >= threshold + + if similarity_query or similarity_scorer: + assert similarity_scorer is not None, "Similarity sorting requires a scorer" + assert similarity_query is not None, "Similarity sorting requires a string query" + score_property_name = f"{field}_similarity_score" + doc = similarity_scorer.generate_similarity_scores( + doc_batch=[doc], query=similarity_query, score_property_name=score_property_name + )[0] + doc.elements.sort(key=lambda e: e.properties.get(score_property_name, float("-inf")), reverse=True) + evaluated_elements = 0 + for element in doc.elements: + e_doc = Document(element.data) + if e_doc.field_to_value(field) is None: + continue + e_doc = entity_extractor.extract_entity(e_doc) + element.properties[new_field] = e_doc.properties[new_field] + # todo: move data extraction and validation to entity extractor + if int(re.findall(r"\d+", element.properties[new_field])[0]) >= threshold: + return True + evaluated_elements += 1 + if evaluated_elements == 0: # no elements found for property + return keep_none + return False + + docset = self.filter(lambda doc: threshold_filter(doc, threshold), **resource_args) return docset diff --git a/lib/sycamore/sycamore/tests/integration/connectors/opensearch/test_opensearch_read.py b/lib/sycamore/sycamore/tests/integration/connectors/opensearch/test_opensearch_read.py index 19886ad1e..694dcb587 100644 --- a/lib/sycamore/sycamore/tests/integration/connectors/opensearch/test_opensearch_read.py +++ b/lib/sycamore/sycamore/tests/integration/connectors/opensearch/test_opensearch_read.py @@ -106,3 +106,6 @@ def test_ingest_and_read(self, setup_index, exec_mode): assert len(retrieved_materialized_reconstructed) == 1 doc = retrieved_materialized_reconstructed[0] assert len(doc.elements) == len(retrieved_materialized) - 1 # drop the document parent record + + for i in range(len(doc.elements) - 1): + assert doc.elements[i].element_index < doc.elements[i + 1].element_index diff --git a/lib/sycamore/sycamore/tests/unit/test_docset.py b/lib/sycamore/sycamore/tests/unit/test_docset.py index 6aa256899..e1bf89e7e 100644 --- a/lib/sycamore/sycamore/tests/unit/test_docset.py +++ b/lib/sycamore/sycamore/tests/unit/test_docset.py @@ -1,12 +1,13 @@ import random import string from typing import Callable, Optional +from unittest.mock import MagicMock import pytest import sycamore from sycamore import DocSet, Context -from sycamore.context import OperationTypes +from sycamore.context import OperationTypes, ExecMode from sycamore.data import Document, Element from sycamore.llms.prompts.default_prompts import ( LlmClusterEntityAssignGroupsMessagesPrompt, @@ -45,9 +46,9 @@ def __init__(self): def generate(self, *, prompt_kwargs: dict, llm_kwargs: Optional[dict] = None): if prompt_kwargs == {"messages": [{"role": "user", "content": "test1"}]} and llm_kwargs == {}: - return 4 + return "4" elif prompt_kwargs == {"messages": [{"role": "user", "content": "test2"}]} and llm_kwargs == {}: - return 2 + return "2" elif ( prompt_kwargs["messages"] @@ -76,6 +77,15 @@ def is_chat_mode(self): return True +class TestSimilarityScorer(SimilarityScorer): + + def score(self, inputs: list[tuple[str, str]]) -> list[float]: + results = [] + for _, content in inputs: + results += [1.0 if content == "test2" else 0.0] + return results + + class TestDocSet: @pytest.fixture def number_docset(self) -> DocSet: @@ -371,10 +381,137 @@ def test_count_distinct(self): docset = context.read.document(docs) assert docset.count_distinct("doc_id") == 9 - def test_llm_filter(self): + def test_llm_filter_with_doc_structure(self): + doc_list = [ + Document( + doc_id="doc_1", + elements=[ + Element(text_representation="test1"), # llm_filter result = 4 + Element(text_representation="test1"), # llm_filter result = 4 + ], + ), + Document( + doc_id="doc_2", + elements=[ + Element(text_representation="test2"), # llm_filter result = 2, + Element(text_representation="test1"), # llm_filter result = 4 + ], + ), + Document( + doc_id="doc_3", + elements=[ + Element(text_representation="test2"), # llm_filter result = 2 + ], + ), + Document(doc_id="doc_4", text_representation="empty elements, maybe an exploded doc", elements=[]), + ] + mock_llm = MockLLM() + mock_llm.generate = MagicMock(wraps=mock_llm.generate) + context = sycamore.init(params={OperationTypes.BINARY_CLASSIFIER: {"llm": mock_llm}}, exec_mode=ExecMode.LOCAL) + docset = context.read.document(doc_list) + new_field = "_autogen_LLMFilterOutput" + + filtered_docset = docset.llm_filter( + new_field=new_field, prompt=[], field="text_representation", threshold=4, use_elements=True + ) + + taken = filtered_docset.take() + assert len(taken) == 2 + assert taken[0].doc_id == "doc_1" + assert taken[1].doc_id == "doc_2" + assert mock_llm.generate.call_count == 4 + + filtered_docset = docset.llm_filter( + new_field=new_field, prompt=[], field="text_representation", threshold=2, use_elements=True + ) + + taken = filtered_docset.take() + assert mock_llm.generate.call_count == (4 + 3) + assert len(taken) == 3 + assert taken[0].doc_id == "doc_1" + assert taken[1].doc_id == "doc_2" + assert taken[2].doc_id == "doc_3" + def test_llm_filter_with_doc_structure_with_similarity_sorting(self): + doc_list = [ + Document( + doc_id="doc_1", + elements=[ + Element(properties={"_element_index": 1}, text_representation="test1"), # llm_filter result = 4 + Element(properties={"_element_index": 2}, text_representation="test1"), # llm_filter result = 4 + ], + ), + Document( + doc_id="doc_2", + elements=[ + Element(properties={"_element_index": 4}, text_representation="test2"), # llm_filter result = 2, + Element(properties={"_element_index": 9}, text_representation="test1"), # llm_filter result = 4 + ], + ), + Document( + doc_id="doc_3", + elements=[ + Element(properties={"_element_index": 1}, text_representation="test2"), # llm_filter result = 2 + ], + ), + Document(doc_id="doc_4", text_representation="empty elements, maybe an exploded doc", elements=[]), + ] + mock_llm = MockLLM() + similarity_scorer = TestSimilarityScorer() + mock_llm.generate = MagicMock(wraps=mock_llm.generate) + context = sycamore.init( + params={ + OperationTypes.BINARY_CLASSIFIER: {"llm": mock_llm}, + OperationTypes.TEXT_SIMILARITY: {"similarity_scorer": similarity_scorer}, + }, + exec_mode=ExecMode.LOCAL, + ) + docset = context.read.document(doc_list) + new_field = "_autogen_LLMFilterOutput" + + filtered_docset = docset.llm_filter( + new_field=new_field, + prompt=[], + field="text_representation", + threshold=4, + use_elements=True, + similarity_scorer=similarity_scorer, + similarity_query="this is an unused query because unit test", + ) + + """ + "test2" elements will be in front, resulting in 2 llm calls for doc_2 (first element threshold miss), + 1 each for other 2. + """ + taken = filtered_docset.take() + assert len(taken) == 2 + assert taken[0].doc_id == "doc_1" + assert taken[1].doc_id == "doc_2" + assert mock_llm.generate.call_count == 4 + + filtered_docset = docset.llm_filter( + new_field=new_field, + prompt=[], + field="text_representation", + threshold=2, + use_elements=True, + similarity_scorer=similarity_scorer, + similarity_query="this is an unused query because unit test", + ) + + """ + "test2" elements will be in front, resulting in 1 llm calls for doc_2 (threshold matches), 1 for other 2 + """ + taken = filtered_docset.take() + assert mock_llm.generate.call_count == (4 + 3) + assert len(taken) == 3 + assert taken[0].doc_id == "doc_1" + assert taken[1].doc_id == "doc_2" + assert taken[2].doc_id == "doc_3" + + def test_llm_filter(self): doc_list = [Document(text_representation="test1"), Document(text_representation="test2")] - context = sycamore.init(params={OperationTypes.BINARY_CLASSIFIER: {"llm": MockLLM()}}) + context = sycamore.init(params={OperationTypes.BINARY_CLASSIFIER: {"llm": MockLLM()}}, exec_mode=ExecMode.LOCAL) docset = context.read.document(doc_list) new_field = "_autogen_LLMFilterOutput" @@ -395,8 +532,21 @@ def test_llm_filter(self): elif doc.text_representation == "test2": assert int(doc.properties[new_field]) == 2 - def test_groupby_count(self, fruits_docset): + def test_llm_filter_with_keep_none(self): + doc_list = [Document(text_representation="test1"), Document(text_representation="test2")] + context = sycamore.init(params={OperationTypes.BINARY_CLASSIFIER: {"llm": MockLLM()}}, exec_mode=ExecMode.LOCAL) + docset = context.read.document(doc_list) + new_field = "_autogen_LLMFilterOutput" + + filtered_docset = docset.llm_filter( + new_field=new_field, prompt=[], field="missing_field", threshold=5, keep_none=True + ).take() + assert len(filtered_docset) == 2 + assert filtered_docset[0].text_representation == "test1" + assert filtered_docset[1].text_representation == "test2" + + def test_groupby_count(self, fruits_docset): grouped_docset = fruits_docset.groupby_count(field="text_representation") assert grouped_docset.count() == 3 for doc in grouped_docset.take():