Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

initial working jaxtyping serializing/deserializing #576

Merged
merged 5 commits into from
Aug 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 90 additions & 0 deletions examples/type_numpy_jaxtyping.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
import numpy
from jaxtyping import (
Float,
Float16,
Float32,
Float64,
Inexact,
Int,
Int8,
Int16,
Int32,
Int64,
Integer,
UInt,
UInt8,
UInt16,
UInt32,
UInt64,
)
from serde import serde
from serde.json import from_json, to_json


@serde
class Foo:
float_: Float[numpy.ndarray, "3 3"]
float16: Float16[numpy.ndarray, "3 3"]
float32: Float32[numpy.ndarray, "3 3"]
float64: Float64[numpy.ndarray, "3 3"]
inexact: Inexact[numpy.ndarray, "3 3"]
int_: Int[numpy.ndarray, "3 3"]
int8: Int8[numpy.ndarray, "3 3"]
int16: Int16[numpy.ndarray, "3 3"]
int32: Int32[numpy.ndarray, "3 3"]
int64: Int64[numpy.ndarray, "3 3"]
integer: Integer[numpy.ndarray, "3 3"]
uint: UInt[numpy.ndarray, "3 3"]
uint8: UInt8[numpy.ndarray, "3 3"]
uint16: UInt16[numpy.ndarray, "3 3"]
uint32: UInt32[numpy.ndarray, "3 3"]
uint64: UInt64[numpy.ndarray, "3 3"]


def main() -> None:
foo = Foo(
float_=numpy.zeros((3, 3), dtype=float),
float16=numpy.zeros((3, 3), dtype=numpy.float16),
float32=numpy.zeros((3, 3), dtype=numpy.float32),
float64=numpy.zeros((3, 3), dtype=numpy.float64),
inexact=numpy.zeros((3, 3), dtype=numpy.inexact),
int_=numpy.zeros((3, 3), dtype=int),
int8=numpy.zeros((3, 3), dtype=numpy.int8),
int16=numpy.zeros((3, 3), dtype=numpy.int16),
int32=numpy.zeros((3, 3), dtype=numpy.int32),
int64=numpy.zeros((3, 3), dtype=numpy.int64),
integer=numpy.zeros((3, 3), dtype=numpy.integer),
uint=numpy.zeros((3, 3), dtype=numpy.uint),
uint8=numpy.zeros((3, 3), dtype=numpy.uint8),
uint16=numpy.zeros((3, 3), dtype=numpy.uint16),
uint32=numpy.zeros((3, 3), dtype=numpy.uint32),
uint64=numpy.zeros((3, 3), dtype=numpy.uint64),
)

print(f"Into Json: {to_json(foo)}")

s = """
{
"float_": [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]],
"float16": [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]],
"float32": [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]],
"float64": [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]],
"inexact": [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]],
"int_": [[0, 0, 0], [0, 0, 0], [0, 0, 0]],
"int8": [[0, 0, 0], [0, 0, 0], [0, 0, 0]],
"int16": [[0, 0, 0], [0, 0, 0], [0, 0, 0]],
"int32": [[0, 0, 0], [0, 0, 0], [0, 0, 0]],
"int64": [[0, 0, 0], [0, 0, 0], [0, 0, 0]],
"integer": [[0, 0, 0], [0, 0, 0], [0, 0, 0]],
"uint": [[0, 0, 0], [0, 0, 0], [0, 0, 0]],
"uint8": [[0, 0, 0], [0, 0, 0], [0, 0, 0]],
"uint16": [[0, 0, 0], [0, 0, 0], [0, 0, 0]],
"uint32": [[0, 0, 0], [0, 0, 0], [0, 0, 0]],
"uint64": [[0, 0, 0], [0, 0, 0], [0, 0, 0]]
}
"""
print(f"From Json: {from_json(Foo, s)}")


if __name__ == "__main__":
main()
60 changes: 36 additions & 24 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,16 @@ packages = [
{ include = "serde" },
]
classifiers=[
"Development Status :: 4 - Beta",
"Intended Audience :: Developers",
"License :: OSI Approved :: MIT License",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Programming Language :: Python :: Implementation :: CPython",
"Programming Language :: Python :: Implementation :: PyPy",
]
"Development Status :: 4 - Beta",
"Intended Audience :: Developers",
"License :: OSI Approved :: MIT License",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Programming Language :: Python :: Implementation :: CPython",
"Programming Language :: Python :: Implementation :: PyPy",
]

[tool.poetry.dependencies]
python = "^3.9.0"
Expand All @@ -33,11 +33,12 @@ tomli = { version = "*", markers = "extra == 'toml' or extra == 'all'", optional
tomli-w = { version = "*", markers = "extra == 'toml' or extra == 'all'", optional = true }
pyyaml = { version = "*", markers = "extra == 'yaml' or extra == 'all'", optional = true }
numpy = [
{ version = ">1.21.0,<2.0.0", markers = "python_version ~= '3.9.0' and (extra == 'numpy' or extra == 'all')", optional = true },
{ version = ">1.22.0,<2.0.0", markers = "python_version ~= '3.10' and (extra == 'numpy' or extra == 'all')", optional = true },
{ version = ">1.22.0,<2.0.0", markers = "python_version ~= '3.11' and (extra == 'numpy' or extra == 'all')", optional = true },
{ version = ">1.22.0,<2.0.0", markers = "python_version ~= '3.12' and (extra == 'numpy' or extra == 'all')", optional = true },
{ version = ">1.21.0,<2.0.0", markers = "python_version ~= '3.9.0' and (extra == 'numpy' or extra == 'all')", optional = true },
{ version = ">1.22.0,<2.0.0", markers = "python_version ~= '3.10' and (extra == 'numpy' or extra == 'all')", optional = true },
{ version = ">1.22.0,<2.0.0", markers = "python_version ~= '3.11' and (extra == 'numpy' or extra == 'all')", optional = true },
{ version = ">1.22.0,<2.0.0", markers = "python_version ~= '3.12' and (extra == 'numpy' or extra == 'all')", optional = true },
pablovela5620 marked this conversation as resolved.
Show resolved Hide resolved
]
jaxtyping = { version = "*", markers = "extra == 'jaxtyping' or extra == 'all'", optional = true }
orjson = { version = "*", markers = "extra == 'orjson' or extra == 'all'", optional = true }
plum-dispatch = ">=2,<2.3"
beartype = ">=0.18.4"
Expand All @@ -49,10 +50,10 @@ tomli = { version = "*", markers = "python_version <= '3.11.0'" }
tomli-w = "*"
msgpack = "*"
numpy = [
{ version = ">1.21.0,<2.0.0", markers = "python_version ~= '3.9.0'" },
{ version = ">1.22.0,<2.0.0", markers = "python_version ~= '3.10'" },
{ version = ">1.22.0,<2.0.0", markers = "python_version ~= '3.11'" },
{ version = ">1.22.0,<2.0.0", markers = "python_version ~= '3.12'" },
{ version = ">1.21.0,<2.0.0", markers = "python_version ~= '3.9.0'" },
{ version = ">1.22.0,<2.0.0", markers = "python_version ~= '3.10'" },
{ version = ">1.22.0,<2.0.0", markers = "python_version ~= '3.11'" },
{ version = ">1.22.0,<2.0.0", markers = "python_version ~= '3.12'" },
]
mypy = "==1.10.1"
pytest = "*"
Expand All @@ -68,6 +69,7 @@ types-PyYAML = "^6.0.9"
msgpack-types = "^0.3"
envclasses = "^0.3.1"
jedi = "*"
jaxtyping = "*"

[tool.poetry.extras]
msgpack = ["msgpack"]
Expand All @@ -76,7 +78,8 @@ toml = ["tomli", "tomli-w"]
yaml = ["pyyaml"]
orjson = ["orjson"]
sqlalchemy = ["sqlalchemy"]
all = ["msgpack", "tomli", "tomli-w", "pyyaml", "numpy", "orjson", "sqlalchemy"]
jaxtyping = ["jaxtyping"]
all = ["msgpack", "tomli", "tomli-w", "pyyaml", "numpy", "orjson", "sqlalchemy", "jaxtyping"]

[build-system]
requires = ["poetry-core>=1.0.0", "poetry-dynamic-versioning"]
Expand Down Expand Up @@ -145,16 +148,25 @@ exclude = [
"tests/test_sqlalchemy.py",
]

[[tool.mypy.overrides]]
# to avoid complaints about generic type ndarray
module = "examples.type_numpy_jaxtyping"
ignore_errors = true

[tool.ruff]
select = [
"E", # pycodestyle errors
"W", # pycodestyle warnings
"F", # pyflakes
"C", # flake8-comprehensions
"B", # flake8-bugbear
"E", # pycodestyle errors
"W", # pycodestyle warnings
"F", # pyflakes
"C", # flake8-comprehensions
"B", # flake8-bugbear
]
ignore = ["B904"]
line-length = 100

[tool.ruff.lint.mccabe]
max-complexity = 30

[tool.ruff.per-file-ignores]
# https://docs.kidger.site/jaxtyping/faq/#flake8-or-ruff-are-throwing-an-error
"examples/type_numpy_jaxtyping.py" = ["F722"]
5 changes: 5 additions & 0 deletions serde/de.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,9 @@
deserialize_numpy_array,
deserialize_numpy_scalar,
deserialize_numpy_array_direct,
deserialize_numpy_jaxtyping_array,
is_numpy_array,
is_numpy_jaxtyping,
is_numpy_scalar,
)

Expand Down Expand Up @@ -749,6 +751,9 @@ def render(self, arg: DeField[Any]) -> str:
elif is_numpy_array(arg.type):
self.import_numpy = True
res = deserialize_numpy_array(arg)
elif is_numpy_jaxtyping(arg.type):
self.import_numpy = True
res = deserialize_numpy_jaxtyping_array(arg)
elif is_union(arg.type):
res = self.union_func(arg)
elif is_str_serializable(arg.type):
Expand Down
19 changes: 19 additions & 0 deletions serde/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,15 @@
typ = origin
return typ is np.ndarray

def is_numpy_jaxtyping(typ) -> bool:
try:
origin = get_origin(typ)
if origin is not None:
typ = origin
return typ is not np.ndarray and issubclass(typ, np.ndarray)
except TypeError:
return False

def serialize_numpy_array(arg) -> str:
return f"{arg.varname}.tolist()"

Expand All @@ -86,6 +95,10 @@
dtype = fullname(arg[1][0].type)
return f"numpy.array({arg.data}, dtype={dtype})"

def deserialize_numpy_jaxtyping_array(arg) -> str:
dtype = f"numpy.{arg.type.dtypes[-1]}"
return f"numpy.array({arg.data}, dtype={dtype})"

def deserialize_numpy_array_direct(typ: Any, arg: Any) -> Any:
if is_bare_numpy_array(typ):
return np.array(arg)
Expand All @@ -111,6 +124,9 @@
def is_numpy_array(typ) -> bool:
return False

def is_numpy_jaxtyping(typ) -> bool:
return False

Check warning on line 128 in serde/numpy.py

View check run for this annotation

Codecov / codecov/patch

serde/numpy.py#L127-L128

Added lines #L127 - L128 were not covered by tests

def serialize_numpy_array(arg) -> str:
return ""

Expand All @@ -120,5 +136,8 @@
def deserialize_numpy_array(arg) -> str:
return ""

def deserialize_numpy_jaxtyping_array(arg) -> str:
return ""

Check warning on line 140 in serde/numpy.py

View check run for this annotation

Codecov / codecov/patch

serde/numpy.py#L139-L140

Added lines #L139 - L140 were not covered by tests

def deserialize_numpy_array_direct(typ: Any, arg: Any) -> Any:
return arg
3 changes: 3 additions & 0 deletions serde/se.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@
)
from .numpy import (
is_numpy_array,
is_numpy_jaxtyping,
is_numpy_datetime,
is_numpy_scalar,
serialize_numpy_array,
Expand Down Expand Up @@ -751,6 +752,8 @@ def render(self, arg: SeField[Any]) -> str:
res = serialize_numpy_scalar(arg)
elif is_numpy_array(arg.type):
res = serialize_numpy_array(arg)
elif is_numpy_jaxtyping(arg.type):
res = serialize_numpy_array(arg)
elif is_primitive(arg.type):
res = self.primitive(arg)
elif is_union(arg.type):
Expand Down
61 changes: 61 additions & 0 deletions tests/test_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import numpy as np
import numpy.typing as npt
import jaxtyping
import pytest

import serde
Expand Down Expand Up @@ -89,6 +90,66 @@ class NumpyDate:

assert de(NumpyDate, se(date_test)) == date_test

@serde.serde(**opt)
class NumpyJaxtyping:
float_: jaxtyping.Float[np.ndarray, "2 2"] # noqa: F722
float16: jaxtyping.Float16[np.ndarray, "2 2"] # noqa: F722
float32: jaxtyping.Float32[np.ndarray, "2 2"] # noqa: F722
float64: jaxtyping.Float64[np.ndarray, "2 2"] # noqa: F722
inexact: jaxtyping.Inexact[np.ndarray, "2 2"] # noqa: F722
int_: jaxtyping.Int[np.ndarray, "2 2"] # noqa: F722
int8: jaxtyping.Int8[np.ndarray, "2 2"] # noqa: F722
int16: jaxtyping.Int16[np.ndarray, "2 2"] # noqa: F722
int32: jaxtyping.Int32[np.ndarray, "2 2"] # noqa: F722
int64: jaxtyping.Int64[np.ndarray, "2 2"] # noqa: F722
integer: jaxtyping.Integer[np.ndarray, "2 2"] # noqa: F722
uint: jaxtyping.UInt[np.ndarray, "2 2"] # noqa: F722
uint8: jaxtyping.UInt8[np.ndarray, "2 2"] # noqa: F722
uint16: jaxtyping.UInt16[np.ndarray, "2 2"] # noqa: F722
uint32: jaxtyping.UInt32[np.ndarray, "2 2"] # noqa: F722
uint64: jaxtyping.UInt64[np.ndarray, "2 2"] # noqa: F722

def __eq__(self, other):
return (
(self.float_ == other.float_).all()
and (self.float16 == other.float16).all()
and (self.float32 == other.float32).all()
and (self.float64 == other.float64).all()
and (self.inexact == other.inexact).all()
and (self.int_ == other.int_).all()
and (self.int8 == other.int8).all()
and (self.int16 == other.int16).all()
and (self.int32 == other.int32).all()
and (self.int64 == other.int64).all()
and (self.integer == other.integer).all()
and (self.uint == other.uint).all()
and (self.uint8 == other.uint8).all()
and (self.uint16 == other.uint16).all()
and (self.uint32 == other.uint32).all()
and (self.uint64 == other.uint64).all()
)

jaxtyping_test = NumpyJaxtyping(
float_=np.array([[1, 2], [3, 4]], dtype=np.float_),
float16=np.array([[5, 6], [7, 8]], dtype=np.float16),
float32=np.array([[9, 10], [11, 12]], dtype=np.float32),
float64=np.array([[13, 14], [15, 16]], dtype=np.float64),
inexact=np.array([[17, 18], [19, 20]], dtype=np.float_),
int_=np.array([[21, 22], [23, 24]], dtype=np.int_),
int8=np.array([[25, 26], [27, 28]], dtype=np.int8),
int16=np.array([[29, 30], [31, 32]], dtype=np.int16),
int32=np.array([[33, 34], [35, 36]], dtype=np.int32),
int64=np.array([[37, 38], [39, 40]], dtype=np.int64),
integer=np.array([[41, 42], [43, 44]], dtype=np.int_),
uint=np.array([[45, 46], [47, 48]], dtype=np.uint),
uint8=np.array([[49, 50], [51, 52]], dtype=np.uint8),
uint16=np.array([[53, 54], [55, 56]], dtype=np.uint16),
uint32=np.array([[57, 58], [59, 60]], dtype=np.uint32),
uint64=np.array([[61, 62], [63, 64]], dtype=np.uint64),
)

assert de(NumpyJaxtyping, se(jaxtyping_test)) == jaxtyping_test


@pytest.mark.parametrize("opt", opt_case, ids=opt_case_ids())
@pytest.mark.parametrize("se,de", format_json + format_msgpack)
Expand Down