Skip to content

Commit

Permalink
format
Browse files Browse the repository at this point in the history
  • Loading branch information
mivanit committed Jun 11, 2024
1 parent 48e936b commit 2615fa3
Show file tree
Hide file tree
Showing 6 changed files with 142 additions and 70 deletions.
8 changes: 7 additions & 1 deletion muutils/json_serialize/serializable_dataclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,11 @@ def dc_eq(

T = TypeVar("T")


class ZanjMissingWarning(UserWarning):
pass


_zanj_loading_needs_import: bool = True


Expand All @@ -275,7 +280,8 @@ def zanj_register_loader_serializable_dataclass(cls: Type[T]):
)
except ImportError:
warnings.warn(
"ZANJ not installed, cannot register serializable dataclass loader. ZANJ can be found at https://github.com/mivanit/ZANJ"
"ZANJ not installed, cannot register serializable dataclass loader. ZANJ can be found at https://github.com/mivanit/ZANJ or installed via `pip install zanj`",
ZanjMissingWarning,
)
return

Expand Down
19 changes: 17 additions & 2 deletions muutils/nbutils/configure_notebook.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,20 @@

import matplotlib.pyplot as plt # type: ignore[import]


class PlotlyNotInstalledWarning(UserWarning):
pass


# handle plotly importing
PLOTLY_IMPORTED: bool
try:
import plotly.io as pio # type: ignore[import]
except ImportError:
warnings.warn("Plotly not installed. Plotly plots will not be available.")
warnings.warn(
"Plotly not installed. Plotly plots will not be available.",
PlotlyNotInstalledWarning,
)
PLOTLY_IMPORTED = False
else:
PLOTLY_IMPORTED = True
Expand Down Expand Up @@ -40,13 +48,20 @@
TIKZPLOTLIB_FORMATS = ["tex", "tikz"]


class UnknownFigureFormatWarning(UserWarning):
pass


def universal_savefig(fname: str, fmt: str | None = None) -> None:
# try to infer format from fname
if fmt is None:
fmt = fname.split(".")[-1]

if not (fmt in MATPLOTLIB_FORMATS or fmt in TIKZPLOTLIB_FORMATS):
warnings.warn(f"Unknown format '{fmt}', defaulting to '{FIG_OUTPUT_FMT}'")
warnings.warn(
f"Unknown format '{fmt}', defaulting to '{FIG_OUTPUT_FMT}'",
UnknownFigureFormatWarning,
)
fmt = FIG_OUTPUT_FMT

# not sure why linting is throwing an error here
Expand Down
4 changes: 2 additions & 2 deletions muutils/sysinfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import sys
import typing

from pip._internal.operations.freeze import freeze
from pip._internal.operations.freeze import freeze as pip_freeze


def _popen(cmd: list[str], split_out: bool = False) -> dict[str, typing.Any]:
Expand Down Expand Up @@ -46,7 +46,7 @@ def python() -> dict:
@staticmethod
def pip() -> dict:
"""installed packages info"""
pckgs: list[str] = [x for x in freeze(local_only=True)]
pckgs: list[str] = [x for x in pip_freeze(local_only=True)]
return {
"n_packages": len(pckgs),
"packages": pckgs,
Expand Down
147 changes: 84 additions & 63 deletions tests/unit/json_serialize/serializable_dataclass/test_helpers.py
Original file line number Diff line number Diff line change
@@ -1,40 +1,9 @@
import dataclasses
from dataclasses import dataclass

import numpy as np
import pytest
import torch

from muutils.json_serialize.serializable_dataclass import (
array_safe_eq,
dc_eq,
serializable_field,
)


@dataclasses.dataclass(eq=False)
class TestClass:
a: int
b: np.ndarray = serializable_field()
c: torch.Tensor = serializable_field()
e: list[int] = serializable_field()
f: dict[str, int] = serializable_field()


instance1 = TestClass(
a=1,
b=np.array([1, 2, 3]),
c=torch.tensor([1, 2, 3]),
e=[1, 2, 3],
f={"key1": 1, "key2": 2},
)

instance2 = TestClass(
a=1,
b=np.array([1, 2, 3]),
c=torch.tensor([1, 2, 3]),
e=[1, 2, 3],
f={"key1": 1, "key2": 2},
)
from muutils.json_serialize.serializable_dataclass import array_safe_eq, dc_eq


def test_array_safe_eq():
Expand All @@ -44,33 +13,85 @@ def test_array_safe_eq():
assert not array_safe_eq(torch.tensor([1, 2, 3]), torch.tensor([4, 5, 6]))


@pytest.mark.parametrize(
"instance1, instance2, expected",
[
(instance1, instance2, True),
(
instance1,
TestClass(
a=1,
b=np.array([4, 5, 6]),
c=torch.tensor([1, 2, 3]),
e=[1, 2, 3],
f={"key1": 1, "key2": 2},
),
False,
),
(
instance1,
TestClass(
a=2,
b=np.array([1, 2, 3]),
c=torch.tensor([1, 2, 3]),
e=[1, 2, 3],
f={"key1": 1, "key2": 2},
),
False,
),
],
)
def test_dc_eq(instance1, instance2, expected):
assert dc_eq(instance1, instance2) == expected
def test_dc_eq_case1():
@dataclass(eq=False)
class TestClass:
a: int
b: np.ndarray
c: torch.Tensor
e: list[int]
f: dict[str, int]

instance1 = TestClass(
a=1,
b=np.array([1, 2, 3]),
c=torch.tensor([1, 2, 3]),
e=[1, 2, 3],
f={"key1": 1, "key2": 2},
)

instance2 = TestClass(
a=1,
b=np.array([1, 2, 3]),
c=torch.tensor([1, 2, 3]),
e=[1, 2, 3],
f={"key1": 1, "key2": 2},
)

assert dc_eq(instance1, instance2)


def test_dc_eq_case2():
@dataclass(eq=False)
class TestClass:
a: int
b: np.ndarray
c: torch.Tensor
e: list[int]
f: dict[str, int]

instance1 = TestClass(
a=1,
b=np.array([1, 2, 3]),
c=torch.tensor([1, 2, 3]),
e=[1, 2, 3],
f={"key1": 1, "key2": 2},
)

instance2 = TestClass(
a=1,
b=np.array([4, 5, 6]),
c=torch.tensor([1, 2, 3]),
e=[1, 2, 3],
f={"key1": 1, "key2": 2},
)

assert not dc_eq(instance1, instance2)


def test_dc_eq_case3():
@dataclass(eq=False)
class TestClass:
a: int
b: np.ndarray
c: torch.Tensor
e: list[int]
f: dict[str, int]

instance1 = TestClass(
a=1,
b=np.array([1, 2, 3]),
c=torch.tensor([1, 2, 3]),
e=[1, 2, 3],
f={"key1": 1, "key2": 2},
)

instance2 = TestClass(
a=2,
b=np.array([1, 2, 3]),
c=torch.tensor([1, 2, 3]),
e=[1, 2, 3],
f={"key1": 1, "key2": 2},
)

assert not dc_eq(instance1, instance2)
31 changes: 30 additions & 1 deletion tests/unit/nbutils/test_configure_notebook.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,16 @@
import os
import warnings

import matplotlib.pyplot as plt # type: ignore[import]
import pytest
import torch

from muutils.nbutils.configure_notebook import configure_notebook, plotshow, setup_plots
from muutils.nbutils.configure_notebook import (
UnknownFigureFormatWarning,
configure_notebook,
plotshow,
setup_plots,
)

JUNK_DATA_PATH: str = "tests/junk_data/test_cfg_notebook"

Expand Down Expand Up @@ -73,6 +79,29 @@ def test_plotshow_save_mixed():
assert os.path.exists(os.path.join(JUNK_DATA_PATH, "mixedfig-3.pdf"))


def test_warn_unknown_format():
with pytest.warns(UnknownFigureFormatWarning):
setup_plots(
plot_mode="save",
fig_basepath=JUNK_DATA_PATH,
fig_numbered_fname="mixedfig-{num}",
)
plt.plot([1, 2, 3], [1, 2, 3])
plotshow()


def test_no_warn_pdf_format():
with warnings.catch_warnings():
warnings.simplefilter("error")
setup_plots(
plot_mode="save",
fig_basepath="JUNK_DATA_PATH",
fig_numbered_fname="fig-{num}.pdf",
)
plt.plot([1, 2, 3], [1, 2, 3])
plotshow()


def test_plotshow_ignore():
setup_plots(plot_mode="ignore")
plt.plot([1, 2, 3], [1, 2, 3])
Expand Down
3 changes: 2 additions & 1 deletion tests/unit/test_statcounter.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ def _compute_err(a: float, b: float, /) -> dict[str, float]:
num_a=float(a),
num_b=float(b),
diff=float(b - a),
frac_err=float((b - a) / a),
# frac_err=float((b - a) / a), # this causes division by zero, whatever
)


Expand Down Expand Up @@ -47,6 +47,7 @@ def test_statcounter() -> None:
# arrs.append(np.random.randint(i, j, size=1000))

for a in arrs:

r = _compare_np_custom(a)

assert all(
Expand Down

0 comments on commit 2615fa3

Please sign in to comment.