Skip to content

Commit

Permalink
Merge pull request #19 from AMYPAD/numpy-typecheck
Browse files Browse the repository at this point in the history
  • Loading branch information
casperdcl authored Nov 22, 2021
2 parents ef2a25b + 7643ef1 commit 44667b6
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 15 deletions.
69 changes: 54 additions & 15 deletions cuvec/include/pycuvec.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -29,28 +29,30 @@
namespace cuvec {
template <typename T> struct PyType {
static const char *format() { return typeid(T).name(); }
static const char *npchr() { return ""; }
};
#define _PYCVEC_TPCHR(T, typestr) \
#define _PYCUVEC_TPCHR(T, typestr, npy_char) \
template <> struct PyType<T> { \
static const char *format() { return typestr; } \
static const char *npchr() { return npy_char; } \
}
_PYCVEC_TPCHR(char, "c");
_PYCVEC_TPCHR(signed char, "b");
_PYCVEC_TPCHR(unsigned char, "B");
_PYCUVEC_TPCHR(char, "c", "S");
_PYCUVEC_TPCHR(signed char, "b", "b");
_PYCUVEC_TPCHR(unsigned char, "B", "B");
#ifdef _Bool
_PYCVEC_TPCHR(_Bool, "?");
_PYCUVEC_TPCHR(_Bool, "?", "?");
#endif
_PYCVEC_TPCHR(short, "h");
_PYCVEC_TPCHR(unsigned short, "H");
_PYCVEC_TPCHR(int, "i");
_PYCVEC_TPCHR(unsigned int, "I");
_PYCVEC_TPCHR(long long, "q");
_PYCVEC_TPCHR(unsigned long long, "Q");
_PYCUVEC_TPCHR(short, "h", "h");
_PYCUVEC_TPCHR(unsigned short, "H", "H");
_PYCUVEC_TPCHR(int, "i", "i");
_PYCUVEC_TPCHR(unsigned int, "I", "I");
_PYCUVEC_TPCHR(long long, "q", "l");
_PYCUVEC_TPCHR(unsigned long long, "Q", "L");
#ifdef _CUVEC_HALF
_PYCVEC_TPCHR(_CUVEC_HALF, "e");
_PYCUVEC_TPCHR(_CUVEC_HALF, "e", "e");
#endif
_PYCVEC_TPCHR(float, "f");
_PYCVEC_TPCHR(double, "d");
_PYCUVEC_TPCHR(float, "f", "f");
_PYCUVEC_TPCHR(double, "d", "d");
} // namespace cuvec

/** classes */
Expand Down Expand Up @@ -233,6 +235,22 @@ template <class T> PyCuVec<T> *PyCuVec_deepcopy(PyCuVec<T> *other) {
template <class T> PyCuVec<T> *asPyCuVec(PyObject *o) {
if (!o || Py_None == o) return NULL;
if (PyObject_HasAttrString(o, "cuvec")) {
// NumPy type checking
PyObject *dtype = PyObject_GetAttrString(o, "dtype");
PyObject *c = PyObject_GetAttrString(dtype, "char");
Py_XDECREF(dtype);
if (PyUnicode_Check(c)) {
char *npchr = (char *)PyUnicode_1BYTE_DATA(c);
if (*npchr != *cuvec::PyType<T>::npchr()) {
PyErr_Format(PyExc_TypeError,
"cannot convert underlying dtype('%s') to requested dtype('%s')", npchr,
cuvec::PyType<T>::npchr());
Py_DECREF(c);
return NULL;
}
}
Py_XDECREF(c);
// return cuvec
o = PyObject_GetAttrString(o, "cuvec");
if (!o) return NULL;
Py_DECREF(o);
Expand All @@ -242,6 +260,22 @@ template <class T> PyCuVec<T> *asPyCuVec(PyObject *o) {
template <class T> PyCuVec<T> *asPyCuVec(PyCuVec<T> *o) {
if (!o || Py_None == (PyObject *)o) return NULL;
if (PyObject_HasAttrString((PyObject *)o, "cuvec")) {
// NumPy type checking
PyObject *dtype = PyObject_GetAttrString((PyObject *)o, "dtype");
PyObject *c = PyObject_GetAttrString(dtype, "char");
Py_XDECREF(dtype);
if (PyUnicode_Check(c)) {
char *npchr = (char *)PyUnicode_1BYTE_DATA(c);
if (*npchr != *cuvec::PyType<T>::npchr()) {
PyErr_Format(PyExc_TypeError,
"cannot convert underlying dtype('%s') to requested dtype('%s')", npchr,
cuvec::PyType<T>::npchr());
Py_DECREF(c);
return NULL;
}
}
Py_XDECREF(c);
// return cuvec
o = (PyCuVec<T> *)PyObject_GetAttrString((PyObject *)o, "cuvec");
if (!o) return NULL;
Py_DECREF((PyObject *)o);
Expand All @@ -251,12 +285,17 @@ template <class T> PyCuVec<T> *asPyCuVec(PyCuVec<T> *o) {
/// conversion functions for PyArg_Parse...(..., "O&", ...)
#define ASCUVEC(T, typechar) \
int asPyCuVec_##typechar(PyObject *object, void **address) { \
*address = (void *)asPyCuVec<T>(object); \
PyCuVec<T> *o = asPyCuVec<T>(object); \
if (!o) return 0; \
*address = (void *)o; \
return 1; \
}
ASCUVEC(signed char, b)
ASCUVEC(unsigned char, B)
ASCUVEC(char, c)
#ifdef _Bool
ASCUVEC(_Bool, "?", "?");
#endif
ASCUVEC(short, h)
ASCUVEC(unsigned short, H)
ASCUVEC(int, i)
Expand Down
13 changes: 13 additions & 0 deletions tests/test_pycuvec.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,3 +150,16 @@ def test_increment_return():
assert (a == 1).all()
del a
assert (res == 1).all()


def test_np_types():
from cuvec.example_mod import increment2d_f
f = cu.zeros((1337, 42), 'f')
d = cu.zeros((1337, 42), 'd')
cu.asarray(increment2d_f(f))
cu.asarray(increment2d_f(f, f))
with raises(TypeError):
cu.asarray(increment2d_f(d))
with raises(SystemError):
# the TypeError is suppressed since a new output is generated
cu.asarray(increment2d_f(f, d))

0 comments on commit 44667b6

Please sign in to comment.