diff --git a/stumpy/core.py b/stumpy/core.py index 4cdaea02a..fdbc9aefd 100644 --- a/stumpy/core.py +++ b/stumpy/core.py @@ -429,11 +429,11 @@ def check_dtype(a, dtype=np.float64): # pragma: no cover TypeError If the array type does not match `dtype` """ - if dtype == int: + if dtype is int: dtype = np.int64 - if dtype == float: + if dtype is float: dtype = np.float64 - if dtype == bool: + if dtype is bool: dtype = np.bool_ if not np.issubdtype(a.dtype, dtype): msg = f"{dtype} dtype expected but found {a.dtype} in input array\n"