diff --git a/requirements.txt b/requirements.txt index 212e684..2f8e4ae 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,6 @@ spacy>=3.7.5 docling>=2.5.2 +pandas # version range set by Docling +srsly # version range set by spaCy # Dev requirements pytest diff --git a/setup.cfg b/setup.cfg index 4a624cc..d9adc14 100644 --- a/setup.cfg +++ b/setup.cfg @@ -31,6 +31,8 @@ python_requires = >=3.10 install_requires = spacy>=3.7.5 docling>=2.5.2 + pandas # version range set by Docling + srsly # version range set by spaCy [bdist_wheel] universal = true diff --git a/spacy_layout/layout.py b/spacy_layout/layout.py index 8b4c578..cd9e396 100644 --- a/spacy_layout/layout.py +++ b/spacy_layout/layout.py @@ -2,24 +2,30 @@ from pathlib import Path from typing import TYPE_CHECKING, Callable, Iterable, Iterator +import srsly from docling.datamodel.base_models import DocumentStream from docling.document_converter import DocumentConverter -from docling_core.types.doc.base import CoordOrigin from docling_core.types.doc.labels import DocItemLabel from spacy.tokens import Doc, Span, SpanGroup from .types import Attrs, DocLayout, DoclingItem, PageLayout, SpanLayout +from .util import decode_df, decode_obj, encode_df, encode_obj, get_bounding_box if TYPE_CHECKING: from docling.datamodel.base_models import InputFormat from docling.document_converter import ConversionResult, FormatOption - from docling_core.types.doc.base import BoundingBox from pandas import DataFrame from spacy.language import Language TABLE_PLACEHOLDER = "TABLE" +# Register msgpack encoders and decoders for custom types +srsly.msgpack_encoders.register("spacy-layout.dataclass", func=encode_obj) +srsly.msgpack_decoders.register("spacy-layout.dataclass", func=decode_obj) +srsly.msgpack_encoders.register("spacy-layout.dataframe", func=encode_df) +srsly.msgpack_decoders.register("spacy-layout.dataframe", func=decode_df) + class spaCyLayout: def __init__( @@ -181,13 +187,3 @@ def get_tables(self, doc: Doc) -> list[Span]: for span in doc.spans[self.attrs.span_group] if span.label_ == DocItemLabel.TABLE ] - - -def get_bounding_box( - bbox: "BoundingBox", page_height: float -) -> tuple[float, float, float, float]: - is_bottom = bbox.coord_origin == CoordOrigin.BOTTOMLEFT - y = page_height - bbox.t if is_bottom else bbox.t - height = bbox.t - bbox.b if is_bottom else bbox.b - bbox.t - width = bbox.r - bbox.l - return (bbox.l, y, width, height) diff --git a/spacy_layout/types.py b/spacy_layout/types.py index 6aa5f3e..9bb89f7 100644 --- a/spacy_layout/types.py +++ b/spacy_layout/types.py @@ -29,6 +29,10 @@ class PageLayout: width: float height: float + @classmethod + def from_dict(cls, data: dict) -> "PageLayout": + return cls(**data) + @dataclass class DocLayout: @@ -36,6 +40,11 @@ class DocLayout: pages: list[PageLayout] + @classmethod + def from_dict(cls, data: dict) -> "DocLayout": + pages = [PageLayout.from_dict(page) for page in data.get("pages", [])] + return cls(pages=pages) + @dataclass class SpanLayout: @@ -46,3 +55,7 @@ class SpanLayout: width: float height: float page_no: int + + @classmethod + def from_dict(cls, data: dict) -> "SpanLayout": + return cls(**data) diff --git a/spacy_layout/util.py b/spacy_layout/util.py new file mode 100644 index 0000000..d479202 --- /dev/null +++ b/spacy_layout/util.py @@ -0,0 +1,54 @@ +import dataclasses +from typing import TYPE_CHECKING, Callable + +from docling_core.types.doc.base import CoordOrigin +from pandas import DataFrame + +from .types import DocLayout, PageLayout, SpanLayout + +if TYPE_CHECKING: + from docling_core.types.doc.base import BoundingBox + +TYPE_ATTR = "__type__" +OBJ_TYPES = {"SpanLayout": SpanLayout, "DocLayout": DocLayout, "PageLayout": PageLayout} + + +def encode_obj(obj, chain: Callable | None = None): + """Convert custom dataclass to dict for serialization.""" + if isinstance(obj, tuple(OBJ_TYPES.values())): + result = dataclasses.asdict(obj) + result[TYPE_ATTR] = type(obj).__name__ + return result + return obj if chain is None else chain(obj) + + +def decode_obj(obj, chain: Callable | None = None): + """Load custom dataclass from serialized dict.""" + if isinstance(obj, dict) and obj.get(TYPE_ATTR) in OBJ_TYPES: + obj_type = obj.pop(TYPE_ATTR) + return OBJ_TYPES[obj_type].from_dict(obj) + return obj if chain is None else chain(obj) + + +def encode_df(obj, chain: Callable | None = None): + """Convert pandas.DataFrame for serialization.""" + if isinstance(obj, DataFrame): + return {"data": obj.to_dict(), TYPE_ATTR: "DataFrame"} + return obj if chain is None else chain(obj) + + +def decode_df(obj, chain: Callable | None = None): + """Load pandas.DataFrame from serialized data.""" + if isinstance(obj, dict) and obj.get(TYPE_ATTR) == "DataFrame": + return DataFrame(obj["data"]) + return obj if chain is None else chain(obj) + + +def get_bounding_box( + bbox: "BoundingBox", page_height: float +) -> tuple[float, float, float, float]: + is_bottom = bbox.coord_origin == CoordOrigin.BOTTOMLEFT + y = page_height - bbox.t if is_bottom else bbox.t + height = bbox.t - bbox.b if is_bottom else bbox.b - bbox.t + width = bbox.r - bbox.l + return (bbox.l, y, width, height) diff --git a/tests/test_general.py b/tests/test_general.py index acdfcbf..64140dc 100644 --- a/tests/test_general.py +++ b/tests/test_general.py @@ -2,12 +2,16 @@ import pytest import spacy +import srsly from docling_core.types.doc.base import BoundingBox, CoordOrigin from docling_core.types.doc.labels import DocItemLabel +from pandas import DataFrame +from pandas.testing import assert_frame_equal +from spacy.tokens import DocBin from spacy_layout import spaCyLayout from spacy_layout.layout import TABLE_PLACEHOLDER, get_bounding_box -from spacy_layout.types import DocLayout, SpanLayout +from spacy_layout.types import DocLayout, PageLayout, SpanLayout PDF_STARCRAFT = Path(__file__).parent / "data" / "starcraft.pdf" PDF_SIMPLE = Path(__file__).parent / "data" / "simple.pdf" @@ -118,3 +122,42 @@ def test_bounding_box(box, page_height, expected): top, bottom, left, right, origin = box bbox = BoundingBox(t=top, b=bottom, l=left, r=right, coord_origin=origin) assert get_bounding_box(bbox, page_height) == expected + + +def test_serialize_objects(): + span_layout = SpanLayout(x=10, y=20, width=30, height=40, page_no=1) + doc_layout = DocLayout(pages=[PageLayout(page_no=1, width=500, height=600)]) + bytes_data = srsly.msgpack_dumps({"span": span_layout, "doc": doc_layout}) + data = srsly.msgpack_loads(bytes_data) + assert isinstance(data, dict) + assert data["span"] == span_layout + assert data["doc"] == doc_layout + df = DataFrame(data={"col1": [1, 2], "col2": [3, 4]}) + bytes_data = srsly.msgpack_dumps({"df": df}) + data = srsly.msgpack_loads(bytes_data) + assert isinstance(data, dict) + assert_frame_equal(df, data["df"]) + + +@pytest.mark.parametrize("path", [PDF_SIMPLE, PDF_TABLE]) +def test_serialize_roundtrip(path, nlp): + layout = spaCyLayout(nlp) + doc = layout(path) + doc_bin = DocBin(store_user_data=True) + doc_bin.add(doc) + bytes_data = doc_bin.to_bytes() + new_doc_bin = DocBin().from_bytes(bytes_data) + new_doc = list(new_doc_bin.get_docs(nlp.vocab))[0] + layout_spans = new_doc.spans[layout.attrs.span_group] + assert len(layout_spans) == len(doc.spans[layout.attrs.span_group]) + assert all( + isinstance(span._.get(layout.attrs.span_layout), SpanLayout) + for span in layout_spans + ) + assert isinstance(new_doc._.get(layout.attrs.doc_layout), DocLayout) + tables = doc._.get(layout.attrs.doc_tables) + new_tables = new_doc._.get(layout.attrs.doc_tables) + for before, after in zip(tables, new_tables): + table_before = before._.get(layout.attrs.span_data) + table_after = after._.get(layout.attrs.span_data) + assert_frame_equal(table_before, table_after)