Skip to content

Commit

Permalink
mypy fixes, universal_flatten minor change
Browse files Browse the repository at this point in the history
universal_flatten now always returns a numeric sequence:
if the input is a scalar (or at least not a sequence) it will be put into a list
  • Loading branch information
mivanit committed Jul 22, 2023
1 parent 05656e1 commit 40fe8f0
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 9 deletions.
4 changes: 2 additions & 2 deletions muutils/json_serialize/json_serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
try:
from muutils.json_serialize.array import ArrayMode, serialize_array
except ImportError as e:
ArrayMode = None
serialize_array = None
ArrayMode = str # type: ignore[misc]
serialize_array = lambda *args, **kwargs: None
warnings.warn(
f"muutils.json_serialize.array could not be imported probably because missing numpy, array serialization will not work: \n{e}",
ImportWarning,
Expand Down
14 changes: 7 additions & 7 deletions muutils/statcounter.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from functools import cached_property
from itertools import chain
from types import NoneType
from typing import Callable, Iterable, Optional, Sequence, Union
from typing import Callable, Optional, Sequence, Union

# _GeneralArray = Union[np.ndarray, "torch.Tensor"]
NumericSequence = Sequence[Union[float, int]]
Expand All @@ -16,23 +16,23 @@


def universal_flatten(
arr: NumericSequence, require_rectangular: bool = True
arr: NumericSequence | float | int, require_rectangular: bool = True
) -> NumericSequence:
"""flattens any iterable"""

# mypy complains that the sequence has no attribute "flatten"
if hasattr(arr, "flatten") and callable(arr.flatten): # type: ignore
return arr.flatten() # type: ignore
elif not isinstance(arr, Iterable):
return arr
else:
elements_iterable: list[bool] = [isinstance(x, Iterable) for x in arr]
elif isinstance(arr, Sequence):
elements_iterable: list[bool] = [isinstance(x, Sequence) for x in arr]
if require_rectangular and (all(elements_iterable) != any(elements_iterable)):
raise ValueError("arr contains mixed iterable and non-iterable elements")
if any(elements_iterable):
return list(chain.from_iterable(universal_flatten(x) for x in arr))
return list(chain.from_iterable(universal_flatten(x) for x in arr)) # type: ignore[misc]
else:
return arr
else:
return [arr]


# StatCounter
Expand Down

0 comments on commit 40fe8f0

Please sign in to comment.