Skip to content

Commit

Permalink
Element ordering and test improvements (#872)
Browse files Browse the repository at this point in the history
* Improvements

* unit test fixes

* pr comments

* pr comments
  • Loading branch information
baitsguy authored Oct 3, 2024
1 parent 4093134 commit 0a3ad99
Show file tree
Hide file tree
Showing 25 changed files with 195 additions and 104 deletions.
1 change: 1 addition & 0 deletions lib/sycamore/sycamore/data/document.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,7 @@ def deserialize(raw: bytes) -> "Document":
@staticmethod
def from_row(row: dict[str, bytes]) -> "Document":
"""Unserialize a Ray row back into a Document."""

return Document.deserialize(row["doc"])

def to_row(self) -> dict[str, bytes]:
Expand Down
22 changes: 13 additions & 9 deletions lib/sycamore/sycamore/data/element.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,14 @@ def __init__(self, element=None, /, **kwargs):
self.data["properties"] = {}

@property
def seq_no(self) -> Optional[int]:
def element_index(self) -> Optional[int]:
"""A unique identifier for the element within a Document. Represents an order within the document"""
return self.data.get("seq_no")
return self.data.get("properties", {}).get("_element_index")

@seq_no.setter
def seq_no(self, value: int) -> None:
@element_index.setter
def element_index(self, value: int) -> None:
"""Set the unique identifier of the element within a Document."""
self.data["seq_no"] = value
self.data["properties"]["_element_index"] = value

@property
def type(self) -> Optional[str]:
Expand Down Expand Up @@ -222,7 +222,8 @@ def text_representation(self, text_representation: str) -> None:
self.data["text_representation"] = text_representation


def create_element(**kwargs) -> Element:
def create_element(element_index: Optional[int] = None, **kwargs) -> Element:
element: Element
if "type" in kwargs and kwargs["type"].lower() == "table":
if "properties" in kwargs:
props = kwargs["properties"]
Expand All @@ -233,7 +234,7 @@ def create_element(**kwargs) -> Element:
table = Table.from_dict(kwargs["table"])
kwargs["table"] = table

return TableElement(**kwargs)
element = TableElement(**kwargs)

elif "type" in kwargs and kwargs["type"].lower() in {"picture", "image", "figure"}:
if "properties" in kwargs:
Expand All @@ -242,7 +243,10 @@ def create_element(**kwargs) -> Element:
kwargs["image_mode"] = props.get("image_mode")
kwargs["image_format"] = props.get("image_format")

return ImageElement(**kwargs)
element = ImageElement(**kwargs)

else:
return Element(**kwargs)
element = Element(**kwargs)
if element_index is not None:
element.element_index = element_index
return element
8 changes: 6 additions & 2 deletions lib/sycamore/sycamore/functions/document.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from io import BytesIO
from typing import Optional

from sycamore.data.document import DocumentPropertyTypes
from sycamore.data.element import TableElement

import pdf2image
Expand Down Expand Up @@ -46,15 +48,17 @@ def split_and_convert_to_image(doc: Document) -> list[Document]:
elements_by_page: dict[int, list[Element]] = {}

for e in doc.elements:
page_number = e.properties["page_number"]
page_number = e.properties[DocumentPropertyTypes.PAGE_NUMBER]
elements_by_page.setdefault(page_number, []).append(e)

new_docs = []
for page, image in enumerate(images):
elements = elements_by_page.get(page + 1, [])
new_doc = Document(binary_representation=image.tobytes(), elements=elements)
new_doc.properties.update(doc.properties)
new_doc.properties.update({"size": list(image.size), "mode": image.mode, "page_number": page + 1})
new_doc.properties.update(
{"size": list(image.size), "mode": image.mode, DocumentPropertyTypes.PAGE_NUMBER: page + 1}
)
new_docs.append(new_doc)
return new_docs

Expand Down
17 changes: 17 additions & 0 deletions lib/sycamore/sycamore/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import pytest
from pyarrow.fs import LocalFileSystem

from sycamore import ExecMode
from sycamore.data.document import Document


Expand All @@ -13,3 +14,19 @@ def read_local_binary(request) -> Document:
document.binary_representation = input_stream.readall()
document.properties["path"] = path
return document


@pytest.fixture(params=(exec_mode for exec_mode in ExecMode if exec_mode != ExecMode.UNKNOWN))
def exec_mode(request):
"""
Use this to run a test against all available execution modes. You will need to pass this as a parameter to
the Context initialization. e.g.
Example:
.. code-block:: python
def test_example(exec_mode):
context = sycamore.init(exec_mode=exec_mode)
...
"""
return request.param
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,13 @@

from opensearchpy import OpenSearch
import sycamore
from sycamore import ExecMode
from sycamore.connectors.file.file_scan import JsonManifestMetadataProvider
from sycamore.tests.config import TEST_DIR
from sycamore.transforms.embed import SentenceTransformerEmbedder
from sycamore.transforms.partition import HtmlPartitioner


def test_html_to_opensearch():
def test_html_to_opensearch(exec_mode):
os_client_args = {
"hosts": [{"host": "localhost", "port": 9200}],
"http_compress": True,
Expand Down Expand Up @@ -55,7 +54,7 @@ def test_html_to_opensearch():
tmp_manifest.flush()
manifest_path = tmp_manifest.name

context = sycamore.init(exec_mode=ExecMode.LOCAL)
context = sycamore.init(exec_mode=exec_mode)
ds = (
context.read.binary(
base_path, binary_format="html", metadata_provider=JsonManifestMetadataProvider(manifest_path)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from opensearchpy import OpenSearch

import sycamore
from sycamore import ExecMode
from sycamore.connectors.common import compare_docs
from sycamore.tests.config import TEST_DIR
from sycamore.transforms.partition import UnstructuredPdfPartitioner
Expand Down Expand Up @@ -54,13 +53,13 @@ class TestOpenSearchRead:
"timeout": 120,
}

def test_ingest_and_read(self, setup_index):
def test_ingest_and_read(self, setup_index, exec_mode):
"""
Validates data is readable from OpenSearch, and that we can rebuild processed Sycamore documents.
"""

path = str(TEST_DIR / "resources/data/pdfs/Ray.pdf")
context = sycamore.init(exec_mode=ExecMode.LOCAL)
context = sycamore.init(exec_mode=exec_mode)
original_docs = (
context.read.binary(path, binary_format="pdf")
.partition(partitioner=UnstructuredPdfPartitioner())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
ARYN_API_KEY = os.environ["ARYN_API_KEY"]


def test_detr_ocr():
def test_aryn_partitioner_w_ocr():
path = TEST_DIR / "resources/data/pdfs/Transformer.pdf"

context = sycamore.init()
Expand All @@ -23,14 +23,17 @@ def test_detr_ocr():
# The test will need to be updated if and when that changes.
docs = (
context.read.binary(paths=[str(path)], binary_format="pdf")
.partition(ArynPartitioner(use_ocr=True))
.partition(ArynPartitioner(aryn_api_key=ARYN_API_KEY, use_ocr=True))
.explode()
.filter(lambda doc: "page_number" in doc.properties and doc.properties["page_number"] == 1)
.filter(lambda doc: doc.type in {"Section-header", "Title"})
.take_all()
)

assert "Attention Is All You Need" in set(str(d.text_representation).strip() for d in docs)
assert all(
docs[i].properties["_element_index"] < docs[i + 1].properties["_element_index"] for i in range(len(docs) - 1)
)


def check_table_extraction(**kwargs):
Expand Down Expand Up @@ -81,6 +84,7 @@ def check_table_extraction(**kwargs):
assert len(docs) == 1
doc = docs[0]
tables = [e for e in doc.elements if e.type == "table"]
assert all(tables[i].element_index < tables[i + 1].element_index for i in range(len(tables) - 1))
assert len(tables) == 1
assert isinstance(tables[0], TableElement)
assert tables[0].table is not None
Expand Down Expand Up @@ -128,6 +132,9 @@ def test_aryn_partitioner():
)

assert "Attention Is All You Need" in set(str(d.text_representation).strip() for d in docs)
assert all(
docs[i].properties["_element_index"] < docs[i + 1].properties["_element_index"] for i in range(len(docs) - 1)
)


def test_table_extraction_with_ocr_batched():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ def test_partition(self, mocker) -> None:
expected_json = json.loads(expected_text.read())
partitioner = ArynPDFPartitioner(None)
expected_elements = []
for element_json in expected_json:
element = create_element(**element_json)
for i, element_json in enumerate(expected_json):
element = create_element(i, **element_json)
if element.binary_representation:
element.binary_representation = base64.b64decode(element.binary_representation)
expected_elements.append(element)
Expand All @@ -49,8 +49,8 @@ def test_partition_extract_table_structure(self, mocker) -> None:
expected_json = json.loads(expected_text.read())
partitioner = ArynPDFPartitioner(None)
expected_elements = []
for element_json in expected_json:
element = create_element(**element_json)
for i, element_json in enumerate(expected_json):
element = create_element(i, **element_json)
if element.binary_representation:
element.binary_representation = base64.b64decode(element.binary_representation)
expected_elements.append(element)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def test_partitioner(self):
"metadata": {"filename": "Bert.pdf", "filetype": "application/pdf", "page_number": 1},
"text": "BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding",
}
element = UnstructuredPdfPartitioner.to_element(dict)
element = UnstructuredPdfPartitioner.to_element(dict, element_index=1)
assert element.type == "Title"
assert (
element.text_representation == "BERT: Pre-training of Deep Bidirectional Transformers for"
Expand All @@ -62,6 +62,7 @@ def test_partitioner(self):
"filename": "Bert.pdf",
"filetype": "application/pdf",
"page_number": 1,
"_element_index": 1,
}

@pytest.mark.parametrize(
Expand Down
22 changes: 15 additions & 7 deletions lib/sycamore/sycamore/tests/unit/transforms/test_similarity.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,27 +18,35 @@ def test_transformers_similarity_scorer(self):
{
"doc_id": 1,
"elements": [
{"text_representation": "here is an animal that meows"},
{"text_representation": "here is an animal that meows", "properties": {"_element_index": 1}}
],
},
{
"doc_id": 2,
"elements": [
{"seq_no": 7, "text_representation": "this is a cat"},
{"seq_no": 1, "text_representation": "here is an animal that moos"},
{"properties": {"_element_index": 7}, "text_representation": "this is a cat"},
{"properties": {"_element_index": 1}, "text_representation": "here is an animal that moos"},
],
},
{
"doc_id": 3,
"elements": [
{"text_representation": "here is an animal that moos"},
{"properties": {"_element_index": 1}, "text_representation": "here is an animal that moos"},
],
},
{
"doc_id": 4,
"elements": [
{
"properties": {"_element_index": 1},
"text_representation": "the number of pages in this document are 253",
}
],
},
{"doc_id": 4, "elements": [{"text_representation": "the number of pages in this document are 253"}]},
{ # handle empty element
"doc_id": 5,
"elements": [
{"seq_no": 1},
{"properties": {"_element_index": 1}},
],
},
]
Expand All @@ -49,7 +57,7 @@ def test_transformers_similarity_scorer(self):
result.sort(key=lambda doc: doc.properties.get(score_property_name, float("-inf")), reverse=True)
assert [doc.doc_id for doc in result] == [2, 1, 3, 4, 5]

assert result[0].properties[score_property_name + "_source_element_seq_no"] == 7
assert result[0].properties[score_property_name + "_source_element_index"] == 7

def test_transformers_similarity_scorer_no_doc_structure(self):
similarity_scorer = HuggingFaceTransformersSimilarityScorer(RERANKER_MODEL, ignore_doc_structure=True)
Expand Down
8 changes: 8 additions & 0 deletions lib/sycamore/sycamore/tests/unit/utils/test_bbox_sort.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ def test_elements_basic() -> None:
elems = bbox_sorted_elements(elems)
answer = [e4, e5, e3, e6, e7, e8, e9, e1, e0, e2]
assert elems == answer
assert_element_index_sorted(elems)


def test_document_basic() -> None:
Expand All @@ -114,3 +115,10 @@ def test_document_basic() -> None:
bbox_sort_document(doc)
answer = [e3, e2, e5, e4, e1, e0]
assert doc.elements == answer
assert_element_index_sorted(doc.elements)


def assert_element_index_sorted(elements: list[Element]):
assert all(
elements[i].element_index < elements[i + 1].element_index for i in range(len(elements) - 1) # type: ignore
)
9 changes: 5 additions & 4 deletions lib/sycamore/sycamore/transforms/bbox_merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@


from sycamore.data import Document, Element
from sycamore.data.document import DocumentPropertyTypes
from sycamore.plan_nodes import Node, SingleThreadUser, NonGPUUser
from sycamore.transforms.map import Map
from sycamore.utils.time_trace import TimeTrace, timetrace
Expand All @@ -26,9 +27,9 @@ def getBboxLeftTop(elem: Element):
def getPageTopLeft(elem: Element):
bbox = elem.data.get("bbox")
if bbox is None:
return (elem.properties["page_number"], 0.0, 0.0)
return (elem.properties[DocumentPropertyTypes.PAGE_NUMBER], 0.0, 0.0)
else:
return (elem.properties["page_number"], bbox[1], bbox[0])
return (elem.properties[DocumentPropertyTypes.PAGE_NUMBER], bbox[1], bbox[0])


def getRow(elem: Element, elements: list[Element]) -> list[Element]:
Expand All @@ -41,7 +42,7 @@ def getRow(elem: Element, elements: list[Element]) -> list[Element]:
top = bbox[1]
right = bbox[2]
bottom = bbox[3]
page = elem.properties["page_number"]
page = elem.properties[DocumentPropertyTypes.PAGE_NUMBER]

# !!! assuming elements are sorted by y-values
n = len(elements)
Expand All @@ -51,7 +52,7 @@ def getRow(elem: Element, elements: list[Element]) -> list[Element]:
while beg < end:
mid = beg + ((end - beg) // 2)
melem = elements[mid]
mpage = melem.properties["page_number"]
mpage = melem.properties[DocumentPropertyTypes.PAGE_NUMBER]
if mpage < page:
beg = mid + 1
idx = mid
Expand Down
Loading

0 comments on commit 0a3ad99

Please sign in to comment.