Skip to content

Commit

Permalink
Restructure settings handling to make it easier to maintain
Browse files Browse the repository at this point in the history
  • Loading branch information
imciner2 committed Nov 20, 2023
1 parent c3d24fc commit cc96cf2
Show file tree
Hide file tree
Showing 4 changed files with 269 additions and 199 deletions.
1 change: 1 addition & 0 deletions c_sources/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ matlab_add_mex( NAME osqp_mex
SRC ${CMAKE_CURRENT_SOURCE_DIR}/osqp_mex.cpp
${CMAKE_CURRENT_SOURCE_DIR}/interrupt_matlab.c
${CMAKE_CURRENT_SOURCE_DIR}/memory_matlab.c
${CMAKE_CURRENT_SOURCE_DIR}/settings_matlab.cpp
LINK_TO osqpstatic
${UT_LIBRARY}
# Force compilation in the traditional C API (equivalent to the -R2017b flag)
Expand Down
239 changes: 40 additions & 199 deletions c_sources/osqp_mex.cpp
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
#include <map>

#include "mex.h"
#include "matrix.h"
#include "osqp_mex.hpp"
#include "osqp.h"
#include "memory_matlab.h"

#include <map>
// Mex-specific functionality
#include "osqp_mex.hpp"
#include "memory_matlab.h"
#include "settings_matlab.h"

//c_int is replaced with OSQPInt
//c_float is replaced with OSQPFloat
Expand All @@ -30,36 +33,6 @@ const char* OSQP_INFO_FIELDS[] = {"status", //char*
"run_time", //OSQPFloat
};

const char* OSQP_SETTINGS_FIELDS[] = {"device", //OSQPInt
"linsys_solver", //enum osqp_linsys_solver_type
"verbose", //OSQPInt
"warm_starting", //OSQPInt
"scaling", //OSQPInt
"polishing", //OSQPInt
"rho", //OSQPFloat
"rho_is_vec", //OSQPInt
"sigma", //OSQPFloat
"alpha", //OSQPFloat
"cg_max_iter", //OSQPInt
"cg_tol_reduction", //OSQPInt
"cg_tol_fraction", //OSQPFloat
"cg_precond", //osqp_precond_type
"adaptive_rho", //OSQPInt
"adaptive_rho_interval", //OSQPInt
"adaptive_rho_fraction", //OSQPFloat
"adaptive_rho_tolerance", //OSQPFloat
"max_iter", //OSQPInt
"eps_abs", //OSQPFloat
"eps_rel", //OSQPFloat
"eps_prim_inf", //OSQPFloat
"eps_dual_inf", //OSQPFloat
"scaled_termination", //OSQPInt
"check_termination", //OSQPInt
"time_limit", //OSQPFloat
"delta", //OSQPFloat
"polish_refine_iter", //OSQPInt
};

#define NEW_SETTINGS_TOL (1e-10)

// wrapper class for all osqp data and settings
Expand All @@ -74,15 +47,12 @@ class OsqpData
OSQPSolver* initializeOSQPSolver();
void castToDoubleArr(OSQPFloat *arr, double* arr_out, OSQPInt len);
void setToNaN(double* arr_out, OSQPInt len);
void copyMxStructToSettings(const mxArray*, OSQPSettings*);
void copyUpdatedSettingsToWork(const mxArray*, OSQPSolver*);
//void castCintToDoubleArr(OSQPInt *arr, double* arr_out, OSQPInt len); //DELETE HERE
void freeCscMatrix(OSQPCscMatrix* M);
OSQPInt* copyToOSQPIntVector(mwIndex * vecData, OSQPInt numel);
OSQPInt* copyDoubleToOSQPIntVector(double* vecData, OSQPInt numel);
OSQPFloat* copyToOSQPFloatVector(double * vecData, OSQPInt numel);
mxArray* copyInfoToMxStruct(OSQPInfo* info);
mxArray* copySettingsToMxStruct(OSQPSettings* settings);


void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
Expand Down Expand Up @@ -136,28 +106,33 @@ void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])

// report the current settings
if (!strcmp("current_settings", cmd)) {
//throw an error if this is called before solver is configured
if(!osqpData->solver) mexErrMsgTxt("Solver is uninitialized. No settings have been configured.");
if(!osqpData->solver->settings){
mexErrMsgTxt("Solver settings is uninitialized. No settings have been configured.");
}
//report the current settings
plhs[0] = copySettingsToMxStruct(osqpData->solver->settings);
return;
// Throw an error if this is called before solver is configured
if(!osqpData->solver) {
mexErrMsgTxt("Solver is uninitialized. No settings have been configured.");
}
if(!osqpData->solver->settings) {
mexErrMsgTxt("Solver settings is uninitialized. No settings have been configured.");
}

// Report the current settings
OSQPSettingsWrapper settings(osqpData->solver->settings);
plhs[0] = settings.GetMxStruct();
return;
}

// write_settings
if (!strcmp("update_settings", cmd)) {
//overwrite the current settings. Mex function is responsible
//for disallowing overwrite of selected settings after initialization,
//and for all error checking
//throw an error if this is called before solver is configured
if(!osqpData->solver){
mexErrMsgTxt("Solver is uninitialized. No settings have been configured.");
}
// Overwrite the current settings. Mex function is responsible
// for disallowing overwrite of selected settings after initialization,
// and for all error checking
// throw an error if this is called before solver is configured
if(!osqpData->solver){
mexErrMsgTxt("Solver is uninitialized. No settings have been configured.");
}

copyUpdatedSettingsToWork(prhs[2],osqpData->solver);
return;
OSQPSettingsWrapper settings(prhs[2]);
osqp_update_settings(osqpData->solver, settings.GetOSQPSettings());
return;
}

// Update rho value
Expand All @@ -180,14 +155,12 @@ void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
mexWarnMsgTxt("Default settings: unexpected number of arguments.");


//Create a Settings structure in default form and report the results
//Useful for external solver packages (e.g. Yalmip) that want to
//know which solver settings are supported
OSQPSettings* defaults = (OSQPSettings *)mxCalloc(1,sizeof(OSQPSettings));
osqp_set_default_settings(defaults);
plhs[0] = copySettingsToMxStruct(defaults);
mxFree(defaults);
return;
// Create a Settings structure in default form and report the results
// Useful for external solver packages (e.g. Yalmip) that want to
// know which solver settings are supported
OSQPSettingsWrapper settings;
plhs[0] = settings.GetMxStruct();
return;
}

// setup
Expand All @@ -196,8 +169,6 @@ void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
if(osqpData->solver){
mexErrMsgTxt("Solver is already initialized with problem data.");
}
//Create data and settings containers
OSQPSettings* settings = (OSQPSettings *)mxCalloc(1,sizeof(OSQPSettings));

// handle the problem data first. Matlab-side
// class wrapper is responsible for ensuring that
Expand Down Expand Up @@ -234,18 +205,16 @@ void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
csc_set_data(dataA, dataM, dataN, Ap[dataN], Ax, Ai, Ap);

// Create Settings
const mxArray* mxSettings = prhs[9];
if(mxIsEmpty(mxSettings)){
// use defaults
osqp_set_default_settings(settings);
} else {
//populate settings structure from mxArray input
copyMxStructToSettings(mxSettings, settings);
OSQPSettingsWrapper settings;

if(!mxIsEmpty(prhs[9])){
// Populate settings structure from mxArray input, otherwise the default settings are used
settings.ParseMxStruct(prhs[9]);
}

// Setup workspace
//exitflag = osqp_setup(&(osqpData->work), data, settings);
exitflag = osqp_setup(&(osqpData->solver), dataP, dataQ, dataA, dataL, dataU, dataM, dataN, settings);
exitflag = osqp_setup(&(osqpData->solver), dataP, dataQ, dataA, dataL, dataU, dataM, dataN, settings.GetOSQPSettings());
//cleanup temporary structures
// Data
if (Px) c_free(Px);
Expand All @@ -259,8 +228,6 @@ void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
if (dataU) c_free(dataU);
if (dataP) c_free(dataP);
if (dataA) c_free(dataA);
// Settings
if (settings) c_free(settings);

// Report error (if any)
if(exitflag){
Expand Down Expand Up @@ -678,129 +645,3 @@ mxArray* copyInfoToMxStruct(OSQPInfo* info){
return mxPtr;

}

mxArray* copySettingsToMxStruct(OSQPSettings* settings){

int nfields = sizeof(OSQP_SETTINGS_FIELDS) / sizeof(OSQP_SETTINGS_FIELDS[0]);
mxArray* mxPtr = mxCreateStructMatrix(1,1,nfields,OSQP_SETTINGS_FIELDS);

//map the OSQP_SETTINGS fields one at a time into mxArrays
//matlab handles everything as a double
mxSetField(mxPtr, 0, "rho", mxCreateDoubleScalar(settings->rho));
mxSetField(mxPtr, 0, "sigma", mxCreateDoubleScalar(settings->sigma));
mxSetField(mxPtr, 0, "scaling", mxCreateDoubleScalar(settings->scaling));
mxSetField(mxPtr, 0, "adaptive_rho", mxCreateDoubleScalar(settings->adaptive_rho));
mxSetField(mxPtr, 0, "adaptive_rho_interval", mxCreateDoubleScalar(settings->adaptive_rho_interval));
mxSetField(mxPtr, 0, "adaptive_rho_tolerance", mxCreateDoubleScalar(settings->adaptive_rho_tolerance));
mxSetField(mxPtr, 0, "adaptive_rho_fraction", mxCreateDoubleScalar(settings->adaptive_rho_fraction));
mxSetField(mxPtr, 0, "max_iter", mxCreateDoubleScalar(settings->max_iter));
mxSetField(mxPtr, 0, "eps_abs", mxCreateDoubleScalar(settings->eps_abs));
mxSetField(mxPtr, 0, "eps_rel", mxCreateDoubleScalar(settings->eps_rel));
mxSetField(mxPtr, 0, "eps_prim_inf", mxCreateDoubleScalar(settings->eps_prim_inf));
mxSetField(mxPtr, 0, "eps_dual_inf", mxCreateDoubleScalar(settings->eps_dual_inf));
mxSetField(mxPtr, 0, "alpha", mxCreateDoubleScalar(settings->alpha));
mxSetField(mxPtr, 0, "linsys_solver", mxCreateDoubleScalar(settings->linsys_solver));
mxSetField(mxPtr, 0, "delta", mxCreateDoubleScalar(settings->delta));
mxSetField(mxPtr, 0, "polish_refine_iter", mxCreateDoubleScalar(settings->polish_refine_iter));
mxSetField(mxPtr, 0, "verbose", mxCreateDoubleScalar(settings->verbose));
mxSetField(mxPtr, 0, "scaled_termination", mxCreateDoubleScalar(settings->scaled_termination));
mxSetField(mxPtr, 0, "check_termination", mxCreateDoubleScalar(settings->check_termination));
mxSetField(mxPtr, 0, "warm_starting", mxCreateDoubleScalar(settings->warm_starting));
mxSetField(mxPtr, 0, "time_limit", mxCreateDoubleScalar(settings->time_limit));
mxSetField(mxPtr, 0, "device", mxCreateDoubleScalar(settings->device));
mxSetField(mxPtr, 0, "polishing", mxCreateDoubleScalar(settings->polishing));
mxSetField(mxPtr, 0, "rho_is_vec", mxCreateDoubleScalar(settings->rho_is_vec));
mxSetField(mxPtr, 0, "cg_max_iter", mxCreateDoubleScalar(settings->cg_max_iter));
mxSetField(mxPtr, 0, "cg_tol_reduction", mxCreateDoubleScalar(settings->cg_tol_reduction));
mxSetField(mxPtr, 0, "cg_tol_fraction", mxCreateDoubleScalar(settings->cg_tol_fraction));
mxSetField(mxPtr, 0, "time_limit", mxCreateDoubleScalar(settings->time_limit));
mxSetField(mxPtr, 0, "cg_precond", mxCreateDoubleScalar(settings->cg_precond));
return mxPtr;
}


// ======================================================================

void copyMxStructToSettings(const mxArray* mxPtr, OSQPSettings* settings){

//this function assumes that only a complete and validated structure
//will be passed. matlab mex-side function is responsible for checking
//structure validity

//map the OSQP_SETTINGS fields one at a time into mxArrays
//matlab handles everything as a double
settings->rho = (OSQPFloat)mxGetScalar(mxGetField(mxPtr, 0, "rho"));
settings->sigma = (OSQPFloat)mxGetScalar(mxGetField(mxPtr, 0, "sigma"));
settings->scaling = (OSQPInt)mxGetScalar(mxGetField(mxPtr, 0, "scaling"));
settings->adaptive_rho = (OSQPInt)mxGetScalar(mxGetField(mxPtr, 0, "adaptive_rho"));
settings->adaptive_rho_interval = (OSQPInt)mxGetScalar(mxGetField(mxPtr, 0, "adaptive_rho_interval"));
settings->adaptive_rho_tolerance = (OSQPFloat)mxGetScalar(mxGetField(mxPtr, 0, "adaptive_rho_tolerance"));
settings->adaptive_rho_fraction = (OSQPFloat)mxGetScalar(mxGetField(mxPtr, 0, "adaptive_rho_fraction"));
settings->max_iter = (OSQPInt)mxGetScalar(mxGetField(mxPtr, 0, "max_iter"));
settings->eps_abs = (OSQPFloat)mxGetScalar(mxGetField(mxPtr, 0, "eps_abs"));
settings->eps_rel = (OSQPFloat)mxGetScalar(mxGetField(mxPtr, 0, "eps_rel"));
settings->eps_prim_inf = (OSQPFloat)mxGetScalar(mxGetField(mxPtr, 0, "eps_dual_inf"));
settings->eps_dual_inf = (OSQPFloat)mxGetScalar(mxGetField(mxPtr, 0, "eps_dual_inf"));
settings->alpha = (OSQPFloat)mxGetScalar(mxGetField(mxPtr, 0, "alpha"));
settings->linsys_solver = (enum osqp_linsys_solver_type) (OSQPInt) mxGetScalar(mxGetField(mxPtr, 0, "linsys_solver"));
settings->delta = (OSQPFloat)mxGetScalar(mxGetField(mxPtr, 0, "delta"));
settings->polish_refine_iter = (OSQPInt)mxGetScalar(mxGetField(mxPtr, 0, "polish_refine_iter"));
settings->verbose = (OSQPInt)mxGetScalar(mxGetField(mxPtr, 0, "verbose"));
settings->scaled_termination = (OSQPInt)mxGetScalar(mxGetField(mxPtr, 0, "scaled_termination"));
settings->check_termination = (OSQPInt)mxGetScalar(mxGetField(mxPtr, 0, "check_termination"));
settings->warm_starting = (OSQPInt)mxGetScalar(mxGetField(mxPtr, 0, "warm_starting"));
settings->time_limit = (OSQPFloat)mxGetScalar(mxGetField(mxPtr, 0, "time_limit"));
settings->device = (OSQPInt)mxGetScalar(mxGetField(mxPtr, 0, "device"));
settings->polishing = (OSQPInt)mxGetScalar(mxGetField(mxPtr, 0, "polishing"));
settings->rho_is_vec = (OSQPInt)mxGetScalar(mxGetField(mxPtr, 0, "rho_is_vec"));
settings->cg_max_iter = (OSQPInt)mxGetScalar(mxGetField(mxPtr, 0, "cg_max_iter"));
settings->cg_tol_reduction = (OSQPInt)mxGetScalar(mxGetField(mxPtr, 0, "cg_tol_reduction"));
settings->cg_tol_fraction = (OSQPFloat)mxGetScalar(mxGetField(mxPtr, 0, "cg_tol_fraction"));
settings->cg_precond = (osqp_precond_type) (OSQPInt) (mxGetField(mxPtr, 0, "cg_precond"));
}

void copyUpdatedSettingsToWork(const mxArray* mxPtr ,OSQPSolver* osqpSolver){

OSQPInt exitflag;
//TODO (Amit): Update this
OSQPSettings* update_template = (OSQPSettings *)mxCalloc(1,sizeof(OSQPSettings));
if (!update_template) mexErrMsgTxt("Failed to allocate a temporary OSQPSettings object.");

update_template->device = (OSQPInt)mxGetScalar(mxGetField(mxPtr, 0, "device"));
update_template->linsys_solver = (enum osqp_linsys_solver_type)mxGetScalar(mxGetField(mxPtr, 0, "linsys_solver"));
update_template->verbose = (OSQPInt)mxGetScalar(mxGetField(mxPtr, 0, "verbose"));
update_template->warm_starting = (OSQPInt)mxGetScalar(mxGetField(mxPtr, 0, "warm_starting"));
update_template->scaling = (OSQPInt)mxGetScalar(mxGetField(mxPtr, 0, "scaling"));
update_template->polishing = (OSQPInt)mxGetScalar(mxGetField(mxPtr, 0, "polishing"));

update_template->rho = (OSQPFloat)mxGetScalar(mxGetField(mxPtr, 0, "rho"));
update_template->rho_is_vec = (OSQPInt)mxGetScalar(mxGetField(mxPtr, 0, "rho_is_vec"));
update_template->sigma = (OSQPFloat)mxGetScalar(mxGetField(mxPtr, 0, "sigma"));
update_template->alpha = (OSQPFloat)mxGetScalar(mxGetField(mxPtr, 0, "alpha"));

update_template->cg_max_iter = (OSQPInt)mxGetScalar(mxGetField(mxPtr, 0, "cg_max_iter"));
update_template->cg_tol_reduction = (OSQPInt)mxGetScalar(mxGetField(mxPtr, 0, "cg_tol_reduction"));
update_template->cg_tol_fraction = (OSQPFloat)mxGetScalar(mxGetField(mxPtr, 0, "cg_tol_fraction"));
update_template->cg_precond = (osqp_precond_type)mxGetScalar(mxGetField(mxPtr, 0, "cg_precond"));

update_template->adaptive_rho = (OSQPInt)mxGetScalar(mxGetField(mxPtr, 0, "adaptive_rho"));
update_template->adaptive_rho_interval = (OSQPInt)mxGetScalar(mxGetField(mxPtr, 0, "adaptive_rho_interval"));
update_template->adaptive_rho_fraction = (OSQPFloat)mxGetScalar(mxGetField(mxPtr, 0, "adaptive_rho_fraction"));
update_template->adaptive_rho_tolerance = (OSQPFloat)mxGetScalar(mxGetField(mxPtr, 0, "adaptive_rho_tolerance"));

update_template->max_iter = (OSQPInt)mxGetScalar(mxGetField(mxPtr, 0, "max_iter"));
update_template->eps_abs = (OSQPFloat)mxGetScalar(mxGetField(mxPtr, 0, "eps_abs"));
update_template->eps_rel = (OSQPFloat)mxGetScalar(mxGetField(mxPtr, 0, "eps_rel"));
update_template->eps_prim_inf = (OSQPFloat)mxGetScalar(mxGetField(mxPtr, 0, "eps_prim_inf"));
update_template->eps_dual_inf = (OSQPFloat)mxGetScalar(mxGetField(mxPtr, 0, "eps_dual_inf"));
update_template->scaled_termination = (OSQPInt)mxGetScalar(mxGetField(mxPtr, 0, "scaled_termination"));
update_template->check_termination = (OSQPInt)mxGetScalar(mxGetField(mxPtr, 0, "check_termination"));
update_template->time_limit = (OSQPFloat)mxGetScalar(mxGetField(mxPtr, 0, "time_limit"));

update_template->delta = (OSQPFloat)mxGetScalar(mxGetField(mxPtr, 0, "delta"));
update_template->polish_refine_iter = (OSQPInt)mxGetScalar(mxGetField(mxPtr, 0, "polish_refine_iter"));

osqp_update_settings(osqpSolver, update_template);

if (update_template) c_free(update_template);
}
Loading

0 comments on commit cc96cf2

Please sign in to comment.