Skip to content

Commit

Permalink
fix (#873)
Browse files Browse the repository at this point in the history
  • Loading branch information
baitsguy authored Oct 4, 2024
1 parent 0a3ad99 commit 61b66a2
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 14 deletions.
32 changes: 22 additions & 10 deletions lib/sycamore/sycamore/tests/integration/transforms/test_rerank.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import sycamore
from sycamore import ExecMode

from sycamore.data import Document
from sycamore.transforms.similarity import HuggingFaceTransformersSimilarityScorer
Expand All @@ -13,39 +14,47 @@ def test_rerank_docset():
{
"doc_id": 1,
"elements": [
{"text_representation": "here is an animal that meows"},
{"properties": {"_element_index": 1}, "text_representation": "here is an animal that meows"},
],
},
{
"doc_id": 2,
"elements": [
{"id": 7, "text_representation": "this is a cat"},
{"id": 1, "text_representation": "here is an animal that moos"},
{"id": 7, "properties": {"_element_index": 7}, "text_representation": "this is a cat"},
{"id": 1, "properties": {"_element_index": 1}, "text_representation": "here is an animal that moos"},
],
},
{
"doc_id": 3,
"elements": [
{"text_representation": "here is an animal that moos"},
{"properties": {"_element_index": 1}, "text_representation": "here is an animal that moos"},
],
},
{ # handle element with not text
"doc_id": 4,
"elements": [
{"id": 1},
{"id": 1, "properties": {"_element_index": 1}},
],
},
{
"doc_id": 5,
"elements": [
{
"properties": {"_element_index": 1},
"text_representation": "the number of pages in this document are 253",
}
],
},
{"doc_id": 5, "elements": [{"text_representation": "the number of pages in this document are 253"}]},
{ # drop because of limit
"doc_id": 6,
"elements": [
{"id": 1},
{"id": 1, "properties": {"_element_index": 1}},
],
},
]
docs = [Document(item) for item in dicts]

context = sycamore.init()
context = sycamore.init(exec_mode=ExecMode.LOCAL)
doc_set = context.read.document(docs).rerank(
similarity_scorer=similarity_scorer, query="is this a cat?", score_property_name=score_property_name, limit=5
)
Expand All @@ -70,14 +79,17 @@ def test_rerank_docset_exploded():
{
"doc_id": 4,
"elements": [
{"text_representation": "this doc doesn't have a text representation but instead has an element"}
{
"properties": {"_element_index": 1},
"text_representation": "this doc doesn't have a text representation but instead has an element",
}
],
},
{"doc_id": 5, "text_representation": "the number of pages in this document are 253"},
]
docs = [Document(item) for item in dicts]

context = sycamore.init()
context = sycamore.init(exec_mode=ExecMode.LOCAL)
doc_set = context.read.document(docs).rerank(
similarity_scorer=similarity_scorer, query="is this a cat?", score_property_name=score_property_name
)
Expand Down
8 changes: 4 additions & 4 deletions lib/sycamore/sycamore/tests/unit/transforms/test_sort.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import random

import sycamore
from sycamore import DocSet
from sycamore import DocSet, ExecMode
from sycamore.data import Document, MetadataDocument


Expand All @@ -24,9 +24,9 @@ def docs(self) -> list[Document]:
doc.properties.pop("even")
return doc_list

@pytest.fixture()
def docset(self, docs: list[Document]) -> DocSet:
context = sycamore.init()
@pytest.fixture(params=(exec_mode for exec_mode in ExecMode if exec_mode != ExecMode.UNKNOWN))
def docset(self, docs: list[Document], exec_mode) -> DocSet:
context = sycamore.init(exec_mode=exec_mode)
return context.read.document(docs)

def test_sort_descending(self, docset: DocSet):
Expand Down
3 changes: 3 additions & 0 deletions lib/sycamore/sycamore/transforms/basics.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ def execute(self, **kwargs) -> "Dataset":
dataset = self.child().execute()
return dataset.limit(self._limit)

def local_execute(self, all_docs: list[Document]) -> list[Document]:
return all_docs[: self._limit]


class Filter(MapBatch):
"""
Expand Down
14 changes: 14 additions & 0 deletions lib/sycamore/sycamore/transforms/sort.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,20 @@ def execute(self, **kwargs) -> "Dataset":
ds = ds.drop_columns(["key"])
return ds

def local_execute(self, all_docs: list[Document]) -> list[Document]:
def get_sort_key(doc, field, default_val):
field_value = doc.field_to_value(field)
if field_value is not None:
return field_value
if default_val is None:
raise ValueError("default_value cannot be None")
return default_val

sorted_docs = sorted(
all_docs, key=lambda doc: get_sort_key(doc, self._field, self._default_val), reverse=self._descending
)
return sorted_docs

def make_map_fn_sort(self):
def ray_callable(input_dict: dict[str, Any]) -> dict[str, Any]:
doc = Document.from_row(input_dict)
Expand Down

0 comments on commit 61b66a2

Please sign in to comment.