-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathpython.cc
138 lines (130 loc) · 4.75 KB
/
python.cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
#include <functional>
#include <memory>
#include <Python.h>
#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION
#include <numpy/arrayobject.h>
#include "lap.h"
static char module_docstring[] =
"This module wraps fastlapjv - Jonker-Volgenant linear sum assignment algorithm.";
static char fastlapjv_docstring[] =
"Solves the linear sum assignment problem.";
static PyObject *py_fastlapjv(PyObject *self, PyObject *args, PyObject *kwargs);
static PyMethodDef module_functions[] = {
{"fastlapjv", reinterpret_cast<PyCFunction>(py_fastlapjv),
METH_VARARGS | METH_KEYWORDS, fastlapjv_docstring},
{NULL, NULL, 0, NULL}
};
extern "C" {
PyMODINIT_FUNC PyInit_fastlapjv(void) {
static struct PyModuleDef moduledef = {
PyModuleDef_HEAD_INIT,
"fastlapjv", /* m_name */
module_docstring, /* m_doc */
-1, /* m_size */
module_functions, /* m_methods */
NULL, /* m_reload */
NULL, /* m_traverse */
NULL, /* m_clear */
NULL, /* m_free */
};
PyObject *m = PyModule_Create(&moduledef);
if (m == NULL) {
PyErr_SetString(PyExc_RuntimeError, "PyModule_Create() failed");
return NULL;
}
// numpy
import_array();
return m;
}
}
template <typename O>
using pyobj_parent = std::unique_ptr<O, std::function<void(O*)>>;
template <typename O>
class _pyobj : public pyobj_parent<O> {
public:
_pyobj() : pyobj_parent<O>(
nullptr, [](O *p){ if (p) Py_DECREF(p); }) {}
explicit _pyobj(PyObject *ptr) : pyobj_parent<O>(
reinterpret_cast<O *>(ptr), [](O *p){ if(p) Py_DECREF(p); }) {}
void reset(PyObject *p) noexcept {
pyobj_parent<O>::reset(reinterpret_cast<O*>(p));
}
};
using pyobj = _pyobj<PyObject>;
using pyarray = _pyobj<PyArrayObject>;
static PyObject *py_fastlapjv(PyObject *self, PyObject *args, PyObject *kwargs) {
PyObject *cost_matrix_obj;
int verbose = 0;
int force_doubles = 0;
int k_value = 1000;
static const char *kwlist[] = {
"cost_matrix", "verbose", "force_doubles", "k_value", NULL};
if (!PyArg_ParseTupleAndKeywords(
args, kwargs, "O|pbi", const_cast<char**>(kwlist),
&cost_matrix_obj, &verbose, &force_doubles, &k_value)) {
return NULL;
}
pyarray cost_matrix_array;
bool float32 = true;
cost_matrix_array.reset(PyArray_FROM_OTF(
cost_matrix_obj, NPY_FLOAT32,
NPY_ARRAY_IN_ARRAY | (force_doubles? 0 : NPY_ARRAY_FORCECAST)));
if (!cost_matrix_array) {
PyErr_Clear();
float32 = false;
cost_matrix_array.reset(PyArray_FROM_OTF(
cost_matrix_obj, NPY_FLOAT64, NPY_ARRAY_IN_ARRAY));
if (!cost_matrix_array) {
PyErr_SetString(PyExc_ValueError, "\"cost_matrix\" must be a numpy array "
"of float32 or float64 dtype");
return NULL;
}
}
auto ndims = PyArray_NDIM(cost_matrix_array.get());
if (ndims != 2) {
PyErr_SetString(PyExc_ValueError,
"\"cost_matrix\" must be a square 2D numpy array");
return NULL;
}
auto dims = PyArray_DIMS(cost_matrix_array.get());
if (dims[0] != dims[1]) {
PyErr_SetString(PyExc_ValueError,
"\"cost_matrix\" must be a square 2D numpy array");
return NULL;
}
int dim = dims[0];
if (dim <= 0) {
PyErr_SetString(PyExc_ValueError,
"\"cost_matrix\"'s shape is invalid or too large");
return NULL;
}
auto cost_matrix = PyArray_DATA(cost_matrix_array.get());
npy_intp ret_dims[] = {dim, 0};
pyarray row_ind_array(PyArray_SimpleNew(1, ret_dims, NPY_INT));
pyarray col_ind_array(PyArray_SimpleNew(1, ret_dims, NPY_INT));
auto row_ind = reinterpret_cast<int*>(PyArray_DATA(row_ind_array.get()));
auto col_ind = reinterpret_cast<int*>(PyArray_DATA(col_ind_array.get()));
pyarray u_array(PyArray_SimpleNew(
1, ret_dims, float32? NPY_FLOAT32 : NPY_FLOAT64));
pyarray v_array(PyArray_SimpleNew(
1, ret_dims, float32? NPY_FLOAT32 : NPY_FLOAT64));
float lapcost;
if (float32) {
auto u = reinterpret_cast<float*>(PyArray_DATA(u_array.get()));
auto v = reinterpret_cast<float*>(PyArray_DATA(v_array.get()));
Py_BEGIN_ALLOW_THREADS
lapcost = lap(dim, reinterpret_cast<float*>(cost_matrix), verbose,
row_ind, col_ind, u, v, k_value);
Py_END_ALLOW_THREADS
} else {
auto u = reinterpret_cast<float*>(PyArray_DATA(u_array.get()));
auto v = reinterpret_cast<float*>(PyArray_DATA(v_array.get()));
Py_BEGIN_ALLOW_THREADS
lapcost = lap(dim, reinterpret_cast<float*>(cost_matrix), verbose,
row_ind, col_ind, u, v, k_value);
Py_END_ALLOW_THREADS
}
return Py_BuildValue("(OO(dOO))",
row_ind_array.get(), col_ind_array.get(), lapcost,
u_array.get(), v_array.get());
}