Skip to content

Commit

Permalink
NumPy type checking
Browse files Browse the repository at this point in the history
- fixes #18
  • Loading branch information
casperdcl committed Nov 21, 2021
1 parent 30faaeb commit ccc3a45
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 1 deletion.
23 changes: 22 additions & 1 deletion cuvec/include/pycuvec.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,22 @@ template <class T> PyCuVec<T> *asPyCuVec(PyObject *o) {
template <class T> PyCuVec<T> *asPyCuVec(PyCuVec<T> *o) {
if (!o || Py_None == (PyObject *)o) return NULL;
if (PyObject_HasAttrString((PyObject *)o, "cuvec")) {
// NumPy type checking
PyObject *dtype = PyObject_GetAttrString((PyObject *)o, "dtype");
PyObject *c = PyObject_GetAttrString(dtype, "char");
Py_XDECREF(dtype);
if (PyUnicode_Check(c)) {
char *npchr = (char *)PyUnicode_1BYTE_DATA(c);
if (*npchr != *cuvec::PyType<T>::npchr()) {
PyErr_Format(PyExc_TypeError,
"cannot convert underlying dtype('%s') to requested dtype('%s')", npchr,
cuvec::PyType<T>::npchr());
Py_DECREF(c);
return NULL;
}
}
Py_XDECREF(c);
// return cuvec
o = (PyCuVec<T> *)PyObject_GetAttrString((PyObject *)o, "cuvec");
if (!o) return NULL;
Py_DECREF((PyObject *)o);
Expand All @@ -253,12 +269,17 @@ template <class T> PyCuVec<T> *asPyCuVec(PyCuVec<T> *o) {
/// conversion functions for PyArg_Parse...(..., "O&", ...)
#define ASCUVEC(T, typechar) \
int asPyCuVec_##typechar(PyObject *object, void **address) { \
*address = (void *)asPyCuVec<T>(object); \
PyCuVec<T> *o = asPyCuVec<T>(object); \
if (!o) return 0; \
*address = (void *)o; \
return 1; \
}
ASCUVEC(signed char, b)
ASCUVEC(unsigned char, B)
ASCUVEC(char, c)
#ifdef _Bool
ASCUVEC(_Bool, "?", "?");
#endif
ASCUVEC(short, h)
ASCUVEC(unsigned short, H)
ASCUVEC(int, i)
Expand Down
12 changes: 12 additions & 0 deletions tests/test_pycuvec.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,3 +150,15 @@ def test_increment_return():
assert (a == 1).all()
del a
assert (res == 1).all()


def test_np_types():
from cuvec.example_mod import increment2d_f
f = cu.zeros((1337, 42), 'f')
d = cu.zeros((1337, 42), 'd')
cu.asarray(increment2d_f(f))
cu.asarray(increment2d_f(f, f))
with raises((TypeError, SystemError)):
cu.asarray(increment2d_f(f, d))
with raises((TypeError, SystemError)):
cu.asarray(increment2d_f(d))

0 comments on commit ccc3a45

Please sign in to comment.