Skip to content

Commit

Permalink
add NumpyEncoder when using json.dump in UnifyConfig
Browse files Browse the repository at this point in the history
  • Loading branch information
kervias committed Dec 4, 2023
1 parent b01e624 commit 4368f53
Showing 1 changed file with 37 additions and 2 deletions.
39 changes: 37 additions & 2 deletions edustudio/utils/common/configUtil.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import types
import re
from functools import reduce
import numpy as np


class UnifyConfig(object):
Expand Down Expand Up @@ -136,14 +137,14 @@ def __repr__(self):
def dump_fmt(self):
return json.dumps(
self.__config__, indent=4, ensure_ascii=False,
default=lambda o: o.__config__ if isinstance(o, UnifyConfig) else str(o)
cls=NumpyEncoder
)

def dump_file(self, filepath: str, encoding: str = 'utf-8'):
with open(filepath, "w", encoding=encoding) as f:
json.dump(
self.__config__, f, indent=4, ensure_ascii=False,
default=lambda o: o.__config__ if isinstance(o, UnifyConfig) else str(o)
cls=NumpyEncoder
)

def __copy__(self):
Expand All @@ -153,3 +154,37 @@ def __copy__(self):
def __deepcopy__(self, memo: Any):
cls = self.__class__
return cls(dic=copy.deepcopy(self.__config__, memo=memo))


class NumpyEncoder(json.JSONEncoder):
"""
Custom encoder for numpy data types
Ref: https://github.com/hmallen/numpyencoder/blob/f8199a61ccde25f829444a9df4b21bcb2d1de8f2/numpyencoder/numpyencoder.py
"""

def default(self, obj):
if isinstance(obj, UnifyConfig):
return obj.__config__
elif isinstance(obj, (np.int_, np.intc, np.intp, np.int8,
np.int16, np.int32, np.int64, np.uint8,
np.uint16, np.uint32, np.uint64)):
return int(obj)

elif isinstance(obj, (np.float_, np.float16, np.float32, np.float64)):
return float(obj)

elif isinstance(obj, (np.complex_, np.complex64, np.complex128)):
return {'real': obj.real, 'imag': obj.imag}

elif isinstance(obj, (np.ndarray,)):
return obj.tolist()

elif isinstance(obj, (np.bool_)):
return bool(obj)

elif isinstance(obj, (np.void)):
return None
try:
return json.JSONEncoder.default(self, obj)
except TypeError:
return str(obj)

0 comments on commit 4368f53

Please sign in to comment.