Skip to content

Commit

Permalink
working on more comprehensive tests
Browse files Browse the repository at this point in the history
  • Loading branch information
mivanit committed Jun 9, 2024
1 parent 11c5eb3 commit 4b87c17
Show file tree
Hide file tree
Showing 8 changed files with 148 additions and 18 deletions.
6 changes: 3 additions & 3 deletions docs/coverage/coverage.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
21 changes: 12 additions & 9 deletions docs/coverage/coverage.txt
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
Name Stmts Miss Cover Missing
---------------------------------------------------------------------------------------------------------------
muutils\__init__.py 0 0 100%
muutils\dictmagic.py 153 21 86% 14-19, 22-25, 177, 285, 432, 460-472
muutils\dictmagic.py 162 23 86% 14-19, 22-25, 177, 285, 445-449, 453, 486-498
muutils\group_equiv.py 28 0 100%
muutils\json_serialize\__init__.py 5 0 100%
muutils\json_serialize\array.py 80 19 76% 18-21, 74, 102, 114, 118, 121, 125, 129, 136, 143, 155, 174-179, 184, 187
muutils\json_serialize\dataclass_factories.py 115 115 0% 1-303
muutils\json_serialize\json_serialize.py 64 18 72% 10-13, 80, 244, 271-286, 295-298
muutils\json_serialize\serializable_dataclass.py 178 35 80% 69, 85, 141, 147, 153, 160-162, 171-174, 208-223, 233, 237, 240, 243, 249, 259-260, 264, 273-277, 281, 302, 305, 337, 363, 379-380, 403, 430, 436
muutils\json_serialize\util.py 76 41 46% 11-13, 24, 33, 36, 44-50, 59, 67-74, 85-88, 94-104, 116-119, 123-126
muutils\json_serialize\util.py 76 16 79% 11-13, 24, 33, 36, 44-50, 73, 102, 125-126
muutils\jsonlines.py 31 31 0% 1-73
muutils\kappa.py 14 0 100%
muutils\logger\__init__.py 5 0 100%
Expand All @@ -19,25 +18,29 @@ muutils\logger\logger.py
muutils\logger\loggingstream.py 39 12 69% 41-74, 79, 89-90
muutils\logger\simplelogger.py 40 19 52% 14, 18, 22, 26, 53-63, 67-79
muutils\logger\timing.py 39 19 51% 25-28, 41-46, 50-52, 65-68, 79-84
muutils\misc.py 54 43 20% 8-11, 16-28, 38-53, 59-64, 87-97, 103-112, 126-134
muutils\misc.py 164 12 93% 155, 186, 227-229, 288, 297, 300, 320-321, 334, 373-374
muutils\mlutils.py 66 17 74% 16-20, 29, 35-49, 54, 56, 65-70, 141-142, 153
muutils\nbutils\__init__.py 2 0 100%
muutils\nbutils\configure_notebook.py 128 46 64% 15, 22-23, 58-64, 81, 91-92, 97, 100-101, 121-124, 129, 135-144, 151-156, 204-215, 221-223, 249-256, 259
muutils\nbutils\convert_ipynb_to_script.py 118 53 55% 63, 78, 91, 105-139, 169, 212-236, 263, 287-289, 296-349
muutils\nbutils\convert_ipynb_to_script.py 118 41 65% 63, 78, 91, 105-139, 228-230, 236, 263, 287-289, 296-349
muutils\nbutils\mermaid.py 11 7 36% 5-8, 15-18
muutils\nbutils\print_tex.py 10 10 0% 1-19
muutils\nbutils\run_notebook_tests.py 58 20 66% 29, 31, 35, 39, 45, 53, 80-82, 86-90, 97-114
muutils\statcounter.py 87 32 63% 24-35, 50, 70, 98, 110, 120, 136-166, 174, 183, 186, 190-195, 204
muutils\sysinfo.py 71 18 75% 21, 60-61, 78-111, 145, 168
muutils\tensor_utils.py 125 39 69% 76, 79, 84, 96-136, 143, 157-165, 352, 418, 429, 443
muutils\tensor_utils.py 125 19 85% 82, 85, 105, 109, 119, 130, 133-136, 144, 151, 165-173
tests\unit\json_serialize\serializable_dataclass\test_sdc_defaults.py 31 0 100%
tests\unit\json_serialize\serializable_dataclass\test_sdc_properties_nested.py 26 0 100%
tests\unit\json_serialize\serializable_dataclass\test_serializable_dataclass.py 190 0 100%
tests\unit\json_serialize\test_array.py 40 0 100%
tests\unit\json_serialize\test_util.py 49 2 96% 62, 69
tests\unit\logger\test_logger.py 10 0 100%
tests\unit\logger\test_timer_context.py 9 0 100%
tests\unit\misc\test_freeze.py 120 0 100%
tests\unit\misc\test_misc.py 43 0 100%
tests\unit\misc\test_numerical_conversions.py 42 0 100%
tests\unit\nbutils\test_configure_notebook.py 49 0 100%
tests\unit\nbutils\test_conversion.py 20 0 100%
tests\unit\nbutils\test_conversion.py 26 0 100%
tests\unit\test_chunks.py 31 0 100%
tests\unit\test_dictmagic.py 129 0 100%
tests\unit\test_group_equiv.py 12 0 100%
Expand All @@ -46,6 +49,6 @@ tests\unit\test_kappa.py
tests\unit\test_mlutils.py 35 3 91% 31, 35, 43
tests\unit\test_statcounter.py 13 0 100%
tests\unit\test_sysinfo.py 4 0 100%
tests\unit\test_tensor_utils.py 36 0 100%
tests\unit\test_tensor_utils.py 48 0 100%
---------------------------------------------------------------------------------------------------------------
TOTAL 2434 682 72%
TOTAL 2710 483 82%
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import Any

from muutils.json_serialize import JSONitem
from muutils.json_serialize.dataclass_factories import (
from muutils._wip.dataclass_factories import (
dataclass_loader_factory,
dataclass_serializer_factory,
)
Expand Down
16 changes: 12 additions & 4 deletions muutils/tensor_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,12 @@ def jaxtype_factory(
default_jax_dtype=jaxtyping.Float,
legacy_mode: typing.Literal["error", "warn", "ignore"] = "warn",
) -> type:
"""usage:
```
ATensor = jaxtype_factory("ATensor", torch.Tensor, jaxtyping.Float)
x: ATensor["dim1 dim2", np.float32]
```
"""
class _BaseArray:
"""jaxtyping shorthand
(backwards compatible with older versions of muutils.tensor_utils)
Expand All @@ -81,13 +87,15 @@ def __init_subclass__(cls, *args, **kwargs):
@classmethod
def param_info(cls, params) -> str:
"""useful for error printing"""
return str(
return "\n".join(
f"{k} = {v}"
for k, v in
{
"cls.__name__": cls.__name__,
"cls.__doc__": cls.__doc__,
"params": params,
"type(params)": type(params),
}
}.items()
)

@typing._tp_cache # type: ignore
Expand All @@ -99,7 +107,7 @@ def __class_getitem__(cls, params: str | tuple) -> type:
elif isinstance(params, tuple):
if len(params) != 2:
raise Exception(
f"unexpected type for params:\n{cls.param_info(params)}"
f"unexpected type for params, expected tuple of length 2 here:\n{cls.param_info(params)}"
)

if isinstance(params[0], str):
Expand All @@ -126,7 +134,7 @@ def __class_getitem__(cls, params: str | tuple) -> type:
shape_anot.append("".join(str(y) for y in x))
else:
raise Exception(
f"unexpected type for params:\n{cls.param_info(params)}"
f"unexpected type for params, expected first part to be str, int, or tuple:\n{cls.param_info(params)}"
)

return TYPE_TO_JAX_DTYPE[params[1]][
Expand Down
75 changes: 75 additions & 0 deletions tests/unit/json_serialize/test_util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
from typing import NamedTuple
from collections import namedtuple

import pytest

# Module code assumed to be imported from my_module
from muutils.json_serialize.util import (
UniversalContainer,
isinstance_namedtuple,
try_catch,
_recursive_hashify,
SerializationException,
string_as_lines,
safe_getsource,
)

def test_universal_container():
uc = UniversalContainer()
assert 'anything' in uc
assert 123 in uc
assert None in uc

def test_isinstance_namedtuple():
Point = namedtuple('Point', ['x', 'y'])
p = Point(1, 2)
assert isinstance_namedtuple(p)
assert not isinstance_namedtuple((1, 2))

class Point2(NamedTuple):
x: int
y: int
p2 = Point2(1, 2)
assert isinstance_namedtuple(p2)



def test_try_catch():
@try_catch
def raises_value_error():
raise ValueError("test error")

@try_catch
def normal_func(x):
return x

assert raises_value_error() == "ValueError: test error"
assert normal_func(10) == 10

def test_recursive_hashify():
assert _recursive_hashify({"a": [1, 2, 3]}) == (('a', (1, 2, 3)),)
assert _recursive_hashify([1, 2, 3]) == (1, 2, 3)
assert _recursive_hashify(123) == 123
with pytest.raises(ValueError):
_recursive_hashify(object(), force=False)

def test_string_as_lines():
assert string_as_lines("line1\nline2\nline3") == ["line1", "line2", "line3"]
assert string_as_lines(None) == []

def test_safe_getsource():
def sample_func():
pass

source = safe_getsource(sample_func)
print(f"Source of sample_func: {source}")
assert "def sample_func():" in source[0]

def raises_error():
raise Exception("test error")

wrapped_func = try_catch(raises_error)
error_source = safe_getsource(wrapped_func)
print(f"Source of wrapped_func: {error_source}")
# Check for the original function's source since the decorator doesn't change this
assert any("def raises_error():" in line for line in error_source)
25 changes: 24 additions & 1 deletion tests/unit/nbutils/test_conversion.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import os
import itertools

from muutils.nbutils.convert_ipynb_to_script import process_dir
import pytest

from muutils.nbutils.convert_ipynb_to_script import process_dir, process_file
from muutils.nbutils.run_notebook_tests import run_notebook_tests

notebooks_input_dir: str = "tests/input_data/notebooks"
Expand Down Expand Up @@ -37,3 +40,23 @@ def test_run_notebook_tests():
with open(os.path.join(test_output_dir, fname), "r") as f:
actual = f.read()
assert expected == actual

@pytest.mark.parametrize(
"idx, args",
enumerate(itertools.product(
[True, False],
[r"#%%", "#"+"="*50],
[True, False],
["%", ("!", "#"), ("import", "return")],
)),
)
def test_file_conversion(idx, args):
os.makedirs(test_output_dir, exist_ok=True)
process_file(
in_file=os.path.join(notebooks_input_dir, "dummy_notebook.ipynb"),
out_file=os.path.join(test_output_dir, f"dn-test-{idx}.py"),
strip_md_cells=args[0],
header_comment=args[1],
disable_plots=args[2],
filter_out_lines=args[3],
)
21 changes: 21 additions & 0 deletions tests/unit/test_tensor_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@
from muutils.tensor_utils import (
DTYPE_MAP,
TORCH_DTYPE_MAP,
StateDictKeysError,
StateDictShapeError,
compare_state_dicts,
get_dict_shapes,
jaxtype_factory,
lpad_array,
lpad_tensor,
Expand All @@ -24,6 +27,9 @@ def test_jaxtype_factory():
assert "default_jax_dtype = <class 'jaxtyping.Float'" in ATensor.__doc__
assert "array_type = <class 'torch.Tensor'>" in ATensor.__doc__

x = ATensor[(1, 2, 3), np.float32]
x = ATensor["dim1 dim2", np.float32]


def test_numpy_to_torch_dtype():
assert numpy_to_torch_dtype(np.float32) == torch.float32
Expand Down Expand Up @@ -60,3 +66,18 @@ def test_compare_state_dicts():
d2["a"] = torch.tensor([7, 8, 9])
with pytest.raises(AssertionError):
compare_state_dicts(d1, d2) # This should raise an exception

d2["a"] = torch.tensor([7, 8, 9, 10])
with pytest.raises(StateDictShapeError):
compare_state_dicts(d1, d2) # This should raise an exception

d2["c"] = torch.tensor([10, 11, 12])
with pytest.raises(StateDictKeysError):
compare_state_dicts(d1, d2) # This should raise an exception


def test_get_dict_shapes():

x = {"a": torch.rand(2, 3), "b": torch.rand(1, 3, 5), "c": torch.rand(2)}
x_shapes = get_dict_shapes(x)
assert x_shapes == {"a": (2, 3), "b": (1, 3, 5), "c": (2,)}

0 comments on commit 4b87c17

Please sign in to comment.