Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix serialization of extension attributes #14

Merged
merged 3 commits into from
Dec 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
spacy>=3.7.5
docling>=2.5.2
pandas # version range set by Docling
srsly # version range set by spaCy
# Dev requirements
pytest
2 changes: 2 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ python_requires = >=3.10
install_requires =
spacy>=3.7.5
docling>=2.5.2
pandas # version range set by Docling
srsly # version range set by spaCy

[bdist_wheel]
universal = true
Expand Down
20 changes: 8 additions & 12 deletions spacy_layout/layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,24 +2,30 @@
from pathlib import Path
from typing import TYPE_CHECKING, Callable, Iterable, Iterator

import srsly
from docling.datamodel.base_models import DocumentStream
from docling.document_converter import DocumentConverter
from docling_core.types.doc.base import CoordOrigin
from docling_core.types.doc.labels import DocItemLabel
from spacy.tokens import Doc, Span, SpanGroup

from .types import Attrs, DocLayout, DoclingItem, PageLayout, SpanLayout
from .util import decode_df, decode_obj, encode_df, encode_obj, get_bounding_box

if TYPE_CHECKING:
from docling.datamodel.base_models import InputFormat
from docling.document_converter import ConversionResult, FormatOption
from docling_core.types.doc.base import BoundingBox
from pandas import DataFrame
from spacy.language import Language


TABLE_PLACEHOLDER = "TABLE"

# Register msgpack encoders and decoders for custom types
srsly.msgpack_encoders.register("spacy-layout.dataclass", func=encode_obj)
srsly.msgpack_decoders.register("spacy-layout.dataclass", func=decode_obj)
srsly.msgpack_encoders.register("spacy-layout.dataframe", func=encode_df)
srsly.msgpack_decoders.register("spacy-layout.dataframe", func=decode_df)


class spaCyLayout:
def __init__(
Expand Down Expand Up @@ -181,13 +187,3 @@ def get_tables(self, doc: Doc) -> list[Span]:
for span in doc.spans[self.attrs.span_group]
if span.label_ == DocItemLabel.TABLE
]


def get_bounding_box(
bbox: "BoundingBox", page_height: float
) -> tuple[float, float, float, float]:
is_bottom = bbox.coord_origin == CoordOrigin.BOTTOMLEFT
y = page_height - bbox.t if is_bottom else bbox.t
height = bbox.t - bbox.b if is_bottom else bbox.b - bbox.t
width = bbox.r - bbox.l
return (bbox.l, y, width, height)
13 changes: 13 additions & 0 deletions spacy_layout/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,22 @@ class PageLayout:
width: float
height: float

@classmethod
def from_dict(cls, data: dict) -> "PageLayout":
return cls(**data)


@dataclass
class DocLayout:
"""Document layout features added to Doc object"""

pages: list[PageLayout]

@classmethod
def from_dict(cls, data: dict) -> "DocLayout":
pages = [PageLayout.from_dict(page) for page in data.get("pages", [])]
return cls(pages=pages)


@dataclass
class SpanLayout:
Expand All @@ -46,3 +55,7 @@ class SpanLayout:
width: float
height: float
page_no: int

@classmethod
def from_dict(cls, data: dict) -> "SpanLayout":
return cls(**data)
54 changes: 54 additions & 0 deletions spacy_layout/util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import dataclasses
from typing import TYPE_CHECKING, Callable

from docling_core.types.doc.base import CoordOrigin
from pandas import DataFrame

from .types import DocLayout, PageLayout, SpanLayout

if TYPE_CHECKING:
from docling_core.types.doc.base import BoundingBox

TYPE_ATTR = "__type__"
OBJ_TYPES = {"SpanLayout": SpanLayout, "DocLayout": DocLayout, "PageLayout": PageLayout}


def encode_obj(obj, chain: Callable | None = None):
"""Convert custom dataclass to dict for serialization."""
if isinstance(obj, tuple(OBJ_TYPES.values())):
result = dataclasses.asdict(obj)
result[TYPE_ATTR] = type(obj).__name__
return result
return obj if chain is None else chain(obj)


def decode_obj(obj, chain: Callable | None = None):
"""Load custom dataclass from serialized dict."""
if isinstance(obj, dict) and obj.get(TYPE_ATTR) in OBJ_TYPES:
obj_type = obj.pop(TYPE_ATTR)
return OBJ_TYPES[obj_type].from_dict(obj)
return obj if chain is None else chain(obj)


def encode_df(obj, chain: Callable | None = None):
"""Convert pandas.DataFrame for serialization."""
if isinstance(obj, DataFrame):
return {"data": obj.to_dict(), TYPE_ATTR: "DataFrame"}
return obj if chain is None else chain(obj)


def decode_df(obj, chain: Callable | None = None):
"""Load pandas.DataFrame from serialized data."""
if isinstance(obj, dict) and obj.get(TYPE_ATTR) == "DataFrame":
return DataFrame(obj["data"])
return obj if chain is None else chain(obj)


def get_bounding_box(
bbox: "BoundingBox", page_height: float
) -> tuple[float, float, float, float]:
is_bottom = bbox.coord_origin == CoordOrigin.BOTTOMLEFT
y = page_height - bbox.t if is_bottom else bbox.t
height = bbox.t - bbox.b if is_bottom else bbox.b - bbox.t
width = bbox.r - bbox.l
return (bbox.l, y, width, height)
45 changes: 44 additions & 1 deletion tests/test_general.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,16 @@

import pytest
import spacy
import srsly
from docling_core.types.doc.base import BoundingBox, CoordOrigin
from docling_core.types.doc.labels import DocItemLabel
from pandas import DataFrame
from pandas.testing import assert_frame_equal
from spacy.tokens import DocBin

from spacy_layout import spaCyLayout
from spacy_layout.layout import TABLE_PLACEHOLDER, get_bounding_box
from spacy_layout.types import DocLayout, SpanLayout
from spacy_layout.types import DocLayout, PageLayout, SpanLayout

PDF_STARCRAFT = Path(__file__).parent / "data" / "starcraft.pdf"
PDF_SIMPLE = Path(__file__).parent / "data" / "simple.pdf"
Expand Down Expand Up @@ -118,3 +122,42 @@ def test_bounding_box(box, page_height, expected):
top, bottom, left, right, origin = box
bbox = BoundingBox(t=top, b=bottom, l=left, r=right, coord_origin=origin)
assert get_bounding_box(bbox, page_height) == expected


def test_serialize_objects():
span_layout = SpanLayout(x=10, y=20, width=30, height=40, page_no=1)
doc_layout = DocLayout(pages=[PageLayout(page_no=1, width=500, height=600)])
bytes_data = srsly.msgpack_dumps({"span": span_layout, "doc": doc_layout})
data = srsly.msgpack_loads(bytes_data)
assert isinstance(data, dict)
assert data["span"] == span_layout
assert data["doc"] == doc_layout
df = DataFrame(data={"col1": [1, 2], "col2": [3, 4]})
bytes_data = srsly.msgpack_dumps({"df": df})
data = srsly.msgpack_loads(bytes_data)
assert isinstance(data, dict)
assert_frame_equal(df, data["df"])


@pytest.mark.parametrize("path", [PDF_SIMPLE, PDF_TABLE])
def test_serialize_roundtrip(path, nlp):
layout = spaCyLayout(nlp)
doc = layout(path)
doc_bin = DocBin(store_user_data=True)
doc_bin.add(doc)
bytes_data = doc_bin.to_bytes()
new_doc_bin = DocBin().from_bytes(bytes_data)
new_doc = list(new_doc_bin.get_docs(nlp.vocab))[0]
layout_spans = new_doc.spans[layout.attrs.span_group]
assert len(layout_spans) == len(doc.spans[layout.attrs.span_group])
assert all(
isinstance(span._.get(layout.attrs.span_layout), SpanLayout)
for span in layout_spans
)
assert isinstance(new_doc._.get(layout.attrs.doc_layout), DocLayout)
tables = doc._.get(layout.attrs.doc_tables)
new_tables = new_doc._.get(layout.attrs.doc_tables)
for before, after in zip(tables, new_tables):
table_before = before._.get(layout.attrs.span_data)
table_after = after._.get(layout.attrs.span_data)
assert_frame_equal(table_before, table_after)
Loading