Skip to content

Commit

Permalink
Refactor how structs are passed to allow cleaner passing of OSQPInfo
Browse files Browse the repository at this point in the history
  • Loading branch information
imciner2 committed Nov 21, 2023
1 parent 7b7f38f commit 5e42763
Show file tree
Hide file tree
Showing 7 changed files with 317 additions and 282 deletions.
3 changes: 2 additions & 1 deletion c_sources/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,8 @@ matlab_add_mex( NAME osqp_mex
SRC ${CMAKE_CURRENT_SOURCE_DIR}/osqp_mex.cpp
${CMAKE_CURRENT_SOURCE_DIR}/interrupt_matlab.c
${CMAKE_CURRENT_SOURCE_DIR}/memory_matlab.c
${CMAKE_CURRENT_SOURCE_DIR}/settings_matlab.cpp
${CMAKE_CURRENT_SOURCE_DIR}/osqp_struct_info.cpp
${CMAKE_CURRENT_SOURCE_DIR}/osqp_struct_settings.cpp
LINK_TO osqpstatic
${UT_LIBRARY}
# Force compilation in the traditional C API (equivalent to the -R2017b flag)
Expand Down
59 changes: 6 additions & 53 deletions c_sources/osqp_mex.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@

// Mex-specific functionality
#include "osqp_mex.hpp"
#include "osqp_struct.h"
#include "memory_matlab.h"
#include "settings_matlab.h"

//c_int is replaced with OSQPInt
//c_float is replaced with OSQPFloat
Expand All @@ -16,23 +16,6 @@
// enum linsys_solver_type { QDLDL_SOLVER, MKL_PARDISO_SOLVER };
#define QDLDL_SOLVER 0 //Based on the previous API

// all of the OSQP_INFO fieldnames as strings
const char* OSQP_INFO_FIELDS[] = {"status", //char*
"status_val", //OSQPInt
"status_polish", //OSQPInt
"obj_val", //OSQPFloat
"prim_res", //OSQPFloat
"dual_res", //OSQPFloat
"iter", //OSQPInt
"rho_updates", //OSQPInt
"rho_estimate", //OSQPFloat
"setup_time", //OSQPFloat
"solve_time", //OSQPFloat
"update_time", //OSQPFloat
"polish_time", //OSQPFloat
"run_time", //OSQPFloat
};

#define NEW_SETTINGS_TOL (1e-10)

// wrapper class for all osqp data and settings
Expand All @@ -52,7 +35,6 @@ void freeCscMatrix(OSQPCscMatrix* M);
OSQPInt* copyToOSQPIntVector(mwIndex * vecData, OSQPInt numel);
OSQPInt* copyDoubleToOSQPIntVector(double* vecData, OSQPInt numel);
OSQPFloat* copyToOSQPFloatVector(double * vecData, OSQPInt numel);
mxArray* copyInfoToMxStruct(OSQPInfo* info);


void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
Expand Down Expand Up @@ -131,7 +113,7 @@ void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
}

OSQPSettingsWrapper settings(prhs[2]);
osqp_update_settings(osqpData->solver, settings.GetOSQPSettings());
osqp_update_settings(osqpData->solver, settings.GetOSQPStruct());
return;
}

Expand Down Expand Up @@ -214,7 +196,7 @@ void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])

// Setup workspace
//exitflag = osqp_setup(&(osqpData->work), data, settings);
exitflag = osqp_setup(&(osqpData->solver), dataP, dataQ, dataA, dataL, dataU, dataM, dataN, settings.GetOSQPSettings());
exitflag = osqp_setup(&(osqpData->solver), dataP, dataQ, dataA, dataL, dataU, dataM, dataN, settings.GetOSQPStruct());
//cleanup temporary structures
// Data
if (Px) c_free(Px);
Expand Down Expand Up @@ -466,7 +448,9 @@ void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
osqpData->solver->info->obj_val = mxGetNaN();
}

plhs[4] = copyInfoToMxStruct(osqpData->solver->info); // Info structure
// Populate the info structure
OSQPInfoWrapper info(osqpData->solver->info);
plhs[4] = info.GetMxStruct();

return;
}
Expand Down Expand Up @@ -614,34 +598,3 @@ void setToNaN(double* arr_out, OSQPInt len){
arr_out[i] = mxGetNaN();
}
}

mxArray* copyInfoToMxStruct(OSQPInfo* info){

//create mxArray with the right number of fields
int nfields = sizeof(OSQP_INFO_FIELDS) / sizeof(OSQP_INFO_FIELDS[0]);
mxArray* mxPtr = mxCreateStructMatrix(1,1,nfields,OSQP_INFO_FIELDS);

//map the OSQP_INFO fields one at a time into mxArrays
//matlab all numeric values as doubles
mxSetField(mxPtr, 0, "iter", mxCreateDoubleScalar(info->iter));
mxSetField(mxPtr, 0, "status", mxCreateString(info->status));
mxSetField(mxPtr, 0, "status_val", mxCreateDoubleScalar(info->status_val));
mxSetField(mxPtr, 0, "status_polish", mxCreateDoubleScalar(info->status_polish));
mxSetField(mxPtr, 0, "obj_val", mxCreateDoubleScalar(info->obj_val));
mxSetField(mxPtr, 0, "prim_res", mxCreateDoubleScalar(info->prim_res));
mxSetField(mxPtr, 0, "dual_res", mxCreateDoubleScalar(info->dual_res));

mxSetField(mxPtr, 0, "setup_time", mxCreateDoubleScalar(info->setup_time));
mxSetField(mxPtr, 0, "solve_time", mxCreateDoubleScalar(info->solve_time));
mxSetField(mxPtr, 0, "update_time", mxCreateDoubleScalar(info->update_time));
mxSetField(mxPtr, 0, "polish_time", mxCreateDoubleScalar(info->polish_time));
mxSetField(mxPtr, 0, "run_time", mxCreateDoubleScalar(info->run_time));


mxSetField(mxPtr, 0, "rho_updates", mxCreateDoubleScalar(info->rho_updates));
mxSetField(mxPtr, 0, "rho_estimate", mxCreateDoubleScalar(info->rho_estimate));


return mxPtr;

}
205 changes: 205 additions & 0 deletions c_sources/osqp_struct.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,205 @@
#ifndef OSQP_STRUCT_H_
#define OSQP_STRUCT_H_

#include <cstring>
#include <functional>
#include <string>
#include <vector>

#include <mex.h>
#include <matrix.h>

#include "memory_matlab.h"
#include <osqp.h>

/**
* Base class used to store the field types for a struct.
*/
class OSQPStructFieldBase {
public:
OSQPStructFieldBase() {}

/**
* Set the field in the given Matlab struct to the value of this field
*/
virtual void ToMxStruct(mxArray* aStruct) = 0;

/**
* Set the field in the internal struct with the data from aStruct
*/
virtual void ToOSQPStruct(const mxArray* aStruct) = 0;
};

/**
* Class to hold a numeric struct field (e.g. float/double/int/enum, etc.).
*/
template<class T>
class OSQPStructField : public OSQPStructFieldBase {
public:
OSQPStructField(T* aStructPtr, std::string aName) :
m_structPtr(aStructPtr),
m_name(aName) {
}

void ToMxStruct(mxArray* aStruct) override {
mxAddField(aStruct, m_name.data());
mxSetField(aStruct, 0, m_name.data(), mxCreateDoubleScalar(*m_structPtr));
}

void ToOSQPStruct(const mxArray* aStruct) override {
*(m_structPtr) = static_cast<T>(mxGetScalar(mxGetField(aStruct, 0, m_name.data())));
}

private:
T* m_structPtr;
std::string m_name;
};

/**
* Class to hold a character array (actual array, not char* array) field in a struct.
*/
class OSQPStructFieldCharArray : public OSQPStructFieldBase {
public:
OSQPStructFieldCharArray(char* aStructPtr, size_t aLength, std::string aName) :
m_structPtr(aStructPtr),
m_name(aName),
m_length(aLength) {
}

void ToMxStruct(mxArray* aStruct) override {
mxAddField(aStruct, m_name.data());
mxSetField(aStruct, 0, m_name.data(), mxCreateString(m_structPtr));
}

void ToOSQPStruct(const mxArray* aStruct) override {
mxArray* tmp = mxGetField(aStruct, 0, m_name.data());
mxGetString(tmp, m_structPtr, m_length);
}

private:
char* m_structPtr;
std::string m_name;
size_t m_length;
};

/**
* Wrap a struct from OSQP to automatically transfer the data between OSQP and Matlab.
*/
template<class T>
class OSQPStructWrapper {
public:
/**
* Initialize the wrapper using the default values.
*/
OSQPStructWrapper() {
// Allocate the default struct and register field handlers
registerFields();
}

/**
* Initialize the wrapper using the values from the OSQP struct.
*/
OSQPStructWrapper(const T* aStruct) {
// Allocate the default struct and register field handlers
registerFields();
ParseOSQPStruct(aStruct);
}

/**
* Initialize the wrapper using the values from the Matlab struct
*/
OSQPStructWrapper(const mxArray* aStruct) {
// Allocate the default struct and register field handlers
registerFields();
ParseMxStruct(aStruct);
}

~OSQPStructWrapper() {
for(auto& s : m_structFields) {
delete s;
}

c_free(m_struct);
}

/**
* Return a Matlab struct populated with the values of the current struct
* contained in this wrapper.
*
* @return a Matlab struct with a copy of the struct (caller owns this copy and must free it)
*/
mxArray* GetMxStruct() {
// No fields are added right now, they are added in the for loop when they are set
mxArray* matStruct = mxCreateStructMatrix(1, 1, 0, NULL);

// Copy the current struct into the struct to return
for(const auto& s : m_structFields) {
s->ToMxStruct(matStruct);
}

return matStruct;
}

/**
* Read a Matlab struct and populate the wrapper with its values.
*/
void ParseMxStruct(const mxArray* aStruct) {
for(const auto& s : m_structFields) {
s->ToOSQPStruct(aStruct);
}
}

/**
* Get a copy of the struct contained inside this wrapper.
*
* @return a copy of the struct (caller owns this copy and must free it)
*/
T* GetOSQPStructCopy() {
// Allocate the default struct
T* ret = static_cast<T*>(c_calloc(1, sizeof(T)));

// Copy the current values for their return
std::memcpy(ret, m_struct, sizeof(T));
return ret;
}

/**
* Get the pointer to the internal struct object.
*/
T* GetOSQPStruct() {
return m_struct;
}

/*
* Read an existing OSQP struct object into this wrapper.
* The struct elements are copied, so no ownership of the aStruct pointer is transferred.
*/
void ParseOSQPStruct(const T* aStruct) {
std::memcpy(m_struct, aStruct, sizeof(T));
}

private:
/**
* Register all the fields for the wrapper.
* This function should be specialized for each struct type to map the fields appropriately.
*/
void registerFields();

// All struct fields
std::vector<OSQPStructFieldBase*> m_structFields;

// Base OSQP struct object. Owned by this wrapper.
T* m_struct;
};

/**
* Wrapper around the OSQPSettings struct
*/
typedef OSQPStructWrapper<OSQPSettings> OSQPSettingsWrapper;

/**
* Wrapper around the OSQPInfo struct
*/
typedef OSQPStructWrapper<OSQPInfo> OSQPInfoWrapper;

#endif
43 changes: 43 additions & 0 deletions c_sources/osqp_struct_info.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
#include <osqp.h>
#include "osqp_struct.h"


/*
* Specialization of the struct wrapper for the OSQPInfo struct.
*/
template<>
void OSQPStructWrapper<OSQPInfo>::registerFields() {
m_struct = static_cast<OSQPInfo*>(c_calloc(1, sizeof(OSQPInfo)));

if(!m_struct)
mexErrMsgTxt("Failed to allocate a OSQPInfo object.");

/*
* Register the mapping between struct field name and the info struct memory location
*/
// Solver status
m_structFields.push_back(new OSQPStructFieldCharArray(m_struct->status, 32, "status"));
m_structFields.push_back(new OSQPStructField<OSQPInt>(&m_struct->status_val, "status_val"));
m_structFields.push_back(new OSQPStructField<OSQPInt>(&m_struct->status_polish, "status_polish"));

// Solution quality
m_structFields.push_back(new OSQPStructField<OSQPFloat>(&m_struct->obj_val, "obj_val"));
m_structFields.push_back(new OSQPStructField<OSQPFloat>(&m_struct->prim_res, "prim_res"));
m_structFields.push_back(new OSQPStructField<OSQPFloat>(&m_struct->dual_res, "dual_res"));

// Algorithm information
m_structFields.push_back(new OSQPStructField<OSQPInt>(&m_struct->iter, "iter"));
m_structFields.push_back(new OSQPStructField<OSQPInt>(&m_struct->rho_updates, "rho_updates"));
m_structFields.push_back(new OSQPStructField<OSQPFloat>(&m_struct->rho_estimate, "rho_estimate"));

// Timing information
m_structFields.push_back(new OSQPStructField<OSQPFloat>(&m_struct->setup_time, "setup_time"));
m_structFields.push_back(new OSQPStructField<OSQPFloat>(&m_struct->solve_time, "solve_time"));
m_structFields.push_back(new OSQPStructField<OSQPFloat>(&m_struct->update_time, "update_time"));
m_structFields.push_back(new OSQPStructField<OSQPFloat>(&m_struct->polish_time, "polish_time"));
m_structFields.push_back(new OSQPStructField<OSQPFloat>(&m_struct->run_time, "run_time"));
}


// Instantiate the OSQPInfo wrapper class
template class OSQPStructWrapper<OSQPInfo>;
Loading

0 comments on commit 5e42763

Please sign in to comment.