-
Notifications
You must be signed in to change notification settings - Fork 1
/
marslib.cc
70 lines (61 loc) · 2.16 KB
/
marslib.cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
#include "marsalgo.h"
#include <omp.h>
#include <pybind11/eigen.h>
#include <pybind11/pybind11.h>
namespace py = pybind11;
typedef Matrix<bool,Dynamic,Dynamic,RowMajor> MatrixXbC;
///////////////////////////////////////////////////////////////////////////////
MarsAlgo * new_algo(
const Ref<const MatrixXf> &X,
const Ref<const VectorXf> &y,
const Ref<const VectorXf> &w,
int max_terms)
{
if (X.rows() != y.rows() || y.rows() != w.rows()) {
throw std::runtime_error("invalid dataset lengths");
}
return new MarsAlgo(X.data(), y.data(), w.data(),
X.rows(), X.cols(), max_terms, X.outerStride());
}
///////////////////////////////////////////////////////////////////////////////
py::tuple eval(MarsAlgo &algo,
const Ref<const MatrixXbC> &mask,
int endspan,
bool linear_only)
{
typedef Array<double,Dynamic,Dynamic,RowMajor> ArrayXXdC;
ArrayXXdC dsse1 = ArrayXXdC::Zero(mask.rows(), mask.cols());
ArrayXXdC dsse2 = ArrayXXdC::Zero(mask.rows(), mask.cols());
ArrayXXdC h_cut = ArrayXXdC::Constant(mask.rows(), mask.cols(), NAN);
#pragma omp parallel for schedule(static)
for (int i = 0; i < mask.rows(); ++i) {
algo.eval(
dsse1.row(i).data(), dsse2.row(i).data(), h_cut.row(i).data(),
i, mask.row(i).data(), endspan, linear_only);
}
return py::make_tuple(algo.dsse(), dsse1, dsse2, h_cut);
}
///////////////////////////////////////////////////////////////////////////////
PYBIND11_MODULE(marslib, m)
{
py::options options;
options.disable_function_signatures();
m.doc() = "Multivariate Adaptive Regression Splines";
m.attr("__version__") = "dev";
py::class_<MarsAlgo>(m, "MarsAlgo")
.def(py::init(&new_algo)
, py::arg("X").noconvert()
, py::arg("y").noconvert()
, py::arg("w").noconvert()
, py::arg("max_terms")
)
.def("eval",&eval
, py::arg("mask").noconvert()
, py::arg("endspan")
, py::arg("linear_only")
)
.def("nbasis", &MarsAlgo::nbasis)
.def("yvar", &MarsAlgo::yvar)
.def("append", &MarsAlgo::append)
;
}