diff --git a/c_sources/CMakeLists.txt b/c_sources/CMakeLists.txt index 2e31ddd..86a1ec3 100644 --- a/c_sources/CMakeLists.txt +++ b/c_sources/CMakeLists.txt @@ -67,6 +67,7 @@ matlab_add_mex( NAME osqp_mex SRC ${CMAKE_CURRENT_SOURCE_DIR}/osqp_mex.cpp ${CMAKE_CURRENT_SOURCE_DIR}/interrupt_matlab.c ${CMAKE_CURRENT_SOURCE_DIR}/memory_matlab.c + ${CMAKE_CURRENT_SOURCE_DIR}/settings_matlab.cpp LINK_TO osqpstatic ${UT_LIBRARY} # Force compilation in the traditional C API (equivalent to the -R2017b flag) diff --git a/c_sources/osqp_mex.cpp b/c_sources/osqp_mex.cpp index 1fa3e5f..d97b5d4 100755 --- a/c_sources/osqp_mex.cpp +++ b/c_sources/osqp_mex.cpp @@ -1,10 +1,13 @@ +#include + #include "mex.h" #include "matrix.h" -#include "osqp_mex.hpp" #include "osqp.h" -#include "memory_matlab.h" -#include +// Mex-specific functionality +#include "osqp_mex.hpp" +#include "memory_matlab.h" +#include "settings_matlab.h" //c_int is replaced with OSQPInt //c_float is replaced with OSQPFloat @@ -30,36 +33,6 @@ const char* OSQP_INFO_FIELDS[] = {"status", //char* "run_time", //OSQPFloat }; -const char* OSQP_SETTINGS_FIELDS[] = {"device", //OSQPInt - "linsys_solver", //enum osqp_linsys_solver_type - "verbose", //OSQPInt - "warm_starting", //OSQPInt - "scaling", //OSQPInt - "polishing", //OSQPInt - "rho", //OSQPFloat - "rho_is_vec", //OSQPInt - "sigma", //OSQPFloat - "alpha", //OSQPFloat - "cg_max_iter", //OSQPInt - "cg_tol_reduction", //OSQPInt - "cg_tol_fraction", //OSQPFloat - "cg_precond", //osqp_precond_type - "adaptive_rho", //OSQPInt - "adaptive_rho_interval", //OSQPInt - "adaptive_rho_fraction", //OSQPFloat - "adaptive_rho_tolerance", //OSQPFloat - "max_iter", //OSQPInt - "eps_abs", //OSQPFloat - "eps_rel", //OSQPFloat - "eps_prim_inf", //OSQPFloat - "eps_dual_inf", //OSQPFloat - "scaled_termination", //OSQPInt - "check_termination", //OSQPInt - "time_limit", //OSQPFloat - "delta", //OSQPFloat - "polish_refine_iter", //OSQPInt - }; - #define NEW_SETTINGS_TOL (1e-10) // wrapper class for all osqp data and settings @@ -74,15 +47,12 @@ class OsqpData OSQPSolver* initializeOSQPSolver(); void castToDoubleArr(OSQPFloat *arr, double* arr_out, OSQPInt len); void setToNaN(double* arr_out, OSQPInt len); -void copyMxStructToSettings(const mxArray*, OSQPSettings*); -void copyUpdatedSettingsToWork(const mxArray*, OSQPSolver*); //void castCintToDoubleArr(OSQPInt *arr, double* arr_out, OSQPInt len); //DELETE HERE void freeCscMatrix(OSQPCscMatrix* M); OSQPInt* copyToOSQPIntVector(mwIndex * vecData, OSQPInt numel); OSQPInt* copyDoubleToOSQPIntVector(double* vecData, OSQPInt numel); OSQPFloat* copyToOSQPFloatVector(double * vecData, OSQPInt numel); mxArray* copyInfoToMxStruct(OSQPInfo* info); -mxArray* copySettingsToMxStruct(OSQPSettings* settings); void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[]) @@ -136,28 +106,33 @@ void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[]) // report the current settings if (!strcmp("current_settings", cmd)) { - //throw an error if this is called before solver is configured - if(!osqpData->solver) mexErrMsgTxt("Solver is uninitialized. No settings have been configured."); - if(!osqpData->solver->settings){ - mexErrMsgTxt("Solver settings is uninitialized. No settings have been configured."); - } - //report the current settings - plhs[0] = copySettingsToMxStruct(osqpData->solver->settings); - return; + // Throw an error if this is called before solver is configured + if(!osqpData->solver) { + mexErrMsgTxt("Solver is uninitialized. No settings have been configured."); + } + if(!osqpData->solver->settings) { + mexErrMsgTxt("Solver settings is uninitialized. No settings have been configured."); + } + + // Report the current settings + OSQPSettingsWrapper settings(osqpData->solver->settings); + plhs[0] = settings.GetMxStruct(); + return; } // write_settings if (!strcmp("update_settings", cmd)) { - //overwrite the current settings. Mex function is responsible - //for disallowing overwrite of selected settings after initialization, - //and for all error checking - //throw an error if this is called before solver is configured - if(!osqpData->solver){ - mexErrMsgTxt("Solver is uninitialized. No settings have been configured."); - } + // Overwrite the current settings. Mex function is responsible + // for disallowing overwrite of selected settings after initialization, + // and for all error checking + // throw an error if this is called before solver is configured + if(!osqpData->solver){ + mexErrMsgTxt("Solver is uninitialized. No settings have been configured."); + } - copyUpdatedSettingsToWork(prhs[2],osqpData->solver); - return; + OSQPSettingsWrapper settings(prhs[2]); + osqp_update_settings(osqpData->solver, settings.GetOSQPSettings()); + return; } // Update rho value @@ -180,14 +155,12 @@ void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[]) mexWarnMsgTxt("Default settings: unexpected number of arguments."); - //Create a Settings structure in default form and report the results - //Useful for external solver packages (e.g. Yalmip) that want to - //know which solver settings are supported - OSQPSettings* defaults = (OSQPSettings *)mxCalloc(1,sizeof(OSQPSettings)); - osqp_set_default_settings(defaults); - plhs[0] = copySettingsToMxStruct(defaults); - mxFree(defaults); - return; + // Create a Settings structure in default form and report the results + // Useful for external solver packages (e.g. Yalmip) that want to + // know which solver settings are supported + OSQPSettingsWrapper settings; + plhs[0] = settings.GetMxStruct(); + return; } // setup @@ -196,8 +169,6 @@ void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[]) if(osqpData->solver){ mexErrMsgTxt("Solver is already initialized with problem data."); } - //Create data and settings containers - OSQPSettings* settings = (OSQPSettings *)mxCalloc(1,sizeof(OSQPSettings)); // handle the problem data first. Matlab-side // class wrapper is responsible for ensuring that @@ -234,18 +205,16 @@ void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[]) csc_set_data(dataA, dataM, dataN, Ap[dataN], Ax, Ai, Ap); // Create Settings - const mxArray* mxSettings = prhs[9]; - if(mxIsEmpty(mxSettings)){ - // use defaults - osqp_set_default_settings(settings); - } else { - //populate settings structure from mxArray input - copyMxStructToSettings(mxSettings, settings); + OSQPSettingsWrapper settings; + + if(!mxIsEmpty(prhs[9])){ + // Populate settings structure from mxArray input, otherwise the default settings are used + settings.ParseMxStruct(prhs[9]); } // Setup workspace //exitflag = osqp_setup(&(osqpData->work), data, settings); - exitflag = osqp_setup(&(osqpData->solver), dataP, dataQ, dataA, dataL, dataU, dataM, dataN, settings); + exitflag = osqp_setup(&(osqpData->solver), dataP, dataQ, dataA, dataL, dataU, dataM, dataN, settings.GetOSQPSettings()); //cleanup temporary structures // Data if (Px) c_free(Px); @@ -259,8 +228,6 @@ void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[]) if (dataU) c_free(dataU); if (dataP) c_free(dataP); if (dataA) c_free(dataA); - // Settings - if (settings) c_free(settings); // Report error (if any) if(exitflag){ @@ -678,129 +645,3 @@ mxArray* copyInfoToMxStruct(OSQPInfo* info){ return mxPtr; } - -mxArray* copySettingsToMxStruct(OSQPSettings* settings){ - - int nfields = sizeof(OSQP_SETTINGS_FIELDS) / sizeof(OSQP_SETTINGS_FIELDS[0]); - mxArray* mxPtr = mxCreateStructMatrix(1,1,nfields,OSQP_SETTINGS_FIELDS); - - //map the OSQP_SETTINGS fields one at a time into mxArrays - //matlab handles everything as a double - mxSetField(mxPtr, 0, "rho", mxCreateDoubleScalar(settings->rho)); - mxSetField(mxPtr, 0, "sigma", mxCreateDoubleScalar(settings->sigma)); - mxSetField(mxPtr, 0, "scaling", mxCreateDoubleScalar(settings->scaling)); - mxSetField(mxPtr, 0, "adaptive_rho", mxCreateDoubleScalar(settings->adaptive_rho)); - mxSetField(mxPtr, 0, "adaptive_rho_interval", mxCreateDoubleScalar(settings->adaptive_rho_interval)); - mxSetField(mxPtr, 0, "adaptive_rho_tolerance", mxCreateDoubleScalar(settings->adaptive_rho_tolerance)); - mxSetField(mxPtr, 0, "adaptive_rho_fraction", mxCreateDoubleScalar(settings->adaptive_rho_fraction)); - mxSetField(mxPtr, 0, "max_iter", mxCreateDoubleScalar(settings->max_iter)); - mxSetField(mxPtr, 0, "eps_abs", mxCreateDoubleScalar(settings->eps_abs)); - mxSetField(mxPtr, 0, "eps_rel", mxCreateDoubleScalar(settings->eps_rel)); - mxSetField(mxPtr, 0, "eps_prim_inf", mxCreateDoubleScalar(settings->eps_prim_inf)); - mxSetField(mxPtr, 0, "eps_dual_inf", mxCreateDoubleScalar(settings->eps_dual_inf)); - mxSetField(mxPtr, 0, "alpha", mxCreateDoubleScalar(settings->alpha)); - mxSetField(mxPtr, 0, "linsys_solver", mxCreateDoubleScalar(settings->linsys_solver)); - mxSetField(mxPtr, 0, "delta", mxCreateDoubleScalar(settings->delta)); - mxSetField(mxPtr, 0, "polish_refine_iter", mxCreateDoubleScalar(settings->polish_refine_iter)); - mxSetField(mxPtr, 0, "verbose", mxCreateDoubleScalar(settings->verbose)); - mxSetField(mxPtr, 0, "scaled_termination", mxCreateDoubleScalar(settings->scaled_termination)); - mxSetField(mxPtr, 0, "check_termination", mxCreateDoubleScalar(settings->check_termination)); - mxSetField(mxPtr, 0, "warm_starting", mxCreateDoubleScalar(settings->warm_starting)); - mxSetField(mxPtr, 0, "time_limit", mxCreateDoubleScalar(settings->time_limit)); - mxSetField(mxPtr, 0, "device", mxCreateDoubleScalar(settings->device)); - mxSetField(mxPtr, 0, "polishing", mxCreateDoubleScalar(settings->polishing)); - mxSetField(mxPtr, 0, "rho_is_vec", mxCreateDoubleScalar(settings->rho_is_vec)); - mxSetField(mxPtr, 0, "cg_max_iter", mxCreateDoubleScalar(settings->cg_max_iter)); - mxSetField(mxPtr, 0, "cg_tol_reduction", mxCreateDoubleScalar(settings->cg_tol_reduction)); - mxSetField(mxPtr, 0, "cg_tol_fraction", mxCreateDoubleScalar(settings->cg_tol_fraction)); - mxSetField(mxPtr, 0, "time_limit", mxCreateDoubleScalar(settings->time_limit)); - mxSetField(mxPtr, 0, "cg_precond", mxCreateDoubleScalar(settings->cg_precond)); - return mxPtr; -} - - -// ====================================================================== - -void copyMxStructToSettings(const mxArray* mxPtr, OSQPSettings* settings){ - - //this function assumes that only a complete and validated structure - //will be passed. matlab mex-side function is responsible for checking - //structure validity - - //map the OSQP_SETTINGS fields one at a time into mxArrays - //matlab handles everything as a double - settings->rho = (OSQPFloat)mxGetScalar(mxGetField(mxPtr, 0, "rho")); - settings->sigma = (OSQPFloat)mxGetScalar(mxGetField(mxPtr, 0, "sigma")); - settings->scaling = (OSQPInt)mxGetScalar(mxGetField(mxPtr, 0, "scaling")); - settings->adaptive_rho = (OSQPInt)mxGetScalar(mxGetField(mxPtr, 0, "adaptive_rho")); - settings->adaptive_rho_interval = (OSQPInt)mxGetScalar(mxGetField(mxPtr, 0, "adaptive_rho_interval")); - settings->adaptive_rho_tolerance = (OSQPFloat)mxGetScalar(mxGetField(mxPtr, 0, "adaptive_rho_tolerance")); - settings->adaptive_rho_fraction = (OSQPFloat)mxGetScalar(mxGetField(mxPtr, 0, "adaptive_rho_fraction")); - settings->max_iter = (OSQPInt)mxGetScalar(mxGetField(mxPtr, 0, "max_iter")); - settings->eps_abs = (OSQPFloat)mxGetScalar(mxGetField(mxPtr, 0, "eps_abs")); - settings->eps_rel = (OSQPFloat)mxGetScalar(mxGetField(mxPtr, 0, "eps_rel")); - settings->eps_prim_inf = (OSQPFloat)mxGetScalar(mxGetField(mxPtr, 0, "eps_dual_inf")); - settings->eps_dual_inf = (OSQPFloat)mxGetScalar(mxGetField(mxPtr, 0, "eps_dual_inf")); - settings->alpha = (OSQPFloat)mxGetScalar(mxGetField(mxPtr, 0, "alpha")); - settings->linsys_solver = (enum osqp_linsys_solver_type) (OSQPInt) mxGetScalar(mxGetField(mxPtr, 0, "linsys_solver")); - settings->delta = (OSQPFloat)mxGetScalar(mxGetField(mxPtr, 0, "delta")); - settings->polish_refine_iter = (OSQPInt)mxGetScalar(mxGetField(mxPtr, 0, "polish_refine_iter")); - settings->verbose = (OSQPInt)mxGetScalar(mxGetField(mxPtr, 0, "verbose")); - settings->scaled_termination = (OSQPInt)mxGetScalar(mxGetField(mxPtr, 0, "scaled_termination")); - settings->check_termination = (OSQPInt)mxGetScalar(mxGetField(mxPtr, 0, "check_termination")); - settings->warm_starting = (OSQPInt)mxGetScalar(mxGetField(mxPtr, 0, "warm_starting")); - settings->time_limit = (OSQPFloat)mxGetScalar(mxGetField(mxPtr, 0, "time_limit")); - settings->device = (OSQPInt)mxGetScalar(mxGetField(mxPtr, 0, "device")); - settings->polishing = (OSQPInt)mxGetScalar(mxGetField(mxPtr, 0, "polishing")); - settings->rho_is_vec = (OSQPInt)mxGetScalar(mxGetField(mxPtr, 0, "rho_is_vec")); - settings->cg_max_iter = (OSQPInt)mxGetScalar(mxGetField(mxPtr, 0, "cg_max_iter")); - settings->cg_tol_reduction = (OSQPInt)mxGetScalar(mxGetField(mxPtr, 0, "cg_tol_reduction")); - settings->cg_tol_fraction = (OSQPFloat)mxGetScalar(mxGetField(mxPtr, 0, "cg_tol_fraction")); - settings->cg_precond = (osqp_precond_type) (OSQPInt) (mxGetField(mxPtr, 0, "cg_precond")); -} - -void copyUpdatedSettingsToWork(const mxArray* mxPtr ,OSQPSolver* osqpSolver){ - - OSQPInt exitflag; - //TODO (Amit): Update this - OSQPSettings* update_template = (OSQPSettings *)mxCalloc(1,sizeof(OSQPSettings)); - if (!update_template) mexErrMsgTxt("Failed to allocate a temporary OSQPSettings object."); - - update_template->device = (OSQPInt)mxGetScalar(mxGetField(mxPtr, 0, "device")); - update_template->linsys_solver = (enum osqp_linsys_solver_type)mxGetScalar(mxGetField(mxPtr, 0, "linsys_solver")); - update_template->verbose = (OSQPInt)mxGetScalar(mxGetField(mxPtr, 0, "verbose")); - update_template->warm_starting = (OSQPInt)mxGetScalar(mxGetField(mxPtr, 0, "warm_starting")); - update_template->scaling = (OSQPInt)mxGetScalar(mxGetField(mxPtr, 0, "scaling")); - update_template->polishing = (OSQPInt)mxGetScalar(mxGetField(mxPtr, 0, "polishing")); - - update_template->rho = (OSQPFloat)mxGetScalar(mxGetField(mxPtr, 0, "rho")); - update_template->rho_is_vec = (OSQPInt)mxGetScalar(mxGetField(mxPtr, 0, "rho_is_vec")); - update_template->sigma = (OSQPFloat)mxGetScalar(mxGetField(mxPtr, 0, "sigma")); - update_template->alpha = (OSQPFloat)mxGetScalar(mxGetField(mxPtr, 0, "alpha")); - - update_template->cg_max_iter = (OSQPInt)mxGetScalar(mxGetField(mxPtr, 0, "cg_max_iter")); - update_template->cg_tol_reduction = (OSQPInt)mxGetScalar(mxGetField(mxPtr, 0, "cg_tol_reduction")); - update_template->cg_tol_fraction = (OSQPFloat)mxGetScalar(mxGetField(mxPtr, 0, "cg_tol_fraction")); - update_template->cg_precond = (osqp_precond_type)mxGetScalar(mxGetField(mxPtr, 0, "cg_precond")); - - update_template->adaptive_rho = (OSQPInt)mxGetScalar(mxGetField(mxPtr, 0, "adaptive_rho")); - update_template->adaptive_rho_interval = (OSQPInt)mxGetScalar(mxGetField(mxPtr, 0, "adaptive_rho_interval")); - update_template->adaptive_rho_fraction = (OSQPFloat)mxGetScalar(mxGetField(mxPtr, 0, "adaptive_rho_fraction")); - update_template->adaptive_rho_tolerance = (OSQPFloat)mxGetScalar(mxGetField(mxPtr, 0, "adaptive_rho_tolerance")); - - update_template->max_iter = (OSQPInt)mxGetScalar(mxGetField(mxPtr, 0, "max_iter")); - update_template->eps_abs = (OSQPFloat)mxGetScalar(mxGetField(mxPtr, 0, "eps_abs")); - update_template->eps_rel = (OSQPFloat)mxGetScalar(mxGetField(mxPtr, 0, "eps_rel")); - update_template->eps_prim_inf = (OSQPFloat)mxGetScalar(mxGetField(mxPtr, 0, "eps_prim_inf")); - update_template->eps_dual_inf = (OSQPFloat)mxGetScalar(mxGetField(mxPtr, 0, "eps_dual_inf")); - update_template->scaled_termination = (OSQPInt)mxGetScalar(mxGetField(mxPtr, 0, "scaled_termination")); - update_template->check_termination = (OSQPInt)mxGetScalar(mxGetField(mxPtr, 0, "check_termination")); - update_template->time_limit = (OSQPFloat)mxGetScalar(mxGetField(mxPtr, 0, "time_limit")); - - update_template->delta = (OSQPFloat)mxGetScalar(mxGetField(mxPtr, 0, "delta")); - update_template->polish_refine_iter = (OSQPInt)mxGetScalar(mxGetField(mxPtr, 0, "polish_refine_iter")); - - osqp_update_settings(osqpSolver, update_template); - - if (update_template) c_free(update_template); -} \ No newline at end of file diff --git a/c_sources/settings_matlab.cpp b/c_sources/settings_matlab.cpp new file mode 100644 index 0000000..001375a --- /dev/null +++ b/c_sources/settings_matlab.cpp @@ -0,0 +1,103 @@ +#include + +#include "memory_matlab.h" +#include "settings_matlab.h" + +#include + + +void OSQPSettingsWrapper::registerFields() { + m_settings = static_cast(c_calloc(1, sizeof(OSQPSettings))); + + if(!m_settings) + mexErrMsgTxt("Failed to allocate a OSQPSettings object."); + + osqp_set_default_settings(m_settings); + + /* + * Register the mapping between struct field name and the settings memory location + */ + m_settingsFields.push_back(new OSQPSettingsField(&m_settings->device, "device")); + m_settingsFields.push_back(new OSQPSettingsField(&m_settings->linsys_solver, "linsys_solver")); + m_settingsFields.push_back(new OSQPSettingsField(&m_settings->verbose, "verbose")); + m_settingsFields.push_back(new OSQPSettingsField(&m_settings->warm_starting, "warm_starting")); + m_settingsFields.push_back(new OSQPSettingsField(&m_settings->scaling, "scaling")); + m_settingsFields.push_back(new OSQPSettingsField(&m_settings->polishing, "polishing")); + + // ADMM parameters + m_settingsFields.push_back(new OSQPSettingsField(&m_settings->rho, "rho")); + m_settingsFields.push_back(new OSQPSettingsField(&m_settings->rho_is_vec, "rho_is_vec")); + m_settingsFields.push_back(new OSQPSettingsField(&m_settings->sigma, "sigma")); + m_settingsFields.push_back(new OSQPSettingsField(&m_settings->alpha, "alpha")); + + // CG settings + m_settingsFields.push_back(new OSQPSettingsField(&m_settings->cg_max_iter, "cg_max_iter")); + m_settingsFields.push_back(new OSQPSettingsField(&m_settings->cg_tol_reduction, "cg_tol_reduction")); + m_settingsFields.push_back(new OSQPSettingsField(&m_settings->cg_tol_fraction, "cg_tol_fraction")); + m_settingsFields.push_back(new OSQPSettingsField(&m_settings->cg_precond, "cg_precond")); + + // adaptive rho logic + m_settingsFields.push_back(new OSQPSettingsField(&m_settings->adaptive_rho, "adaptive_rho")); + m_settingsFields.push_back(new OSQPSettingsField(&m_settings->adaptive_rho_interval, "adaptive_rho_interval")); + m_settingsFields.push_back(new OSQPSettingsField(&m_settings->adaptive_rho_fraction, "adaptive_rho_fraction")); + m_settingsFields.push_back(new OSQPSettingsField(&m_settings->adaptive_rho_tolerance, "adaptive_rho_tolerance")); + + // termination parameters + m_settingsFields.push_back(new OSQPSettingsField(&m_settings->max_iter, "max_iter")); + m_settingsFields.push_back(new OSQPSettingsField(&m_settings->eps_abs, "eps_abs")); + m_settingsFields.push_back(new OSQPSettingsField(&m_settings->eps_rel, "eps_rel")); + m_settingsFields.push_back(new OSQPSettingsField(&m_settings->eps_prim_inf, "eps_prim_inf")); + m_settingsFields.push_back(new OSQPSettingsField(&m_settings->eps_dual_inf, "eps_dual_inf")); + m_settingsFields.push_back(new OSQPSettingsField(&m_settings->scaled_termination, "scaled_termination")); + m_settingsFields.push_back(new OSQPSettingsField(&m_settings->check_termination, "check_termination")); + m_settingsFields.push_back(new OSQPSettingsField(&m_settings->time_limit, "time_limit")); + + // polishing parameters + m_settingsFields.push_back(new OSQPSettingsField(&m_settings->delta, "delta")); + m_settingsFields.push_back(new OSQPSettingsField(&m_settings->polish_refine_iter, "polish_refine_iter")); +} + + +OSQPSettingsWrapper::~OSQPSettingsWrapper() { + for(auto& s : m_settingsFields) { + delete s; + } + + c_free(m_settings); +} + + +mxArray* OSQPSettingsWrapper::GetMxStruct() { + // No fields are added right now, they are added in the for loop when they are set + mxArray* mxSettings = mxCreateStructMatrix(1, 1, 0, NULL); + + // Copy the current settings into the struct to return + for(const auto& s : m_settingsFields) { + s->ToMxStruct(mxSettings); + } + + return mxSettings; +} + + +void OSQPSettingsWrapper::ParseMxStruct(const mxArray* aStruct) { + for(const auto& s : m_settingsFields) { + s->ToOSQPSettings(aStruct); + } +} + + +OSQPSettings* OSQPSettingsWrapper::GetOSQPSettingsCopy() { + // Allocate the default settings + OSQPSettings* ret = static_cast(c_calloc(1, sizeof(OSQPSettings))); + + // Copy the current settings for their return + std::memcpy(ret, m_settings, sizeof(ret)); + + return ret; +} + + +void OSQPSettingsWrapper::ParseOSQPSettings(const OSQPSettings* aSettings) { + std::memcpy(m_settings, aSettings, sizeof(m_settings)); +} diff --git a/c_sources/settings_matlab.h b/c_sources/settings_matlab.h new file mode 100644 index 0000000..28ee832 --- /dev/null +++ b/c_sources/settings_matlab.h @@ -0,0 +1,125 @@ +#ifndef SETTINGS_MATLAB_H_ +#define SETTINGS_MATLAB_H_ + +#include +#include +#include + +#include +#include +#include + +/* + * Base class used to store the templated settings field types. + */ +class OSQPSettingsFieldBase { +public: + OSQPSettingsFieldBase() {} + + virtual void ToMxStruct(mxArray* aStruct) = 0; + virtual void ToOSQPSettings(const mxArray* aStruct) = 0; +}; + +template +class OSQPSettingsField : public OSQPSettingsFieldBase { +public: + OSQPSettingsField(T* aSettingPtr, std::string aName) : + m_settingsPtr(aSettingPtr), + m_name(aName) { + } + + /* + * Set the field in the given Matlab struct to the value of this settings field + */ + void ToMxStruct(mxArray* aStruct) override { + mxAddField(aStruct, m_name.data()); + mxSetField(aStruct, 0, m_name.data(), mxCreateDoubleScalar(*m_settingsPtr)); + } + + /* + * Set the field in the internal OSQPSettings struct with the data from aStruct + */ + void ToOSQPSettings(const mxArray* aStruct) override { + *(m_settingsPtr) = static_cast(mxGetScalar(mxGetField(aStruct, 0, m_name.data()))); + } + +private: + T* m_settingsPtr; + std::string m_name; +}; + +class OSQPSettingsWrapper { +public: + /* + * Initialize the settings wrapper using the default settings. + */ + OSQPSettingsWrapper() { + // Allocate the default settings and register field handlers + registerFields(); + } + + /* + * Initialize the settings wrapper using the values from aSettings. + */ + OSQPSettingsWrapper(const OSQPSettings* aSettings) { + // Allocate the default settings and register field handlers + registerFields(); + ParseOSQPSettings(aSettings); + } + + /* + * Initialize the settings wrapper using the values from aStruct + */ + OSQPSettingsWrapper(const mxArray* aStruct) { + // Allocate the default settings and register field handlers + registerFields(); + ParseMxStruct(aStruct); + } + + ~OSQPSettingsWrapper(); + + /* + * Return a Matlab structu populated with the values of the current settings + * contained in this wrapper. + * + * @return a Matlab struct with a copy of the settings (caller owns this copy and must free it) + */ + mxArray* GetMxStruct(); + + /* + * Read a Matlab struct and populate the wrapper with its values. + */ + void ParseMxStruct(const mxArray* aStruct); + + /* + * Get a copy of the settings contained inside this wrapper. + * + * @return a copy of the settings (caller owns this copy and must free it) + */ + OSQPSettings* GetOSQPSettingsCopy(); + + /* + * Get the pointer to the internal settings object. + */ + OSQPSettings* GetOSQPSettings() { + return m_settings; + } + + /* + * Read an existing OSQPSettings object into this wrapper. + * The settings are copied, so no ownership of the aSettings pointer is transferred. + */ + void ParseOSQPSettings(const OSQPSettings* aSettings); + +private: + // Register all the fields + void registerFields(); + + // All settings fields + std::vector m_settingsFields; + + // Base OSQP settings object. Owned by this wrapper. + OSQPSettings* m_settings; +}; + +#endif \ No newline at end of file