From e0e0d3652f0bfb01b6e49028443c8e72830dc9e8 Mon Sep 17 00:00:00 2001 From: Luke Marshall <52978038+mathgeekcoder@users.noreply.github.com> Date: Mon, 9 Sep 2024 18:42:16 -0700 Subject: [PATCH] * Fixed highspy deadlocking issues * Added GIL acquire for python callbacks * Added support for user_callback_data --- src/highs_bindings.cpp | 33 ++++++++++++++++++++------------- src/highspy/highs.py | 18 +----------------- 2 files changed, 21 insertions(+), 30 deletions(-) diff --git a/src/highs_bindings.cpp b/src/highs_bindings.cpp index 57e21fc72b..413549ad5b 100644 --- a/src/highs_bindings.cpp +++ b/src/highs_bindings.cpp @@ -582,16 +582,26 @@ std::tuple highs_getRowByName(Highs* h, return std::make_tuple(status, row); } - -HighsStatus highs_run(Highs* h) -{ - py::gil_scoped_release release; - HighsStatus status = h->run(); - py::gil_scoped_acquire(); - return status; +// Wrap the setCallback function. Pass a lambda wrapper around the python +// function that acquires the GIL and appropriately handle user data passed to +// the callback +HighsStatus highs_setCallback( + Highs* h, + std::function + fn, + py::handle data) { + return h->setCallback( + [fn, data](int callbackType, const std::string& msg, + const HighsCallbackDataOut* dataOut, + HighsCallbackDataIn* dataIn, void* d) { + py::gil_scoped_acquire acquire; + return fn(callbackType, msg, dataOut, dataIn, + py::handle(reinterpret_cast(d))); + }, + data.ptr()); } - PYBIND11_MODULE(_core, m) { // enum classes py::enum_(m, "ObjSense") @@ -889,7 +899,7 @@ PYBIND11_MODULE(_core, m) { .def("writeBasis", &Highs::writeBasis) .def("postsolve", &highs_postsolve) .def("postsolve", &highs_mipPostsolve) - .def("run", &highs_run) + .def("run", &Highs::run, py::call_guard()) .def("feasibilityRelaxation", [](Highs& self, double global_lower_penalty, double global_upper_penalty, double global_rhs_penalty, py::object local_lower_penalty, py::object local_upper_penalty, py::object local_rhs_penalty) { @@ -1021,10 +1031,7 @@ PYBIND11_MODULE(_core, m) { .def("solutionStatusToString", &Highs::solutionStatusToString) .def("basisStatusToString", &Highs::basisStatusToString) .def("basisValidityToString", &Highs::basisValidityToString) - .def( - "setCallback", - static_cast( - &Highs::setCallback)) + .def("setCallback", &highs_setCallback) .def("startCallback", static_cast( &Highs::startCallback)) diff --git a/src/highspy/highs.py b/src/highspy/highs.py index aa925f5261..fb43afbbe9 100644 --- a/src/highspy/highs.py +++ b/src/highspy/highs.py @@ -38,11 +38,6 @@ from threading import Thread, local import numpy as np -class _ThreadingResult: - def __init__(self): - self.out = None - - class Highs(_Highs): """ HiGHS solver interface @@ -56,13 +51,6 @@ def silent(self, turn_off_output=True): Disables solver output to the console. """ super().setOptionValue("output_flag", not turn_off_output) - - def _run(self, res): - res.out = super().run() - - def run(self): - return self.solve() - # solve def solve(self): @@ -71,11 +59,7 @@ def solve(self): Returns: A HighsStatus object containing the solve status. """ - res = _ThreadingResult() - t = Thread(target=self._run, args=(res,)) - t.start() - t.join() - return res.out + return super().run() def optimize(self): """