Skip to content

Commit

Permalink
Fix argmin/argmax for int4 types
Browse files Browse the repository at this point in the history
We used the wrong value init value, causing us to never update the candidate index.

While we are here, fix-up some minor bugs involving scalar conversion from float16 and float128.

PiperOrigin-RevId: 578022918
  • Loading branch information
majnemer authored and The ml_dtypes Authors committed Oct 31, 2023
1 parent 171155f commit 979569c
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 57 deletions.
8 changes: 8 additions & 0 deletions ml_dtypes/_src/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,14 @@ struct TypeDescriptor<std::complex<long double>> {
static int Dtype() { return NPY_CLONGDOUBLE; }
};

template <class T>
struct is_complex : std::false_type {};
template <class T>
struct is_complex<std::complex<T>> : std::true_type {};

template <typename T>
inline constexpr bool is_complex_v = is_complex<T>::value;

} // namespace ml_dtypes

#endif // ML_DTYPES_COMMON_H_
11 changes: 5 additions & 6 deletions ml_dtypes/_src/custom_float.h
Original file line number Diff line number Diff line change
Expand Up @@ -599,12 +599,11 @@ int NPyCustomFloat_ArgMinFunc(void* data, npy_intp n, npy_intp* min_ind,

template <typename T>
float CastToFloat(T value) {
return static_cast<float>(value);
}

template <typename T>
float CastToFloat(std::complex<T> value) {
return CastToFloat(value.real());
if constexpr (is_complex_v<T>) {
return CastToFloat(value.real());
} else {
return static_cast<float>(value);
}
}

// Performs a NumPy array cast from type 'From' to 'To'.
Expand Down
86 changes: 38 additions & 48 deletions ml_dtypes/_src/int4_numpy.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ limitations under the License.
#ifndef ML_DTYPES_INT4_NUMPY_H_
#define ML_DTYPES_INT4_NUMPY_H_

#include <limits>
#include <type_traits>

// Must be included first
Expand Down Expand Up @@ -55,7 +56,7 @@ int Int4TypeDescriptor<T>::npy_type = NPY_NOTYPE;
template <typename T>
PyObject* Int4TypeDescriptor<T>::type_ptr = nullptr;

// Representation of a Python custom float object.
// Representation of a Python custom integer object.
template <typename T>
struct PyInt4 {
PyObject_HEAD; // Python object header
Expand Down Expand Up @@ -96,7 +97,7 @@ Safe_PyObjectPtr PyInt4_FromValue(T x) {
return ref;
}

// Converts a Python object to a reduced float value. Returns true on success,
// Converts a Python object to a reduced integer value. Returns true on success,
// returns false and reports a Python error on failure.
template <typename T>
bool CastToInt4(PyObject* arg, T* output) {
Expand Down Expand Up @@ -143,8 +144,8 @@ bool CastToInt4(PyObject* arg, T* output) {
*output = T(v);
return true;
}
if (PyArray_IsScalar(arg, Float)) {
float f;
auto floating_conversion = [&](auto type) -> bool {
decltype(type) f;
PyArray_ScalarAsCtype(arg, &f);
if (!(std::numeric_limits<T>::min() <= f &&
f <= std::numeric_limits<T>::max())) {
Expand All @@ -153,17 +154,18 @@ bool CastToInt4(PyObject* arg, T* output) {
}
*output = T(static_cast<::int8_t>(f));
return true;
};
if (PyArray_IsScalar(arg, Half)) {
return floating_conversion(Eigen::half{});
}
if (PyArray_IsScalar(arg, Float)) {
return floating_conversion(float{});
}
if (PyArray_IsScalar(arg, Double)) {
double d;
PyArray_ScalarAsCtype(arg, &d);
if (!(std::numeric_limits<T>::min() <= d &&
d <= std::numeric_limits<T>::max())) {
PyErr_SetString(PyExc_OverflowError, kOutOfRange);
return false;
}
*output = T(static_cast<::int8_t>(d));
return true;
return floating_conversion(double{});
}
if (PyArray_IsScalar(arg, LongDouble)) {
return floating_conversion((long double){});
}
return false;
}
Expand Down Expand Up @@ -216,14 +218,13 @@ PyObject* PyInt4_tp_new(PyTypeObject* type, PyObject* args, PyObject* kwds) {
template <typename T>
PyObject* PyInt4_nb_float(PyObject* self) {
T x = PyInt4_Value_Unchecked<T>(self);
return PyFloat_FromDouble(static_cast<double>(static_cast<float>(x)));
return PyFloat_FromDouble(static_cast<double>(x));
}

template <typename T>
PyObject* PyInt4_nb_int(PyObject* self) {
T x = PyInt4_Value_Unchecked<T>(self);
long y = static_cast<long>(static_cast<float>(x)); // NOLINT
return PyLong_FromLong(y);
return PyLong_FromLong(static_cast<long>(x)); // NOLINT
}

template <typename T>
Expand Down Expand Up @@ -538,12 +539,11 @@ int NPyInt4_CompareFunc(const void* v1, const void* v2, void* arr) {
template <typename T>
int NPyInt4_ArgMaxFunc(void* data, npy_intp n, npy_intp* max_ind, void* arr) {
const T* bdata = reinterpret_cast<const T*>(data);
// Start with a max_val of NaN, this results in the first iteration preferring
// bdata[0].
int max_val = std::numeric_limits<int>::max();
// Start with a max_val of INT_MIN, this results in the first iteration
// preferring bdata[0].
int max_val = std::numeric_limits<int>::lowest();
for (npy_intp i = 0; i < n; ++i) {
// This condition is chosen so that NaNs are always considered "max".
if (!(static_cast<int>(bdata[i]) <= max_val)) {
if (static_cast<int>(bdata[i]) > max_val) {
max_val = static_cast<int>(bdata[i]);
*max_ind = i;
}
Expand All @@ -554,43 +554,33 @@ int NPyInt4_ArgMaxFunc(void* data, npy_intp n, npy_intp* max_ind, void* arr) {
template <typename T>
int NPyInt4_ArgMinFunc(void* data, npy_intp n, npy_intp* min_ind, void* arr) {
const T* bdata = reinterpret_cast<const T*>(data);
int min_val = std::numeric_limits<int>::lowest();
// Start with a min_val of NaN, this results in the first iteration preferring
// bdata[0].
int min_val = std::numeric_limits<int>::max();
// Start with a min_val of INT_MAX, this results in the first iteration
// preferring bdata[0].
for (npy_intp i = 0; i < n; ++i) {
// This condition is chosen so that NaNs are always considered "min".
if (!(static_cast<int>(bdata[i]) >= min_val)) {
if (static_cast<int>(bdata[i]) < min_val) {
min_val = static_cast<int>(bdata[i]);
*min_ind = i;
}
}
return 0;
}

template <typename T, std::enable_if_t<(std::is_floating_point<T>::value ||
std::is_same<T, Eigen::half>::value),
bool> = true>
template <typename T>
int CastToInt(T value) {
if (std::isnan(value) || std::isinf(value) ||
value < std::numeric_limits<int>::lowest() ||
value > std::numeric_limits<int>::max()) {
return 0;
if constexpr (is_complex_v<T>) {
return CastToInt(value.real());
} else {
static_assert(std::numeric_limits<T>::is_specialized);
if constexpr (!std::numeric_limits<T>::is_integer) {
if (std::isnan(value) || std::isinf(value) ||
value < std::numeric_limits<int>::lowest() ||
value > std::numeric_limits<int>::max()) {
return 0;
}
}
return static_cast<int>(value);
}
return static_cast<int>(value);
}

template <typename T, std::enable_if_t<std::is_integral<T>::value, bool> = true>
int CastToInt(T value) {
return static_cast<int>(value);
}

int CastToInt(int4 value) { return static_cast<int>(value); }

int CastToInt(uint4 value) { return static_cast<int>(value); }

template <typename T>
int CastToInt(std::complex<T> value) {
return CastToInt(value.real());
}

// Performs a NumPy array cast from type 'From' to 'To'.
Expand Down
13 changes: 10 additions & 3 deletions ml_dtypes/tests/int4_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,10 @@ def testPickleable(self, scalar_type):
self.assertEqual(x_out.dtype, x.dtype)
np.testing.assert_array_equal(x_out.astype(int), x.astype(int))

@parameterized.product(scalar_type=INT4_TYPES, python_scalar=[int, float])
@parameterized.product(
scalar_type=INT4_TYPES,
python_scalar=[int, float, np.float16, np.longdouble],
)
def testRoundTripToPythonScalar(self, scalar_type, python_scalar):
for v in VALUES[scalar_type]:
self.assertEqual(v, scalar_type(v))
Expand Down Expand Up @@ -241,12 +244,16 @@ def testArray(self, scalar_type):

@parameterized.product(
scalar_type=INT4_TYPES,
ufunc=[np.nonzero, np.logical_not],
ufunc=[np.nonzero, np.logical_not, np.argmax, np.argmin],
)
def testUnaryPredicateUfunc(self, scalar_type, ufunc):
x = np.array(VALUES[scalar_type])
y = np.array(VALUES[scalar_type], dtype=scalar_type)
np.testing.assert_array_equal(ufunc(x), ufunc(y))
# Compute `ufunc(y)` first so we don't get lucky by reusing memory
# initialized by `ufunc(x)`.
y_result = ufunc(y)
x_result = ufunc(x)
np.testing.assert_array_equal(x_result, y_result)

@parameterized.product(
scalar_type=INT4_TYPES,
Expand Down

0 comments on commit 979569c

Please sign in to comment.