From 412f50ec6ae851af266c799429699ea3027d65c2 Mon Sep 17 00:00:00 2001 From: Ines Montani Date: Sat, 7 Dec 2024 12:22:40 +0100 Subject: [PATCH 1/3] Fix serialization of extension attributes --- spacy_layout/layout.py | 20 +++++++--------- spacy_layout/types.py | 13 ++++++++++ spacy_layout/util.py | 54 ++++++++++++++++++++++++++++++++++++++++++ tests/test_general.py | 45 ++++++++++++++++++++++++++++++++++- 4 files changed, 119 insertions(+), 13 deletions(-) create mode 100644 spacy_layout/util.py diff --git a/spacy_layout/layout.py b/spacy_layout/layout.py index 8b4c578..cd9e396 100644 --- a/spacy_layout/layout.py +++ b/spacy_layout/layout.py @@ -2,24 +2,30 @@ from pathlib import Path from typing import TYPE_CHECKING, Callable, Iterable, Iterator +import srsly from docling.datamodel.base_models import DocumentStream from docling.document_converter import DocumentConverter -from docling_core.types.doc.base import CoordOrigin from docling_core.types.doc.labels import DocItemLabel from spacy.tokens import Doc, Span, SpanGroup from .types import Attrs, DocLayout, DoclingItem, PageLayout, SpanLayout +from .util import decode_df, decode_obj, encode_df, encode_obj, get_bounding_box if TYPE_CHECKING: from docling.datamodel.base_models import InputFormat from docling.document_converter import ConversionResult, FormatOption - from docling_core.types.doc.base import BoundingBox from pandas import DataFrame from spacy.language import Language TABLE_PLACEHOLDER = "TABLE" +# Register msgpack encoders and decoders for custom types +srsly.msgpack_encoders.register("spacy-layout.dataclass", func=encode_obj) +srsly.msgpack_decoders.register("spacy-layout.dataclass", func=decode_obj) +srsly.msgpack_encoders.register("spacy-layout.dataframe", func=encode_df) +srsly.msgpack_decoders.register("spacy-layout.dataframe", func=decode_df) + class spaCyLayout: def __init__( @@ -181,13 +187,3 @@ def get_tables(self, doc: Doc) -> list[Span]: for span in doc.spans[self.attrs.span_group] if span.label_ == DocItemLabel.TABLE ] - - -def get_bounding_box( - bbox: "BoundingBox", page_height: float -) -> tuple[float, float, float, float]: - is_bottom = bbox.coord_origin == CoordOrigin.BOTTOMLEFT - y = page_height - bbox.t if is_bottom else bbox.t - height = bbox.t - bbox.b if is_bottom else bbox.b - bbox.t - width = bbox.r - bbox.l - return (bbox.l, y, width, height) diff --git a/spacy_layout/types.py b/spacy_layout/types.py index 6aa5f3e..9bb89f7 100644 --- a/spacy_layout/types.py +++ b/spacy_layout/types.py @@ -29,6 +29,10 @@ class PageLayout: width: float height: float + @classmethod + def from_dict(cls, data: dict) -> "PageLayout": + return cls(**data) + @dataclass class DocLayout: @@ -36,6 +40,11 @@ class DocLayout: pages: list[PageLayout] + @classmethod + def from_dict(cls, data: dict) -> "DocLayout": + pages = [PageLayout.from_dict(page) for page in data.get("pages", [])] + return cls(pages=pages) + @dataclass class SpanLayout: @@ -46,3 +55,7 @@ class SpanLayout: width: float height: float page_no: int + + @classmethod + def from_dict(cls, data: dict) -> "SpanLayout": + return cls(**data) diff --git a/spacy_layout/util.py b/spacy_layout/util.py new file mode 100644 index 0000000..d479202 --- /dev/null +++ b/spacy_layout/util.py @@ -0,0 +1,54 @@ +import dataclasses +from typing import TYPE_CHECKING, Callable + +from docling_core.types.doc.base import CoordOrigin +from pandas import DataFrame + +from .types import DocLayout, PageLayout, SpanLayout + +if TYPE_CHECKING: + from docling_core.types.doc.base import BoundingBox + +TYPE_ATTR = "__type__" +OBJ_TYPES = {"SpanLayout": SpanLayout, "DocLayout": DocLayout, "PageLayout": PageLayout} + + +def encode_obj(obj, chain: Callable | None = None): + """Convert custom dataclass to dict for serialization.""" + if isinstance(obj, tuple(OBJ_TYPES.values())): + result = dataclasses.asdict(obj) + result[TYPE_ATTR] = type(obj).__name__ + return result + return obj if chain is None else chain(obj) + + +def decode_obj(obj, chain: Callable | None = None): + """Load custom dataclass from serialized dict.""" + if isinstance(obj, dict) and obj.get(TYPE_ATTR) in OBJ_TYPES: + obj_type = obj.pop(TYPE_ATTR) + return OBJ_TYPES[obj_type].from_dict(obj) + return obj if chain is None else chain(obj) + + +def encode_df(obj, chain: Callable | None = None): + """Convert pandas.DataFrame for serialization.""" + if isinstance(obj, DataFrame): + return {"data": obj.to_dict(), TYPE_ATTR: "DataFrame"} + return obj if chain is None else chain(obj) + + +def decode_df(obj, chain: Callable | None = None): + """Load pandas.DataFrame from serialized data.""" + if isinstance(obj, dict) and obj.get(TYPE_ATTR) == "DataFrame": + return DataFrame(obj["data"]) + return obj if chain is None else chain(obj) + + +def get_bounding_box( + bbox: "BoundingBox", page_height: float +) -> tuple[float, float, float, float]: + is_bottom = bbox.coord_origin == CoordOrigin.BOTTOMLEFT + y = page_height - bbox.t if is_bottom else bbox.t + height = bbox.t - bbox.b if is_bottom else bbox.b - bbox.t + width = bbox.r - bbox.l + return (bbox.l, y, width, height) diff --git a/tests/test_general.py b/tests/test_general.py index acdfcbf..64140dc 100644 --- a/tests/test_general.py +++ b/tests/test_general.py @@ -2,12 +2,16 @@ import pytest import spacy +import srsly from docling_core.types.doc.base import BoundingBox, CoordOrigin from docling_core.types.doc.labels import DocItemLabel +from pandas import DataFrame +from pandas.testing import assert_frame_equal +from spacy.tokens import DocBin from spacy_layout import spaCyLayout from spacy_layout.layout import TABLE_PLACEHOLDER, get_bounding_box -from spacy_layout.types import DocLayout, SpanLayout +from spacy_layout.types import DocLayout, PageLayout, SpanLayout PDF_STARCRAFT = Path(__file__).parent / "data" / "starcraft.pdf" PDF_SIMPLE = Path(__file__).parent / "data" / "simple.pdf" @@ -118,3 +122,42 @@ def test_bounding_box(box, page_height, expected): top, bottom, left, right, origin = box bbox = BoundingBox(t=top, b=bottom, l=left, r=right, coord_origin=origin) assert get_bounding_box(bbox, page_height) == expected + + +def test_serialize_objects(): + span_layout = SpanLayout(x=10, y=20, width=30, height=40, page_no=1) + doc_layout = DocLayout(pages=[PageLayout(page_no=1, width=500, height=600)]) + bytes_data = srsly.msgpack_dumps({"span": span_layout, "doc": doc_layout}) + data = srsly.msgpack_loads(bytes_data) + assert isinstance(data, dict) + assert data["span"] == span_layout + assert data["doc"] == doc_layout + df = DataFrame(data={"col1": [1, 2], "col2": [3, 4]}) + bytes_data = srsly.msgpack_dumps({"df": df}) + data = srsly.msgpack_loads(bytes_data) + assert isinstance(data, dict) + assert_frame_equal(df, data["df"]) + + +@pytest.mark.parametrize("path", [PDF_SIMPLE, PDF_TABLE]) +def test_serialize_roundtrip(path, nlp): + layout = spaCyLayout(nlp) + doc = layout(path) + doc_bin = DocBin(store_user_data=True) + doc_bin.add(doc) + bytes_data = doc_bin.to_bytes() + new_doc_bin = DocBin().from_bytes(bytes_data) + new_doc = list(new_doc_bin.get_docs(nlp.vocab))[0] + layout_spans = new_doc.spans[layout.attrs.span_group] + assert len(layout_spans) == len(doc.spans[layout.attrs.span_group]) + assert all( + isinstance(span._.get(layout.attrs.span_layout), SpanLayout) + for span in layout_spans + ) + assert isinstance(new_doc._.get(layout.attrs.doc_layout), DocLayout) + tables = doc._.get(layout.attrs.doc_tables) + new_tables = new_doc._.get(layout.attrs.doc_tables) + for before, after in zip(tables, new_tables): + table_before = before._.get(layout.attrs.span_data) + table_after = after._.get(layout.attrs.span_data) + assert_frame_equal(table_before, table_after) From e84650cc5cf7233464593ced3c1958f37998b21e Mon Sep 17 00:00:00 2001 From: Ines Montani Date: Sat, 7 Dec 2024 12:26:54 +0100 Subject: [PATCH 2/3] Update requirements Define all packages that are imported directly, even if we know they'll be installed --- requirements.txt | 2 ++ setup.cfg | 2 ++ 2 files changed, 4 insertions(+) diff --git a/requirements.txt b/requirements.txt index 212e684..a907b55 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,6 @@ spacy>=3.7.5 docling>=2.5.2 +pandas +srsly # Dev requirements pytest diff --git a/setup.cfg b/setup.cfg index 4a624cc..cb75f46 100644 --- a/setup.cfg +++ b/setup.cfg @@ -31,6 +31,8 @@ python_requires = >=3.10 install_requires = spacy>=3.7.5 docling>=2.5.2 + pandas + srsly [bdist_wheel] universal = true From e7631f99b4e0ade609f71e17b33793f9dba69150 Mon Sep 17 00:00:00 2001 From: Ines Montani Date: Sat, 7 Dec 2024 12:35:45 +0100 Subject: [PATCH 3/3] Add comments to requirements --- requirements.txt | 4 ++-- setup.cfg | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/requirements.txt b/requirements.txt index a907b55..2f8e4ae 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,6 @@ spacy>=3.7.5 docling>=2.5.2 -pandas -srsly +pandas # version range set by Docling +srsly # version range set by spaCy # Dev requirements pytest diff --git a/setup.cfg b/setup.cfg index cb75f46..d9adc14 100644 --- a/setup.cfg +++ b/setup.cfg @@ -31,8 +31,8 @@ python_requires = >=3.10 install_requires = spacy>=3.7.5 docling>=2.5.2 - pandas - srsly + pandas # version range set by Docling + srsly # version range set by spaCy [bdist_wheel] universal = true