From 34c0552d5f18789200bb8db4ca26bcee3d8ea0c3 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Fri, 13 Sep 2024 15:28:33 +0000 Subject: [PATCH] Use PyType_FromSpecWithBases to construct the scalar type objects. This is a simpler and more stable API for manufacturing a type. --- ml_dtypes/_src/custom_float.h | 120 ++++++++++++---------------------- ml_dtypes/_src/intn_numpy.h | 119 +++++++++++++-------------------- 2 files changed, 86 insertions(+), 153 deletions(-) diff --git a/ml_dtypes/_src/custom_float.h b/ml_dtypes/_src/custom_float.h index a6c94d97..2786fde0 100644 --- a/ml_dtypes/_src/custom_float.h +++ b/ml_dtypes/_src/custom_float.h @@ -61,7 +61,8 @@ struct CustomFloatType { // registered by another system into NumPy. static PyObject* type_ptr; - static PyNumberMethods number_methods; + static PyType_Spec type_spec; + static PyType_Slot type_slots[]; static PyArray_ArrFuncs arr_funcs; static PyArray_DescrProto npy_descr_proto; static PyArray_Descr* npy_descr; @@ -242,47 +243,6 @@ PyObject* PyCustomFloat_TrueDivide(PyObject* a, PyObject* b) { return PyArray_Type.tp_as_number->nb_true_divide(a, b); } -// Python number methods for PyCustomFloat objects. -template -PyNumberMethods CustomFloatType::number_methods = { - PyCustomFloat_Add, // nb_add - PyCustomFloat_Subtract, // nb_subtract - PyCustomFloat_Multiply, // nb_multiply - nullptr, // nb_remainder - nullptr, // nb_divmod - nullptr, // nb_power - PyCustomFloat_Negative, // nb_negative - nullptr, // nb_positive - nullptr, // nb_absolute - nullptr, // nb_nonzero - nullptr, // nb_invert - nullptr, // nb_lshift - nullptr, // nb_rshift - nullptr, // nb_and - nullptr, // nb_xor - nullptr, // nb_or - PyCustomFloat_Int, // nb_int - nullptr, // reserved - PyCustomFloat_Float, // nb_float - - nullptr, // nb_inplace_add - nullptr, // nb_inplace_subtract - nullptr, // nb_inplace_multiply - nullptr, // nb_inplace_remainder - nullptr, // nb_inplace_power - nullptr, // nb_inplace_lshift - nullptr, // nb_inplace_rshift - nullptr, // nb_inplace_and - nullptr, // nb_inplace_xor - nullptr, // nb_inplace_or - - nullptr, // nb_floor_divide - PyCustomFloat_TrueDivide, // nb_true_divide - nullptr, // nb_inplace_floor_divide - nullptr, // nb_inplace_true_divide - nullptr, // nb_index -}; - // Constructs a new PyCustomFloat. template PyObject* PyCustomFloat_New(PyTypeObject* type, PyObject* args, @@ -401,6 +361,34 @@ Py_hash_t PyCustomFloat_Hash(PyObject* self) { return HashImpl(&_Py_HashDouble, self, static_cast(x)); } +template +PyType_Slot CustomFloatType::type_slots[] = { + {Py_tp_new, reinterpret_cast(PyCustomFloat_New)}, + {Py_tp_repr, reinterpret_cast(PyCustomFloat_Repr)}, + {Py_tp_hash, reinterpret_cast(PyCustomFloat_Hash)}, + {Py_tp_str, reinterpret_cast(PyCustomFloat_Str)}, + {Py_tp_doc, + reinterpret_cast(const_cast(TypeDescriptor::kTpDoc))}, + {Py_tp_richcompare, reinterpret_cast(PyCustomFloat_RichCompare)}, + {Py_nb_add, reinterpret_cast(PyCustomFloat_Add)}, + {Py_nb_subtract, reinterpret_cast(PyCustomFloat_Subtract)}, + {Py_nb_multiply, reinterpret_cast(PyCustomFloat_Multiply)}, + {Py_nb_negative, reinterpret_cast(PyCustomFloat_Negative)}, + {Py_nb_int, reinterpret_cast(PyCustomFloat_Int)}, + {Py_nb_float, reinterpret_cast(PyCustomFloat_Float)}, + {Py_nb_true_divide, reinterpret_cast(PyCustomFloat_TrueDivide)}, + {0, nullptr}, +}; + +template +PyType_Spec CustomFloatType::type_spec = { + /*.name=*/TypeDescriptor::kQualifiedTypeName, + /*.basicsize=*/static_cast(sizeof(PyCustomFloat)), + /*.itemsize=*/0, + /*.flags=*/Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, + /*.slots=*/CustomFloatType::type_slots, +}; + // Numpy support template PyArray_ArrFuncs CustomFloatType::arr_funcs; @@ -874,46 +862,22 @@ bool RegisterFloatUFuncs(PyObject* numpy) { template bool RegisterFloatDtype(PyObject* numpy) { - // TODO(jakevdp): simplify this; we no longer need heap allocation. - Safe_PyObjectPtr name = - make_safe(PyUnicode_FromString(TypeDescriptor::kTypeName)); - Safe_PyObjectPtr qualname = - make_safe(PyUnicode_FromString(TypeDescriptor::kTypeName)); - - PyHeapTypeObject* heap_type = reinterpret_cast( - PyType_Type.tp_alloc(&PyType_Type, 0)); - if (!heap_type) { - return false; - } - // Caution: we must not call any functions that might invoke the GC until - // PyType_Ready() is called. Otherwise the GC might see a half-constructed - // type object. - heap_type->ht_name = name.release(); - heap_type->ht_qualname = qualname.release(); - PyTypeObject* type = &heap_type->ht_type; - type->tp_name = TypeDescriptor::kTypeName; - type->tp_basicsize = sizeof(PyCustomFloat); - type->tp_flags = - Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE | Py_TPFLAGS_HEAPTYPE; - type->tp_base = &PyGenericArrType_Type; - type->tp_new = PyCustomFloat_New; - type->tp_repr = PyCustomFloat_Repr; - type->tp_hash = PyCustomFloat_Hash; - type->tp_str = PyCustomFloat_Str; - type->tp_doc = const_cast(TypeDescriptor::kTpDoc); - type->tp_richcompare = PyCustomFloat_RichCompare; - type->tp_as_number = &CustomFloatType::number_methods; - if (PyType_Ready(type) < 0) { - return false; - } - TypeDescriptor::type_ptr = reinterpret_cast(type); + // bases must be a tuple for Python 3.9 and earlier. Change to just pass + // the base type directly when dropping Python 3.9 support. + Safe_PyObjectPtr bases( + PyTuple_Pack(1, reinterpret_cast(&PyGenericArrType_Type))); + PyObject* type = + PyType_FromSpecWithBases(&CustomFloatType::type_spec, bases.get()); + if (!type) { + return false; + } + TypeDescriptor::type_ptr = type; Safe_PyObjectPtr module = make_safe(PyUnicode_FromString("ml_dtypes")); if (!module) { return false; } - if (PyObject_SetAttrString(TypeDescriptor::type_ptr, "__module__", - module.get()) < 0) { + if (PyObject_SetAttrString(type, "__module__", module.get()) < 0) { return false; } @@ -940,7 +904,7 @@ bool RegisterFloatDtype(PyObject* numpy) { PyArray_DescrProto& descr_proto = CustomFloatType::npy_descr_proto; descr_proto = GetCustomFloatDescrProto(); Py_SET_TYPE(&descr_proto, &PyArrayDescr_Type); - descr_proto.typeobj = type; + descr_proto.typeobj = reinterpret_cast(type); TypeDescriptor::npy_type = PyArray_RegisterDataType(&descr_proto); if (TypeDescriptor::npy_type < 0) { diff --git a/ml_dtypes/_src/intn_numpy.h b/ml_dtypes/_src/intn_numpy.h index ee2e96e4..ccb4ed63 100644 --- a/ml_dtypes/_src/intn_numpy.h +++ b/ml_dtypes/_src/intn_numpy.h @@ -50,7 +50,9 @@ struct IntNTypeDescriptor { // registered by another system into NumPy. static PyObject* type_ptr; - static PyNumberMethods number_methods; + static PyType_Spec type_spec; + static PyType_Slot type_slots[]; + static PyArray_ArrFuncs arr_funcs; static PyArray_DescrProto npy_descr_proto; static PyArray_Descr* npy_descr; @@ -310,47 +312,6 @@ PyObject* PyIntN_nb_floor_divide(PyObject* a, PyObject* b) { return PyArray_Type.tp_as_number->nb_floor_divide(a, b); } -// Python number methods for PyIntN objects. -template -PyNumberMethods IntNTypeDescriptor::number_methods = { - PyIntN_nb_add, // nb_add - PyIntN_nb_subtract, // nb_subtract - PyIntN_nb_multiply, // nb_multiply - PyIntN_nb_remainder, // nb_remainder - nullptr, // nb_divmod - nullptr, // nb_power - PyIntN_nb_negative, // nb_negative - PyIntN_nb_positive, // nb_positive - nullptr, // nb_absolute - nullptr, // nb_nonzero - nullptr, // nb_invert - nullptr, // nb_lshift - nullptr, // nb_rshift - nullptr, // nb_and - nullptr, // nb_xor - nullptr, // nb_or - PyIntN_nb_int, // nb_int - nullptr, // reserved - PyIntN_nb_float, // nb_float - - nullptr, // nb_inplace_add - nullptr, // nb_inplace_subtract - nullptr, // nb_inplace_multiply - nullptr, // nb_inplace_remainder - nullptr, // nb_inplace_power - nullptr, // nb_inplace_lshift - nullptr, // nb_inplace_rshift - nullptr, // nb_inplace_and - nullptr, // nb_inplace_xor - nullptr, // nb_inplace_or - - PyIntN_nb_floor_divide, // nb_floor_divide - nullptr, // nb_true_divide - nullptr, // nb_inplace_floor_divide - nullptr, // nb_inplace_true_divide - nullptr, // nb_index -}; - // Implementation of repr() for PyIntN. template PyObject* PyIntN_Repr(PyObject* self) { @@ -410,6 +371,36 @@ PyObject* PyIntN_RichCompare(PyObject* a, PyObject* b, int op) { PyArrayScalar_RETURN_BOOL_FROM_LONG(result); } +template +PyType_Slot IntNTypeDescriptor::type_slots[] = { + {Py_tp_new, reinterpret_cast(PyIntN_tp_new)}, + {Py_tp_repr, reinterpret_cast(PyIntN_Repr)}, + {Py_tp_hash, reinterpret_cast(PyIntN_Hash)}, + {Py_tp_str, reinterpret_cast(PyIntN_Str)}, + {Py_tp_doc, + reinterpret_cast(const_cast(TypeDescriptor::kTpDoc))}, + {Py_tp_richcompare, reinterpret_cast(PyIntN_RichCompare)}, + {Py_nb_add, reinterpret_cast(PyIntN_nb_add)}, + {Py_nb_subtract, reinterpret_cast(PyIntN_nb_subtract)}, + {Py_nb_multiply, reinterpret_cast(PyIntN_nb_multiply)}, + {Py_nb_remainder, reinterpret_cast(PyIntN_nb_remainder)}, + {Py_nb_negative, reinterpret_cast(PyIntN_nb_negative)}, + {Py_nb_positive, reinterpret_cast(PyIntN_nb_positive)}, + {Py_nb_int, reinterpret_cast(PyIntN_nb_int)}, + {Py_nb_float, reinterpret_cast(PyIntN_nb_float)}, + {Py_nb_floor_divide, reinterpret_cast(PyIntN_nb_floor_divide)}, + {0, nullptr}, +}; + +template +PyType_Spec IntNTypeDescriptor::type_spec = { + /*.name=*/TypeDescriptor::kQualifiedTypeName, + /*.basicsize=*/static_cast(sizeof(PyIntN)), + /*.itemsize=*/0, + /*.flags=*/Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, + /*.slots=*/IntNTypeDescriptor::type_slots, +}; + // Numpy support template PyArray_ArrFuncs IntNTypeDescriptor::arr_funcs; @@ -775,38 +766,16 @@ bool RegisterIntNUFuncs(PyObject* numpy) { template bool RegisterIntNDtype(PyObject* numpy) { - Safe_PyObjectPtr name = - make_safe(PyUnicode_FromString(TypeDescriptor::kTypeName)); - Safe_PyObjectPtr qualname = - make_safe(PyUnicode_FromString(TypeDescriptor::kTypeName)); - - PyHeapTypeObject* heap_type = reinterpret_cast( - PyType_Type.tp_alloc(&PyType_Type, 0)); - if (!heap_type) { - return false; - } - // Caution: we must not call any functions that might invoke the GC until - // PyType_Ready() is called. Otherwise the GC might see a half-constructed - // type object. - heap_type->ht_name = name.release(); - heap_type->ht_qualname = qualname.release(); - PyTypeObject* type = &heap_type->ht_type; - type->tp_name = TypeDescriptor::kTypeName; - type->tp_basicsize = sizeof(PyIntN); - type->tp_flags = - Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE | Py_TPFLAGS_HEAPTYPE; - type->tp_base = &PyGenericArrType_Type; - type->tp_new = PyIntN_tp_new; - type->tp_repr = PyIntN_Repr; - type->tp_hash = PyIntN_Hash; - type->tp_str = PyIntN_Str; - type->tp_doc = const_cast(TypeDescriptor::kTpDoc); - type->tp_richcompare = PyIntN_RichCompare; - type->tp_as_number = &IntNTypeDescriptor::number_methods; - if (PyType_Ready(type) < 0) { - return false; - } - TypeDescriptor::type_ptr = reinterpret_cast(type); + // bases must be a tuple for Python 3.9 and earlier. Change to just pass + // the base type directly when dropping Python 3.9 support. + Safe_PyObjectPtr bases( + PyTuple_Pack(1, reinterpret_cast(&PyGenericArrType_Type))); + PyObject* type = + PyType_FromSpecWithBases(&IntNTypeDescriptor::type_spec, bases.get()); + if (!type) { + return false; + } + TypeDescriptor::type_ptr = type; Safe_PyObjectPtr module = make_safe(PyUnicode_FromString("ml_dtypes")); if (!module) { @@ -840,7 +809,7 @@ bool RegisterIntNDtype(PyObject* numpy) { PyArray_DescrProto& descr_proto = IntNTypeDescriptor::npy_descr_proto; descr_proto = GetIntNDescrProto(); Py_SET_TYPE(&descr_proto, &PyArrayDescr_Type); - descr_proto.typeobj = type; + descr_proto.typeobj = reinterpret_cast(type); TypeDescriptor::npy_type = PyArray_RegisterDataType(&descr_proto); if (TypeDescriptor::npy_type < 0) {