Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
casperdcl committed Nov 22, 2021
1 parent ccc3a45 commit 7643ef1
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 3 deletions.
16 changes: 16 additions & 0 deletions cuvec/include/pycuvec.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,22 @@ template <class T> PyCuVec<T> *PyCuVec_deepcopy(PyCuVec<T> *other) {
template <class T> PyCuVec<T> *asPyCuVec(PyObject *o) {
if (!o || Py_None == o) return NULL;
if (PyObject_HasAttrString(o, "cuvec")) {
// NumPy type checking
PyObject *dtype = PyObject_GetAttrString(o, "dtype");
PyObject *c = PyObject_GetAttrString(dtype, "char");
Py_XDECREF(dtype);
if (PyUnicode_Check(c)) {
char *npchr = (char *)PyUnicode_1BYTE_DATA(c);
if (*npchr != *cuvec::PyType<T>::npchr()) {
PyErr_Format(PyExc_TypeError,
"cannot convert underlying dtype('%s') to requested dtype('%s')", npchr,
cuvec::PyType<T>::npchr());
Py_DECREF(c);
return NULL;
}
}
Py_XDECREF(c);
// return cuvec
o = PyObject_GetAttrString(o, "cuvec");
if (!o) return NULL;
Py_DECREF(o);
Expand Down
7 changes: 4 additions & 3 deletions tests/test_pycuvec.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,8 @@ def test_np_types():
d = cu.zeros((1337, 42), 'd')
cu.asarray(increment2d_f(f))
cu.asarray(increment2d_f(f, f))
with raises((TypeError, SystemError)):
cu.asarray(increment2d_f(f, d))
with raises((TypeError, SystemError)):
with raises(TypeError):
cu.asarray(increment2d_f(d))
with raises(SystemError):
# the TypeError is suppressed since a new output is generated
cu.asarray(increment2d_f(f, d))

0 comments on commit 7643ef1

Please sign in to comment.