Skip to content

Commit

Permalink
fix formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
mivanit committed Jun 18, 2024
1 parent d50a38d commit 695aed1
Show file tree
Hide file tree
Showing 31 changed files with 93 additions and 52 deletions.
2 changes: 1 addition & 1 deletion muutils/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from __future__ import annotations
from __future__ import annotations
15 changes: 14 additions & 1 deletion muutils/dictmagic.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,21 @@
from __future__ import annotations

import typing
import warnings
from collections import defaultdict
from typing import Any, Callable, Generic, Hashable, Iterable, Literal, TypeVar, Dict, Union, Optional, Tuple
from typing import (
Any,
Callable,
Dict,
Generic,
Hashable,
Iterable,
Literal,
Optional,
Tuple,
TypeVar,
Union,
)

_KT = TypeVar("_KT")
_VT = TypeVar("_VT")
Expand Down
1 change: 1 addition & 0 deletions muutils/group_equiv.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from __future__ import annotations

from itertools import chain
from typing import Callable, Sequence, TypeVar

Expand Down
1 change: 1 addition & 0 deletions muutils/json_serialize/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from __future__ import annotations

from muutils.json_serialize.array import arr_metadata, load_array
from muutils.json_serialize.json_serialize import (
BASE_HANDLERS,
Expand Down
1 change: 1 addition & 0 deletions muutils/json_serialize/array.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from __future__ import annotations

import typing
import warnings
from typing import Any, Iterable, Literal, Optional, Sequence
Expand Down
6 changes: 3 additions & 3 deletions muutils/json_serialize/json_serialize.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from __future__ import annotations

import inspect
import types
import warnings
from dataclasses import dataclass, is_dataclass
from pathlib import Path
from typing import Any, Callable, Iterable, Mapping, Dict, Set, Union
from typing import Any, Callable, Dict, Iterable, Mapping, Set, Union

try:
from muutils.json_serialize.array import ArrayMode, serialize_array
Expand Down Expand Up @@ -54,7 +54,7 @@
"<class 'torch.dtype'>",
}

ObjectPath = MonoTuple[Union[str,int]]
ObjectPath = MonoTuple[Union[str, int]]


@dataclass
Expand Down
20 changes: 8 additions & 12 deletions muutils/json_serialize/serializable_dataclass.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from __future__ import annotations

import abc
import dataclasses
import json
Expand Down Expand Up @@ -38,8 +39,7 @@ def __init__(
self,
default: Union[Any, dataclasses._MISSING_TYPE] = dataclasses.MISSING,
default_factory: Union[
Callable[[], Any],
dataclasses._MISSING_TYPE
Callable[[], Any], dataclasses._MISSING_TYPE
] = dataclasses.MISSING,
init: bool = True,
repr: bool = True,
Expand Down Expand Up @@ -74,11 +74,9 @@ def __init__(
super_kwargs["metadata"] = types.MappingProxyType({})

# special check, kw_only is not supported in python <3.9 and `dataclasses.MISSING` is truthy
if (sys.version_info[1] < 9):
if (super_kwargs["kw_only"] == True): # noqa: E712
raise ValueError(
"kw_only is not supported in python >=3.9"
)
if sys.version_info[1] < 9:
if super_kwargs["kw_only"] == True: # noqa: E712
raise ValueError("kw_only is not supported in python >=3.9")
else:
del super_kwargs["kw_only"]

Expand Down Expand Up @@ -433,12 +431,10 @@ def wrap(cls: Type[T]) -> Type[T]:
setattr(cls, field_name, field_value)

# special check, kw_only is not supported in python <3.9 and `dataclasses.MISSING` is truthy
if (sys.version_info[1] < 9):
if sys.version_info[1] < 9:
if "kw_only" in kwargs:
if (kwargs["kw_only"] == True): # noqa: E712
raise ValueError(
"kw_only is not supported in python >=3.9"
)
if kwargs["kw_only"] == True: # noqa: E712
raise ValueError("kw_only is not supported in python >=3.9")
else:
del kwargs["kw_only"]

Expand Down
8 changes: 3 additions & 5 deletions muutils/json_serialize/util.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from __future__ import annotations

import functools
import inspect
import sys
import types
import typing
import warnings
from typing import Any, Callable, Iterable, Literal, Union, Dict
from typing import Any, Callable, Dict, Iterable, Literal, Union

_NUMPY_WORKING: bool
try:
Expand Down Expand Up @@ -54,9 +54,7 @@ def __class_getitem__(cls, params):
elif len(params) == 1:
return typing.GenericAlias(tuple, (params[0], Ellipsis))
else:
raise TypeError(
f"MonoTuple expects 1 type argument, got {params = }"
)
raise TypeError(f"MonoTuple expects 1 type argument, got {params = }")


class UniversalContainer:
Expand Down
1 change: 1 addition & 0 deletions muutils/logger/headerfuncs.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from __future__ import annotations

import json
from typing import Any, Mapping, Protocol

Expand Down
1 change: 1 addition & 0 deletions muutils/logger/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
this was mostly made with training models in mind and storing both metadata and loss
- `TimerContext` is a context manager that can be used to time the duration of a block of code
"""

from __future__ import annotations

import json
Expand Down
1 change: 1 addition & 0 deletions muutils/logger/loggingstream.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from __future__ import annotations

import time
from dataclasses import dataclass, field
from typing import Any, Callable
Expand Down
1 change: 1 addition & 0 deletions muutils/logger/simplelogger.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from __future__ import annotations

import json
import sys
import time
Expand Down
1 change: 1 addition & 0 deletions muutils/logger/timing.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from __future__ import annotations

import time
from typing import Literal

Expand Down
46 changes: 26 additions & 20 deletions muutils/misc.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from __future__ import annotations

import hashlib
import typing

Expand Down Expand Up @@ -65,37 +66,37 @@ def list_join(lst: list, factory: typing.Callable) -> list:
# name stuff
# ================================================================================


def sanitize_name(
name: str | None,
additional_allowed_chars: str = "",
replace_invalid: str = "",
when_none: str | None = "_None_",
leading_digit_prefix: str = "",
) -> str:
name: str | None,
additional_allowed_chars: str = "",
replace_invalid: str = "",
when_none: str | None = "_None_",
leading_digit_prefix: str = "",
) -> str:
"""sanitize a string, leaving only alphanumerics and `additional_allowed_chars`
# Parameters:
- `name : str | None`
- `name : str | None`
input string
- `additional_allowed_chars : str`
- `additional_allowed_chars : str`
additional characters to allow, none by default
(defaults to `""`)
- `replace_invalid : str`
- `replace_invalid : str`
character to replace invalid characters with
(defaults to `""`)
- `when_none : str | None`
- `when_none : str | None`
string to return if `name` is `None`. if `None`, raises an exception
(defaults to `"_None_"`)
- `leading_digit_prefix : str`
- `leading_digit_prefix : str`
character to prefix the string with if it starts with a digit
(defaults to `""`)
# Returns:
- `str`
- `str`
sanitized string
"""
"""


if name is None:
if when_none is None:
raise ValueError("name is None")
Expand All @@ -110,26 +111,31 @@ def sanitize_name(
sanitized += char
else:
sanitized += replace_invalid

if sanitized[0].isdigit():
sanitized = leading_digit_prefix + sanitized

return sanitized


def sanitize_fname(fname: str | None, **kwargs) -> str:
"""sanitize a filename to posix standards
- leave only alphanumerics, `_` (underscore), '-' (dash) and `.` (period)
"""
return sanitize_name(fname, additional_allowed_chars="._-", **kwargs)


def sanitize_identifier(fname: str | None, **kwargs) -> str:
"""sanitize an identifier (variable or function name)
- leave only alphanumerics and `_` (underscore)
- prefix with `_` if it starts with a digit
"""
return sanitize_name(fname, additional_allowed_chars="_", leading_digit_prefix="_", **kwargs)
return sanitize_name(
fname, additional_allowed_chars="_", leading_digit_prefix="_", **kwargs
)


def dict_to_filename(
data: dict,
Expand Down
6 changes: 4 additions & 2 deletions muutils/mlutils.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from __future__ import annotations

import json
import os
import random
import typing
import warnings
from itertools import islice
from pathlib import Path
from typing import Any, Callable, TypeVar, Union, Optional
from typing import Any, Callable, Optional, TypeVar, Union

ARRAY_IMPORTS: bool
try:
Expand Down Expand Up @@ -138,12 +139,13 @@ def register_method(
def decorator(method: F) -> F:
method_name: str
if custom_name is None:
method_name: str|None = getattr(method, "__name__", None)
method_name: str | None = getattr(method, "__name__", None)
if method_name is None:
warnings.warn(
f"Method {method} does not have a name, using sanitized repr"
)
from muutils.misc import sanitize_identifier

method_name = sanitize_identifier(repr(method))
else:
method_name = custom_name
Expand Down
1 change: 1 addition & 0 deletions muutils/nbutils/configure_notebook.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from __future__ import annotations

import os
import typing
import warnings
Expand Down
1 change: 1 addition & 0 deletions muutils/nbutils/convert_ipynb_to_script.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from __future__ import annotations

import argparse
import json
import os
Expand Down
3 changes: 2 additions & 1 deletion muutils/statcounter.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from __future__ import annotations

import json
import math
from collections import Counter
Expand Down Expand Up @@ -54,7 +55,7 @@ def min(self):

def max(self):
return max(x for x, v in self.items() if v > 0)

def total(self):
"""Sum of the counts"""
return sum(self.values())
Expand Down
4 changes: 3 additions & 1 deletion muutils/sysinfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
from pip._internal.operations.freeze import freeze as pip_freeze


def _popen(cmd: typing.List[str], split_out: bool = False) -> typing.Dict[str, typing.Any]:
def _popen(
cmd: typing.List[str], split_out: bool = False
) -> typing.Dict[str, typing.Any]:
p: subprocess.Popen = subprocess.Popen(
cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE
)
Expand Down
4 changes: 3 additions & 1 deletion muutils/tensor_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,9 @@ def rpad_array(
return pad_array(array, pad_length, pad_value, rpad=True)


def get_dict_shapes(d: typing.Dict[str, "torch.Tensor"]) -> typing.Dict[str, typing.Tuple[int, ...]]:
def get_dict_shapes(
d: typing.Dict[str, "torch.Tensor"]
) -> typing.Dict[str, typing.Tuple[int, ...]]:
"""given a state dict or cache dict, compute the shapes and put them in a nested dict"""
return dotlist_to_nested_dict({k: tuple(v.shape) for k, v in d.items()})

Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from __future__ import annotations

from dataclasses import dataclass

import numpy as np
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Dict

from muutils.json_serialize import (
JsonSerializer,
SerializableDataclass,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from __future__ import annotations

import sys

import pytest
Expand All @@ -9,6 +10,7 @@

print(f"{SUPPORS_KW_ONLY = }")


@serializable_dataclass
class Person(SerializableDataclass):
first_name: str
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from __future__ import annotations

from typing import Any

import pytest
Expand Down
1 change: 1 addition & 0 deletions tests/unit/misc/test_freeze.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from __future__ import annotations

import pytest

from muutils.misc import freeze
Expand Down
Loading

0 comments on commit 695aed1

Please sign in to comment.