Skip to content

Commit

Permalink
Add mpz.__array__() method to interact with numpy
Browse files Browse the repository at this point in the history
Closes #507
  • Loading branch information
skirpichev committed Sep 13, 2024
1 parent c7629dd commit 5099d35
Show file tree
Hide file tree
Showing 5 changed files with 104 additions and 1 deletion.
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ content-type = 'text/x-rst'

[project.optional-dependencies]
docs = ['sphinx>=4', 'sphinx-rtd-theme>=1']
tests = ['pytest', 'hypothesis', 'cython', 'mpmath', 'setuptools']
tests = ['pytest', 'hypothesis', 'cython', 'mpmath', 'setuptools',
'numpy; python_version>="3.10" and platform_system=="Linux"']

[project.urls]
Homepage = 'https://github.com/aleaxit/gmpy'
Expand Down
1 change: 1 addition & 0 deletions src/gmpy2_mpz.c
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ static PyMethodDef GMPy_MPZ_methods[] = {
{ "__round__", (PyCFunction)GMPy_MPZ_Method_Round, METH_FASTCALL, GMPy_doc_mpz_method_round },
{ "__sizeof__", GMPy_MPZ_Method_SizeOf, METH_NOARGS, GMPy_doc_mpz_method_sizeof },
{ "__trunc__", GMPy_MPZ_Method_Trunc, METH_NOARGS, GMPy_doc_mpz_method_trunc },
{ "__array__", (PyCFunction)GMPy_MPZ_Method_Array, METH_FASTCALL | METH_KEYWORDS, GMPy_doc_mpz_method_array },
{ "bit_clear", GMPy_MPZ_bit_clear_method, METH_O, doc_bit_clear_method },
{ "bit_count", GMPy_MPZ_bit_count_method, METH_NOARGS, doc_bit_count_method },
{ "bit_flip", GMPy_MPZ_bit_flip_method, METH_O, doc_bit_flip_method },
Expand Down
85 changes: 85 additions & 0 deletions src/gmpy2_mpz_misc.c
Original file line number Diff line number Diff line change
Expand Up @@ -2167,3 +2167,88 @@ GMPy_MP_Method_Conjugate(PyObject *self, PyObject *args)
Py_INCREF((PyObject*)self);
return (PyObject*)self;
}

PyDoc_STRVAR(GMPy_doc_mpz_method_array,
"x.__array__(dtype=None, copy=None)\n");

static PyObject *
GMPy_MPZ_Method_Array(PyObject *self, PyObject *const *args,
Py_ssize_t nargs, PyObject *kwnames)
{
Py_ssize_t i, nkws = 0;
int argidx[2] = {-1, -1};
const char* kwname;
PyObject *dtype = Py_None, *copy = Py_None;

if (nargs > 2) {
TYPE_ERROR("__array__() takes at most 2 positional arguments");
return NULL;
}
if (nargs >= 1) {
argidx[0] = 0;
}
if (nargs == 2) {
argidx[1] = 1;
}

if (kwnames) {
nkws = PyTuple_GET_SIZE(kwnames);
}
if (nkws > 2) {
TYPE_ERROR("__array__() takes at most 2 keyword arguments");
return NULL;
}
for (i = 0; i < nkws; i++) {
kwname = PyUnicode_AsUTF8(PyTuple_GET_ITEM(kwnames, i));
if (strcmp(kwname, "dtype") == 0) {
if (nargs == 0) {
argidx[0] = (int)(nargs + i);
}
else {
TYPE_ERROR("argument for __array__() given by name ('dtype') and position (1)");
return NULL;
}
}
else if (strcmp(kwname, "copy") == 0) {
if (nargs <= 1) {
argidx[1] = (int)(nargs + i);
}
else {
TYPE_ERROR("argument for __array__() given by name ('copy') and position (2)");
return NULL;
}
}
else {
TYPE_ERROR("got an invalid keyword argument for __array__()");
return NULL;
}
}

if (argidx[0] >= 0) {
dtype = args[argidx[0]];
}
if (argidx[1] >= 0) {
copy = args[argidx[1]];
}

PyObject *mod = PyImport_ImportModule("numpy");

if (!mod) {
return NULL;
}

PyObject *tmp_long = GMPy_PyLong_From_MPZ((MPZ_Object *)self, NULL);

if (!tmp_long) {
Py_DECREF(mod);
return NULL;
}

PyObject *result = PyObject_CallMethod(mod, "array",
"OO", tmp_long, dtype);

Py_DECREF(mod);
Py_DECREF(tmp_long);

return result;
}
1 change: 1 addition & 0 deletions src/gmpy2_mpz_misc.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ static PyObject * GMPy_MPZ_Method_IsProbabPrime(PyObject *self, PyObject *const
static PyObject * GMPy_MPZ_Method_IsEven(PyObject *self, PyObject *other);
static PyObject * GMPy_MPZ_Method_IsOdd(PyObject *self, PyObject *other);
static PyObject * GMPy_MP_Method_Conjugate(PyObject *self, PyObject *args);
static PyObject * GMPy_MPZ_Method_Array(PyObject *self, PyObject *const *args, Py_ssize_t nargs, PyObject *kwnames);

static PyObject * GMPy_MPZ_Function_NumDigits(PyObject *self, PyObject *const *args, Py_ssize_t nargs);
static PyObject * GMPy_MPZ_Function_Iroot(PyObject *self, PyObject *const *args, Py_ssize_t nargs);
Expand Down
15 changes: 15 additions & 0 deletions test/test_mpz.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import pickle
from fractions import Fraction

import pytest
from hypothesis import assume, example, given, settings
from hypothesis.strategies import booleans, integers, sampled_from
from pytest import mark, raises
Expand Down Expand Up @@ -1758,3 +1759,17 @@ def test_issue_312():
assert not is_prime(1 - 2**4423)
assert all(not is_prime(-a) for a in range(8))
assert next_prime(-8) == 2


def test_mpz_array():
numpy = pytest.importorskip('numpy')
i = 5579686107214117131790972086716881
m = gmpy2.mpz(i)
assert numpy.longdouble(m) == numpy.longdouble(i)
assert m.__array__(dtype=numpy.longdouble) == numpy.longdouble(i)

raises(TypeError, lambda: m.__array__(1, 2, 3))
raises(TypeError, lambda: m.__array__(dtype=None, copy=None, spam=123))
raises(TypeError, lambda: m.__array__(int, dtype=None))
raises(TypeError, lambda: m.__array__(int, None, copy=None))
raises(TypeError, lambda: m.__array__(spam=123))

0 comments on commit 5099d35

Please sign in to comment.