From a556d3553c2d69c56a49ad810e861b643a472328 Mon Sep 17 00:00:00 2001 From: Amit Solomon Date: Thu, 11 Apr 2024 16:06:39 -0400 Subject: [PATCH] Updated dy_u/dy_l->dy in bindings.cpp, bindings.cpp.in, interface.py, derivatice_test.py --- src/bindings.cpp | 3 ++- src/bindings.cpp.in | 26 +++++++++----------------- src/osqp/interface.py | 10 ++++------ src/osqp/tests/derivative_test.py | 8 +++----- 4 files changed, 18 insertions(+), 29 deletions(-) diff --git a/src/bindings.cpp b/src/bindings.cpp index c65b99cd..d6ec44f1 100644 --- a/src/bindings.cpp +++ b/src/bindings.cpp @@ -118,7 +118,7 @@ class PyOSQPSolver { OSQPInt update_data_mat(py::object, py::object, py::object, py::object); OSQPInt warm_start(py::object, py::object); OSQPInt solve(); - OSQPInt adjoint_derivative_compute(py::object, py::object, py::object); + OSQPInt adjoint_derivative_compute(py::object, py::object); OSQPInt adjoint_derivative_get_mat(CSC&, CSC&); OSQPInt adjoint_derivative_get_vec(py::object, py::object, py::object); @@ -291,6 +291,7 @@ OSQPInt PyOSQPSolver::adjoint_derivative_compute(const py::object dx, const py:: _dy = (OSQPFloat *)_dy_array.data(); } + return osqp_adjoint_derivative_compute(this->_solver, _dx, _dy); } diff --git a/src/bindings.cpp.in b/src/bindings.cpp.in index 066cbf2b..baca9fd4 100644 --- a/src/bindings.cpp.in +++ b/src/bindings.cpp.in @@ -118,7 +118,7 @@ class PyOSQPSolver { OSQPInt update_data_mat(py::object, py::object, py::object, py::object); OSQPInt warm_start(py::object, py::object); OSQPInt solve(); - OSQPInt adjoint_derivative_compute(py::object, py::object, py::object); + OSQPInt adjoint_derivative_compute(py::object, py::object); OSQPInt adjoint_derivative_get_mat(CSC&, CSC&); OSQPInt adjoint_derivative_get_vec(py::object, py::object, py::object); @@ -273,10 +273,9 @@ OSQPInt PyOSQPSolver::update_data_mat(py::object P_x, py::object P_i, py::object return osqp_update_data_mat(this->_solver, _P_x, _P_i, _P_n, _A_x, _A_i, _A_n); } -OSQPInt PyOSQPSolver::adjoint_derivative_compute(const py::object dx, const py::object dy_l, const py::object dy_u) { +OSQPInt PyOSQPSolver::adjoint_derivative_compute(const py::object dx, const py::object dy) { OSQPFloat* _dx; - OSQPFloat* _dy_l; - OSQPFloat* _dy_u; + OSQPFloat* _dy; if (dx.is_none()) { _dx = NULL; @@ -285,22 +284,15 @@ OSQPInt PyOSQPSolver::adjoint_derivative_compute(const py::object dx, const py:: _dx = (OSQPFloat *)_dx_array.data(); } - if (dy_l.is_none()) { - _dy_l = NULL; + if (dy.is_none()) { + _dy = NULL; } else { - auto _dy_l_array = py::array_t(dy_l); - _dy_l = (OSQPFloat *)_dy_l_array.data(); + auto _dy_array = py::array_t(dy); + _dy = (OSQPFloat *)_dy_array.data(); } - if (dy_u.is_none()) { - _dy_u = NULL; - } else { - auto _dy_u_array = py::array_t(dy_u); - _dy_u = (OSQPFloat *)_dy_u_array.data(); - } - - return osqp_adjoint_derivative_compute(this->_solver, _dx, _dy_l, _dy_u); + return osqp_adjoint_derivative_compute(this->_solver, _dx, _dy); } OSQPInt PyOSQPSolver::adjoint_derivative_get_mat(CSC& dP, CSC& dA) { @@ -489,7 +481,7 @@ PYBIND11_MODULE(@OSQP_EXT_MODULE_NAME@, m) { .def("update_rho", &PyOSQPSolver::update_rho) .def("get_settings", &PyOSQPSolver::get_settings, py::return_value_policy::reference) - .def("adjoint_derivative_compute", &PyOSQPSolver::adjoint_derivative_compute, "dx"_a.none(true), "dy_l"_a.none(true), "dy_u"_a.none(true)) + .def("adjoint_derivative_compute", &PyOSQPSolver::adjoint_derivative_compute, "dx"_a.none(true), "dy"_a.none(true)) .def("adjoint_derivative_get_mat", &PyOSQPSolver::adjoint_derivative_get_mat, "dP"_a, "dA"_a) .def("adjoint_derivative_get_vec", &PyOSQPSolver::adjoint_derivative_get_vec, "dq"_a, "dl"_a, "du"_a) diff --git a/src/osqp/interface.py b/src/osqp/interface.py index b598d5e7..be96f522 100644 --- a/src/osqp/interface.py +++ b/src/osqp/interface.py @@ -433,7 +433,7 @@ def codegen( return folder - def adjoint_derivative_compute(self, dx=None, dy_l=None, dy_u=None): + def adjoint_derivative_compute(self, dx=None, dy=None): """ Compute adjoint derivative after solve. """ @@ -450,12 +450,10 @@ def adjoint_derivative_compute(self, dx=None, dy_l=None, dy_u=None): if results.info.status != 'solved': raise ValueError('Problem has not been solved to optimality. ' 'You cannot take derivatives') - if dy_u is None: - dy_u = np.zeros(self.m) - if dy_l is None: - dy_l = np.zeros(self.m) + if dy is None: + dy = np.zeros(self.m) - self._solver.adjoint_derivative_compute(dx, dy_l, dy_u) + self._solver.adjoint_derivative_compute(dx, dy) def adjoint_derivative_get_mat(self, as_dense=True, dP_as_triu=True): """ diff --git a/src/osqp/tests/derivative_test.py b/src/osqp/tests/derivative_test.py index e953203e..fcccbad2 100644 --- a/src/osqp/tests/derivative_test.py +++ b/src/osqp/tests/derivative_test.py @@ -47,7 +47,7 @@ def get_prob(self, n=10, m=3, P_scale=1.0, A_scale=1.0): return [P, q, A, l, u, true_x, true_yl, true_yu] - def get_grads(self, P, q, A, l, u, true_x, true_yl=None, true_yu=None, mode='qdldl'): + def get_grads(self, P, q, A, l, u, true_x, true_y=None, mode='qdldl'): # Get gradients by solving with osqp m = osqp.OSQP(algebra='builtin') m.setup( @@ -66,12 +66,10 @@ def get_grads(self, P, q, A, l, u, true_x, true_yl=None, true_yu=None, mode='qdl raise ValueError('Problem not solved!') x = results.x y = results.y - yl = -np.minimum(y, 0) - yu = np.maximum(y, 0) - if true_yl is None and true_yu is None: + if true_y is None: m.adjoint_derivative_compute(dx=x - true_x) else: - m.adjoint_derivative_compute(dx=x - true_x, dy_l=yl - true_yl, dy_u=yu - true_yu) + m.adjoint_derivative_compute(dx=x - true_x, dy=y - true_y) dP, dA = m.adjoint_derivative_get_mat() dq, dl, du = m.adjoint_derivative_get_vec()