Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Get llm_filter to support document structure + similarity sorting for elements #876

Merged
merged 5 commits into from
Oct 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ class OpenSearchReaderQueryResponse(BaseDBReader.QueryResponse):

def to_docs(self, query_params: "BaseDBReader.QueryParams") -> list[Document]:
assert isinstance(query_params, OpenSearchReaderQueryParams)
result = []
result: list[Document] = []
if not query_params.reconstruct_document:
for data in self.output:
doc = Document(
Expand Down Expand Up @@ -141,6 +141,10 @@ def to_docs(self, query_params: "BaseDBReader.QueryParams") -> list[Document]:

result = list(unique_docs.values())

# sort elements per doc
for doc in result:
doc.elements.sort(key=lambda e: e.element_index if e.element_index is not None else float("inf"))
baitsguy marked this conversation as resolved.
Show resolved Hide resolved

return result

def _get_all_elements_for_doc_ids(self, doc_ids: list[str], index: str) -> list[typing.Any]:
Expand Down
62 changes: 47 additions & 15 deletions lib/sycamore/sycamore/docset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import pprint
import sys
from typing import Callable, Optional, Any, Iterable, Type, Union, TYPE_CHECKING
import re

from sycamore.context import Context, context_params, OperationTypes
from sycamore.data import Document, Element, MetadataDocument
Expand Down Expand Up @@ -956,13 +957,18 @@ def process_doc(doc: Document) -> Document:
return self.map(process_doc, **resource_args)

@context_params(OperationTypes.BINARY_CLASSIFIER)
@context_params(OperationTypes.TEXT_SIMILARITY)
def llm_filter(
self,
llm: LLM,
new_field: str,
prompt: Union[list[dict], str],
field: str = "text_representation",
threshold: int = 3,
keep_none: bool = False,
use_elements: bool = False,
similarity_query: Optional[str] = None,
similarity_scorer: Optional[SimilarityScorer] = None,
**resource_args,
) -> "DocSet":
"""
Expand All @@ -974,29 +980,55 @@ def llm_filter(
new_field: The field that will be added to the DocSet with the outputs.
prompt: LLM prompt.
field: Document field to filter based on.
threshold: Cutoff that determines whether or not to keep document.
threshold: If the value of the computed result is an integer value greater than or equal to this threshold,
the document will be kept.
keep_none: keep records with a None value for the provided field to filter on.
Warning: using this might hide data corruption issues.
use_elements: use contents of a document's elements to filter as opposed to document level contents.
similarity_query: query string to compute similarity against. Also requires a 'similarity_scorer'.
similarity_scorer: scorer used to generate similarity scores used in element sorting.
Also requires a 'similarity_query'.
**resource_args

Returns:
A filtered DocSet.
"""

def threshold_filter(doc: Document, threshold) -> bool:
try:
return_value = int(doc.properties[new_field]) >= threshold
except Exception:
# accounts for llm output errors
return_value = False

return return_value

docset = self.filter(lambda doc: doc.field_to_value(field) is not None and doc.field_to_value(field) != "None")

entity_extractor = OpenAIEntityExtractor(
entity_name=new_field, llm=llm, use_elements=False, prompt=prompt, field=field
)
docset = docset.extract_entity(entity_extractor=entity_extractor)
docset = docset.filter(lambda doc: threshold_filter(doc, threshold), **resource_args)

def threshold_filter(doc: Document, threshold) -> bool:
if not use_elements:
if doc.field_to_value(field) is None:
return keep_none
doc = entity_extractor.extract_entity(doc)
# todo: move data extraction and validation to entity extractor
return int(re.findall(r"\d+", doc.properties[new_field])[0]) >= threshold
baitsguy marked this conversation as resolved.
Show resolved Hide resolved

if similarity_query or similarity_scorer:
assert similarity_scorer is not None, "Similarity sorting requires a scorer"
assert similarity_query is not None, "Similarity sorting requires a string query"
score_property_name = f"{field}_similarity_score"
doc = similarity_scorer.generate_similarity_scores(
doc_batch=[doc], query=similarity_query, score_property_name=score_property_name
)[0]
doc.elements.sort(key=lambda e: e.properties.get(score_property_name, float("-inf")), reverse=True)
evaluated_elements = 0
for element in doc.elements:
e_doc = Document(element.data)
baitsguy marked this conversation as resolved.
Show resolved Hide resolved
if e_doc.field_to_value(field) is None:
continue
e_doc = entity_extractor.extract_entity(e_doc)
element.properties[new_field] = e_doc.properties[new_field]
# todo: move data extraction and validation to entity extractor
if int(re.findall(r"\d+", element.properties[new_field])[0]) >= threshold:
baitsguy marked this conversation as resolved.
Show resolved Hide resolved
return True
evaluated_elements += 1
if evaluated_elements == 0: # no elements found for property
return keep_none
return False

docset = self.filter(lambda doc: threshold_filter(doc, threshold), **resource_args)

return docset

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,3 +106,6 @@ def test_ingest_and_read(self, setup_index, exec_mode):
assert len(retrieved_materialized_reconstructed) == 1
doc = retrieved_materialized_reconstructed[0]
assert len(doc.elements) == len(retrieved_materialized) - 1 # drop the document parent record

for i in range(len(doc.elements) - 1):
assert doc.elements[i].element_index < doc.elements[i + 1].element_index
162 changes: 156 additions & 6 deletions lib/sycamore/sycamore/tests/unit/test_docset.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import random
import string
from typing import Callable, Optional
from unittest.mock import MagicMock

import pytest

import sycamore
from sycamore import DocSet, Context
from sycamore.context import OperationTypes
from sycamore.context import OperationTypes, ExecMode
from sycamore.data import Document, Element
from sycamore.llms.prompts.default_prompts import (
LlmClusterEntityAssignGroupsMessagesPrompt,
Expand Down Expand Up @@ -45,9 +46,9 @@ def __init__(self):

def generate(self, *, prompt_kwargs: dict, llm_kwargs: Optional[dict] = None):
if prompt_kwargs == {"messages": [{"role": "user", "content": "test1"}]} and llm_kwargs == {}:
return 4
return "4"
elif prompt_kwargs == {"messages": [{"role": "user", "content": "test2"}]} and llm_kwargs == {}:
return 2
return "2"

elif (
prompt_kwargs["messages"]
Expand Down Expand Up @@ -76,6 +77,15 @@ def is_chat_mode(self):
return True


class TestSimilarityScorer(SimilarityScorer):

def score(self, inputs: list[tuple[str, str]]) -> list[float]:
results = []
for _, content in inputs:
results += [1.0 if content == "test2" else 0.0]
return results


class TestDocSet:
@pytest.fixture
def number_docset(self) -> DocSet:
Expand Down Expand Up @@ -371,10 +381,137 @@ def test_count_distinct(self):
docset = context.read.document(docs)
assert docset.count_distinct("doc_id") == 9

def test_llm_filter(self):
def test_llm_filter_with_doc_structure(self):
doc_list = [
Document(
doc_id="doc_1",
elements=[
Element(text_representation="test1"), # llm_filter result = 4
Element(text_representation="test1"), # llm_filter result = 4
],
),
Document(
doc_id="doc_2",
elements=[
Element(text_representation="test2"), # llm_filter result = 2,
Element(text_representation="test1"), # llm_filter result = 4
],
),
Document(
doc_id="doc_3",
elements=[
Element(text_representation="test2"), # llm_filter result = 2
],
),
Document(doc_id="doc_4", text_representation="empty elements, maybe an exploded doc", elements=[]),
]
mock_llm = MockLLM()
mock_llm.generate = MagicMock(wraps=mock_llm.generate)
context = sycamore.init(params={OperationTypes.BINARY_CLASSIFIER: {"llm": mock_llm}}, exec_mode=ExecMode.LOCAL)
docset = context.read.document(doc_list)
new_field = "_autogen_LLMFilterOutput"

filtered_docset = docset.llm_filter(
new_field=new_field, prompt=[], field="text_representation", threshold=4, use_elements=True
)

taken = filtered_docset.take()
assert len(taken) == 2
assert taken[0].doc_id == "doc_1"
assert taken[1].doc_id == "doc_2"
assert mock_llm.generate.call_count == 4

filtered_docset = docset.llm_filter(
new_field=new_field, prompt=[], field="text_representation", threshold=2, use_elements=True
)

taken = filtered_docset.take()
assert mock_llm.generate.call_count == (4 + 3)
assert len(taken) == 3
assert taken[0].doc_id == "doc_1"
assert taken[1].doc_id == "doc_2"
assert taken[2].doc_id == "doc_3"

def test_llm_filter_with_doc_structure_with_similarity_sorting(self):
doc_list = [
Document(
doc_id="doc_1",
elements=[
Element(properties={"_element_index": 1}, text_representation="test1"), # llm_filter result = 4
Element(properties={"_element_index": 2}, text_representation="test1"), # llm_filter result = 4
],
),
Document(
doc_id="doc_2",
elements=[
Element(properties={"_element_index": 4}, text_representation="test2"), # llm_filter result = 2,
Element(properties={"_element_index": 9}, text_representation="test1"), # llm_filter result = 4
],
),
Document(
doc_id="doc_3",
elements=[
Element(properties={"_element_index": 1}, text_representation="test2"), # llm_filter result = 2
],
),
Document(doc_id="doc_4", text_representation="empty elements, maybe an exploded doc", elements=[]),
]
mock_llm = MockLLM()
similarity_scorer = TestSimilarityScorer()
mock_llm.generate = MagicMock(wraps=mock_llm.generate)
context = sycamore.init(
params={
OperationTypes.BINARY_CLASSIFIER: {"llm": mock_llm},
OperationTypes.TEXT_SIMILARITY: {"similarity_scorer": similarity_scorer},
},
exec_mode=ExecMode.LOCAL,
)
docset = context.read.document(doc_list)
new_field = "_autogen_LLMFilterOutput"

filtered_docset = docset.llm_filter(
new_field=new_field,
prompt=[],
field="text_representation",
threshold=4,
use_elements=True,
similarity_scorer=similarity_scorer,
similarity_query="this is an unused query because unit test",
)

"""
"test2" elements will be in front, resulting in 2 llm calls for doc_2 (first element threshold miss),
1 each for other 2.
"""
taken = filtered_docset.take()
assert len(taken) == 2
assert taken[0].doc_id == "doc_1"
assert taken[1].doc_id == "doc_2"
assert mock_llm.generate.call_count == 4

filtered_docset = docset.llm_filter(
new_field=new_field,
prompt=[],
field="text_representation",
threshold=2,
use_elements=True,
similarity_scorer=similarity_scorer,
similarity_query="this is an unused query because unit test",
)

"""
"test2" elements will be in front, resulting in 1 llm calls for doc_2 (threshold matches), 1 for other 2
"""
taken = filtered_docset.take()
assert mock_llm.generate.call_count == (4 + 3)
assert len(taken) == 3
assert taken[0].doc_id == "doc_1"
assert taken[1].doc_id == "doc_2"
assert taken[2].doc_id == "doc_3"

def test_llm_filter(self):
doc_list = [Document(text_representation="test1"), Document(text_representation="test2")]
context = sycamore.init(params={OperationTypes.BINARY_CLASSIFIER: {"llm": MockLLM()}})
context = sycamore.init(params={OperationTypes.BINARY_CLASSIFIER: {"llm": MockLLM()}}, exec_mode=ExecMode.LOCAL)
docset = context.read.document(doc_list)
new_field = "_autogen_LLMFilterOutput"

Expand All @@ -395,8 +532,21 @@ def test_llm_filter(self):
elif doc.text_representation == "test2":
assert int(doc.properties[new_field]) == 2

def test_groupby_count(self, fruits_docset):
def test_llm_filter_with_keep_none(self):
doc_list = [Document(text_representation="test1"), Document(text_representation="test2")]
context = sycamore.init(params={OperationTypes.BINARY_CLASSIFIER: {"llm": MockLLM()}}, exec_mode=ExecMode.LOCAL)
docset = context.read.document(doc_list)
new_field = "_autogen_LLMFilterOutput"

filtered_docset = docset.llm_filter(
new_field=new_field, prompt=[], field="missing_field", threshold=5, keep_none=True
).take()

assert len(filtered_docset) == 2
assert filtered_docset[0].text_representation == "test1"
assert filtered_docset[1].text_representation == "test2"

def test_groupby_count(self, fruits_docset):
grouped_docset = fruits_docset.groupby_count(field="text_representation")
assert grouped_docset.count() == 3
for doc in grouped_docset.take():
Expand Down
Loading