diff --git a/cuvec/include/pycuvec.cuh b/cuvec/include/pycuvec.cuh index 86c39d9..63c4ce6 100644 --- a/cuvec/include/pycuvec.cuh +++ b/cuvec/include/pycuvec.cuh @@ -244,6 +244,22 @@ template PyCuVec *asPyCuVec(PyObject *o) { template PyCuVec *asPyCuVec(PyCuVec *o) { if (!o || Py_None == (PyObject *)o) return NULL; if (PyObject_HasAttrString((PyObject *)o, "cuvec")) { + // NumPy type checking + PyObject *dtype = PyObject_GetAttrString((PyObject *)o, "dtype"); + PyObject *c = PyObject_GetAttrString(dtype, "char"); + Py_XDECREF(dtype); + if (PyUnicode_Check(c)) { + char *npchr = (char *)PyUnicode_1BYTE_DATA(c); + if (*npchr != *cuvec::PyType::npchr()) { + PyErr_Format(PyExc_TypeError, + "cannot convert underlying dtype('%s') to requested dtype('%s')", npchr, + cuvec::PyType::npchr()); + Py_DECREF(c); + return NULL; + } + } + Py_XDECREF(c); + // return cuvec o = (PyCuVec *)PyObject_GetAttrString((PyObject *)o, "cuvec"); if (!o) return NULL; Py_DECREF((PyObject *)o); @@ -253,12 +269,17 @@ template PyCuVec *asPyCuVec(PyCuVec *o) { /// conversion functions for PyArg_Parse...(..., "O&", ...) #define ASCUVEC(T, typechar) \ int asPyCuVec_##typechar(PyObject *object, void **address) { \ - *address = (void *)asPyCuVec(object); \ + PyCuVec *o = asPyCuVec(object); \ + if (!o) return 0; \ + *address = (void *)o; \ return 1; \ } ASCUVEC(signed char, b) ASCUVEC(unsigned char, B) ASCUVEC(char, c) +#ifdef _Bool +ASCUVEC(_Bool, "?", "?"); +#endif ASCUVEC(short, h) ASCUVEC(unsigned short, H) ASCUVEC(int, i) diff --git a/tests/test_pycuvec.py b/tests/test_pycuvec.py index 2eaa26e..a821580 100644 --- a/tests/test_pycuvec.py +++ b/tests/test_pycuvec.py @@ -150,3 +150,15 @@ def test_increment_return(): assert (a == 1).all() del a assert (res == 1).all() + + +def test_np_types(): + from cuvec.example_mod import increment2d_f + f = cu.zeros((1337, 42), 'f') + d = cu.zeros((1337, 42), 'd') + cu.asarray(increment2d_f(f)) + cu.asarray(increment2d_f(f, f)) + with raises((TypeError, SystemError)): + cu.asarray(increment2d_f(f, d)) + with raises((TypeError, SystemError)): + cu.asarray(increment2d_f(d))