Skip to content

Commit

Permalink
Merge pull request #192 from vfdev-5:fix-_finfo_cache-init
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 675182623
  • Loading branch information
The ml_dtypes Authors committed Sep 16, 2024
2 parents 34f5c29 + 522e963 commit b65a1f6
Showing 1 changed file with 3 additions and 8 deletions.
11 changes: 3 additions & 8 deletions ml_dtypes/_finfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@

"""Overload of numpy.finfo to handle dtypes defined in ml_dtypes."""

from typing import Dict

from ml_dtypes._ml_dtypes_ext import bfloat16
from ml_dtypes._ml_dtypes_ext import float4_e2m1fn
from ml_dtypes._ml_dtypes_ext import float6_e2m3fn
Expand Down Expand Up @@ -154,7 +152,6 @@ def __init__(self):

class finfo(np.finfo): # pylint: disable=invalid-name,missing-class-docstring
__doc__ = np.finfo.__doc__
_finfo_cache: Dict[type, np.finfo] = {} # pylint: disable=g-bare-generic

@staticmethod
def _bfloat16_finfo():
Expand Down Expand Up @@ -699,6 +696,9 @@ def float_to_str(f):
_float8_e8m0fnu_dtype: _float8_e8m0fnu_finfo,
}
_finfo_name_map = {t.name: t for t in _finfo_type_map}
_finfo_cache = {
t: init_fn.__func__() for t, init_fn in _finfo_type_map.items() # pytype: disable=attribute-error
}

def __new__(cls, dtype):
if isinstance(dtype, str):
Expand All @@ -710,9 +710,4 @@ def __new__(cls, dtype):
i = cls._finfo_cache.get(key)
if i is not None:
return i

init = cls._finfo_type_map.get(key)
if init is not None:
cls._finfo_cache[dtype] = init.__func__() # pytype: disable=attribute-error
return cls._finfo_cache[dtype]
return super().__new__(cls, dtype)

0 comments on commit b65a1f6

Please sign in to comment.