diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 128f361..317b966 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -32,7 +32,7 @@ jobs: - name: Build OSQP interface uses: matlab-actions/run-command@v1 with: - command: make_osqp + command: osqp.build('osqp_mex') - name: Run tests uses: matlab-actions/run-tests@v1 diff --git a/@osqp/build.m b/@osqp/build.m new file mode 100644 index 0000000..0d13215 --- /dev/null +++ b/@osqp/build.m @@ -0,0 +1,164 @@ +function build(varargin) +% Matlab MEX makefile for OSQP. +% +% MAKE_OSQP(VARARGIN) is a make file for OSQP solver. It +% builds OSQP and its components from source. +% +% WHAT is the last element of VARARGIN and cell array of strings, +% with the following options: +% +% {}, '' (empty string) or 'all': build all components and link. +% +% 'osqp_mex': builds the OSQP mex interface and the OSQP library +% +% Additional commands: +% +% 'clean': Delete all compiled files +% 'purge': Delete all compiled files and copied code generation files + + if( nargin == 0 ) + what = {'all'}; + verbose = false; + elseif ( nargin == 1 && ismember('-verbose', varargin) ) + what = {'all'}; + verbose = true; + else + what = varargin{nargin}; + if(isempty(strfind(what, 'all')) && ... + isempty(strfind(what, 'osqp_mex')) && ... + isempty(strfind(what, 'clean')) && ... + isempty(strfind(what, 'purge'))) + fprintf('No rule to make target "%s", exiting.\n', what); + end + + verbose = ismember('-verbose', varargin); + end + + %% Determine where the various files are all located + % Various parts of the build system + [osqp_classpath,~,~] = fileparts( mfilename( 'fullpath' ) ); + osqp_mex_src_dir = fullfile( osqp_classpath, '..', 'c_sources' ); + osqp_mex_build_dir = fullfile( osqp_mex_src_dir, 'build' ); + osqp_cg_src_dir = fullfile( osqp_mex_build_dir, 'codegen_src' ); + osqp_cg_dest_dir = fullfile( osqp_classpath, '..', 'codegen', 'sources' ); + + % Determine where CMake should look for MATLAB + Matlab_ROOT = strrep( matlabroot, '\', '/' ); + + %% Try to unlock any pre-existing version of osqp_mex + % this prevents compile errors if a user builds, runs osqp + % and then tries to recompile + if(mislocked('osqp_mex')) + munlock('osqp_mex'); + end + + %% Configure, build and install the OSQP mex interface + if( any(strcmpi(what,'osqp_mex')) || any(strcmpi(what,'all')) ) + fprintf('Compiling OSQP solver mex interface...\n'); + + % Create build for the mex file and go inside + if exist( osqp_mex_build_dir, 'dir' ) + rmdir( osqp_mex_build_dir, 's' ); + end + mkdir( osqp_mex_build_dir ); + % cd( osqp_mex_build_dir ); + + % Extend path for CMake mac (via Homebrew) + PATH = getenv('PATH'); + if( (ismac) && (isempty(strfind(PATH, '/usr/local/bin'))) ) + setenv('PATH', [PATH ':/usr/local/bin']); + end + + + + %% Configure CMake for the mex interface + fprintf(' Configuring...' ) + [status, output] = system( sprintf( 'cmake -B %s -S %s -DCMAKE_BUILD_TYPE=RelWithDebInfo -DMatlab_ROOT_DIR=\"%s\"', osqp_mex_build_dir, osqp_mex_src_dir, Matlab_ROOT ), 'LD_LIBRARY_PATH', '' ); + if( status ) + fprintf( '\n' ); + disp( output ); + error( 'Error configuring CMake environment' ); + elseif( verbose ) + fprintf( '\n' ); + disp( output ); + else + fprintf( '\t\t\t\t\t[done]\n' ); + end + + %% Build the mex interface + fprintf( ' Building...') + [status, output] = system( sprintf( 'cmake --build %s --config Release', osqp_mex_build_dir ), 'LD_LIBRARY_PATH', '' ); + if( status ) + fprintf( '\n' ); + disp( output ); + error( 'Error compiling OSQP mex interface' ); + elseif( verbose ) + fprintf( '\n' ); + disp( output ); + else + fprintf( '\t\t\t\t\t\t[done]\n' ); + end + + + %% Install various files + fprintf( ' Installing...' ) + + % Copy mex file to root directory for use + if( ispc ) + [err, errmsg, ~] = copyfile( [osqp_mex_build_dir, filesep, 'Release', filesep, 'osqp_mex.mex*'], [osqp_classpath, filesep, 'private'] ); + else + [err, errmsg, ~] = copyfile( [osqp_mex_build_dir, filesep, 'osqp_mex.mex*'], [osqp_classpath, filesep, 'private'] ); + end + if( ~err ) + fprintf( '\n' ) + disp( errmsg ) + error( ' Error copying mex file' ) + end + + % Copy the code generation source files + % Create build for the mex file and go inside + if exist( osqp_cg_dest_dir, 'dir' ) + rmdir( osqp_cg_dest_dir, 's' ); + end + mkdir( osqp_cg_dest_dir ); + + [err, errmsg, ~] = copyfile( [osqp_cg_src_dir, filesep, '*'], osqp_cg_dest_dir ); + if( ~err ) + fprintf( '\n' ) + disp( errmsg ) + error( ' Error copying code generation source files' ) + end + + fprintf( '\t\t\t\t\t\t[done]\n' ); + end + + %% Clean and purge + if( any(strcmpi(what,'clean')) || any(strcmpi(what,'purge')) ) + fprintf('Cleaning OSQP mex files and build directory...'); + + % Delete mex file + mexfiles = dir(['*.', mexext]); + for i = 1 : length(mexfiles) + delete(mexfiles(i).name); + end + + % Delete OSQP build directory + if exist(osqp_mex_build_dir, 'dir') + rmdir(osqp_mex_build_dir, 's'); + end + + fprintf('\t\t[done]\n'); + + %% Purge only + if( any(strcmpi(what,'purge')) ) + fprintf('Cleaning OSQP codegen directories...'); + + % Delete codegen files + if exist(osqp_cg_dest_dir, 'dir') + rmdir(osqp_cg_dest_dir, 's'); + end + + fprintf('\t\t\t[done]\n'); + end + end +end diff --git a/@osqp/codegen.m b/@osqp/codegen.m new file mode 100644 index 0000000..d753762 --- /dev/null +++ b/@osqp/codegen.m @@ -0,0 +1,208 @@ +%% +function codegen(this, target_dir, varargin) + % CODEGEN generate C code for the parametric problem + % + % codegen(target_dir,options) + + % Parse input arguments + p = inputParser; + defaultProject = ''; + expectedProject = {'', 'Makefile', 'MinGW Makefiles', 'Unix Makefiles', 'CodeBlocks', 'Xcode'}; + defaultParams = 'vectors'; + expectedParams = {'vectors', 'matrices'}; + defaultMexname = 'emosqp'; + defaultFloat = false; + defaultLong = true; + defaultFW = false; + + addRequired(p, 'target_dir', @isstr); + addParameter(p, 'project_type', defaultProject, ... + @(x) ischar(validatestring(x, expectedProject))); + addParameter(p, 'parameters', defaultParams, ... + @(x) ischar(validatestring(x, expectedParams))); + addParameter(p, 'mexname', defaultMexname, @isstr); + addParameter(p, 'FLOAT', defaultFloat, @islogical); + addParameter(p, 'LONG', defaultLong, @islogical); + addParameter(p, 'force_rewrite', defaultFW, @islogical); + + parse(p, target_dir, varargin{:}); + + % Set internal variables + if strcmp(p.Results.parameters, 'vectors') + embedded = 1; + else + embedded = 2; + end + if p.Results.FLOAT + float_flag = 'ON'; + else + float_flag = 'OFF'; + end + if p.Results.LONG + long_flag = 'ON'; + else + long_flag = 'OFF'; + end + if strcmp(p.Results.project_type, 'Makefile') + if (ispc) + project_type = 'MinGW Makefiles'; % Windows + elseif (ismac || isunix) + project_type = 'Unix Makefiles'; % Unix + end + else + project_type = p.Results.project_type; + end + + % Check whether the specified directory already exists + if exist(target_dir, 'dir') + if p.Results.force_rewrite + rmdir(target_dir, 's'); + else + while(1) + prompt = sprintf('Directory "%s" already exists. Do you want to replace it? y/n [y]: ', target_dir); + str = input(prompt, 's'); + + if any(strcmpi(str, {'','y'})) + rmdir(target_dir, 's'); + break; + elseif strcmpi(str, 'n') + return; + end + end + end + end + + % Import OSQP path + [osqp_path,~,~] = fileparts(which('osqp.m')); + + % Add codegen directory to path + addpath(fullfile(osqp_path, 'codegen')); + + % Path of osqp module + cg_dir = fullfile(osqp_path, 'codegen'); + files_to_generate_path = fullfile(cg_dir, 'files_to_generate'); + + % Get workspace structure + work = osqp_mex('get_workspace', this.objectHandle); + + % Make target directory + fprintf('Creating target directories...\t\t\t\t\t'); + target_configure_dir = fullfile(target_dir, 'configure'); + target_include_dir = fullfile(target_dir, 'include'); + target_src_dir = fullfile(target_dir, 'src'); + + if ~exist(target_dir, 'dir') + mkdir(target_dir); + end + if ~exist(target_configure_dir, 'dir') + mkdir(target_configure_dir); + end + if ~exist(target_include_dir, 'dir') + mkdir(target_include_dir); + end + if ~exist(target_src_dir, 'dir') + mkdir(fullfile(target_src_dir, 'osqp')); + end + fprintf('[done]\n'); + + % Copy source files to target directory + fprintf('Copying OSQP source files...\t\t\t\t\t'); + cdir = fullfile(cg_dir, 'sources', 'src'); + cfiles = dir(fullfile(cdir, '*.c')); + for i = 1 : length(cfiles) + if embedded == 1 + % Do not copy kkt.c if embedded is 1 + if ~strcmp(cfiles(i).name, 'kkt.c') + copyfile(fullfile(cdir, cfiles(i).name), ... + fullfile(target_src_dir, 'osqp', cfiles(i).name)); + end + else + copyfile(fullfile(cdir, cfiles(i).name), ... + fullfile(target_src_dir, 'osqp', cfiles(i).name)); + end + end + configure_dir = fullfile(cg_dir, 'sources', 'configure'); + configure_files = dir(fullfile(configure_dir, '*.h.in')); + for i = 1 : length(configure_files) + copyfile(fullfile(configure_dir, configure_files(i).name), ... + fullfile(target_configure_dir, configure_files(i).name)); + end + hdir = fullfile(cg_dir, 'sources', 'include'); + hfiles = dir(fullfile(hdir, '*.h')); + for i = 1 : length(hfiles) + if embedded == 1 + % Do not copy kkt.h if embedded is 1 + if ~strcmp(hfiles(i).name, 'kkt.h') + copyfile(fullfile(hdir, hfiles(i).name), ... + fullfile(target_include_dir, hfiles(i).name)); + end + else + copyfile(fullfile(hdir, hfiles(i).name), ... + fullfile(target_include_dir, hfiles(i).name)); + end + end + + % Copy cmake files + copyfile(fullfile(cdir, 'CMakeLists.txt'), ... + fullfile(target_src_dir, 'osqp', 'CMakeLists.txt')); + copyfile(fullfile(hdir, 'CMakeLists.txt'), ... + fullfile(target_include_dir, 'CMakeLists.txt')); + fprintf('[done]\n'); + + % Copy example.c + copyfile(fullfile(files_to_generate_path, 'example.c'), target_src_dir); + + % Render CMakeLists.txt + fidi = fopen(fullfile(files_to_generate_path, 'CMakeLists.txt'),'r'); + fido = fopen(fullfile(target_dir, 'CMakeLists.txt'),'w'); + while ~feof(fidi) + l = fgetl(fidi); % read line + % Replace EMBEDDED_FLAG in CMakeLists.txt by a numerical value + newl = strrep(l, 'EMBEDDED_FLAG', num2str(embedded)); + fprintf(fido, '%s\n', newl); + end + fclose(fidi); + fclose(fido); + + % Render workspace.h and workspace.c + work_hfile = fullfile(target_include_dir, 'workspace.h'); + work_cfile = fullfile(target_src_dir, 'osqp', 'workspace.c'); + fprintf('Generating workspace.h/.c...\t\t\t\t\t\t'); + render_workspace(work, work_hfile, work_cfile, embedded); + fprintf('[done]\n'); + + % Create project + if ~isempty(project_type) + + % Extend path for CMake mac (via Homebrew) + PATH = getenv('PATH'); + if ((ismac) && (isempty(strfind(PATH, '/usr/local/bin')))) + setenv('PATH', [PATH ':/usr/local/bin']); + end + + fprintf('Creating project...\t\t\t\t\t\t\t\t'); + orig_dir = pwd; + cd(target_dir); + mkdir('build') + cd('build'); + cmd = sprintf('cmake -G "%s" ..', project_type); + [status, output] = system(cmd); + if(status) + fprintf('\n'); + fprintf(output); + error('Error configuring CMake environment'); + else + fprintf('[done]\n'); + end + cd(orig_dir); + end + + % Make mex interface to the generated code + mex_cfile = fullfile(files_to_generate_path, 'emosqp_mex.c'); + make_emosqp(target_dir, mex_cfile, embedded, float_flag, long_flag); + + % Rename the mex file + old_mexfile = ['emosqp_mex.', mexext]; + new_mexfile = [p.Results.mexname, '.', mexext]; + movefile(old_mexfile, new_mexfile); +end \ No newline at end of file diff --git a/@osqp/osqp.m b/@osqp/osqp.m new file mode 100644 index 0000000..cc58bcd --- /dev/null +++ b/@osqp/osqp.m @@ -0,0 +1,90 @@ +classdef osqp < handle + % osqp interface class for OSQP solver + % This class provides a complete interface to the C implementation + % of the OSQP solver. + % + % osqp Properties: + % objectHandle - pointer to the C structure of OSQP solver + % + % osqp Methods: + % + % setup - configure solver with problem data + % solve - solve the QP + % update - modify problem vectors + % warm_start - set warm starting variables x and y + % + % default_settings - create default settings structure + % current_settings - get the current solver settings structure + % update_settings - update the current solver settings structure + % + % get_dimensions - get the number of variables and constraints + % version - return OSQP version + % constant - return a OSQP internal constant + % + % codegen - generate embeddable C code for the problem + + + properties(SetAccess = private, Hidden = true) + objectHandle % Handle to underlying C instance + end + + methods(Static) + output = build(varargin) + + %% + function out = default_settings() + % DEFAULT_SETTINGS get the default solver settings structure + out = osqp_mex('default_settings'); + + % Convert linsys solver to string + out.linsys_solver = linsys_solver_to_string(out.linsys_solver); + end + + %% + function out = constant(constant_name) + % CONSTANT Return solver constant + % C = CONSTANT(CONSTANT_NAME) return constant called CONSTANT_NAME + out = osqp_mex('constant', constant_name); + end + + %% + function out = version() + % Return OSQP version + out = osqp_mex('version'); + end + end + + methods(Access = private) + currentSettings = validate_settings(this, isInitialization, varargin) + end + + methods + %% Constructor - Create a new solver instance + function this = osqp(varargin) + % Construct OSQP solver class + this.objectHandle = osqp_mex('new', varargin{:}); + end + + %% Destructor - destroy the solver instance + function delete(this) + % Destroy OSQP solver class + osqp_mex('delete', this.objectHandle); + end + + %% + function out = current_settings(this) + % CURRENT_SETTINGS get the current solver settings structure + out = osqp_mex('current_settings', this.objectHandle); + + % Convert linsys solver to string + out.linsys_solver = linsys_solver_to_string(out.linsys_solver); + end + + %% + function [n,m] = get_dimensions(this) + % GET_DIMENSIONS get the number of variables and constraints + + [n,m] = osqp_mex('get_dimensions', this.objectHandle); + end + end +end \ No newline at end of file diff --git a/@osqp/private/linsys_solver_to_string.m b/@osqp/private/linsys_solver_to_string.m new file mode 100644 index 0000000..bac8047 --- /dev/null +++ b/@osqp/private/linsys_solver_to_string.m @@ -0,0 +1,13 @@ +% Convert linear systme solver integer to string +function [linsys_solver_string] = linsys_solver_to_string(linsys_solver) + switch linsys_solver + case osqp.constant('OSQP_UNKNOWN_SOLVER') + linsys_solver_string = 'unknown solver'; + case osqp.constant('OSQP_DIRECT_SOLVER') + linsys_solver_string = 'direct solver'; + case osqp.constant('OSQP_INDIRECT_SOLVER') + linsys_solver_string = 'indirect solver'; + otherwise + error('Unrecognized linear system solver.'); + end +end diff --git a/@osqp/private/string_to_linsys_solver.m b/@osqp/private/string_to_linsys_solver.m new file mode 100644 index 0000000..51e5c99 --- /dev/null +++ b/@osqp/private/string_to_linsys_solver.m @@ -0,0 +1,18 @@ +function [linsys_solver] = string_to_linsys_solver(linsys_solver_string) + linsys_solver_string = lower(linsys_solver_string); + switch linsys_solver_string + case 'unknown solver' + linsys_solver = osqp.constant('OSQP_UNKNOWN_SOLVER'); + case 'direct solver' + linsys_solver = osqp.constant('OSQP_DIRECT_SOLVER'); + case 'indirect solver' + linsys_solver = osqp.constant('OSQP_INDIRECT_SOLVER'); + % Default solver: QDLDL + case '' + linsys_solver = osqp.constant('OSQP_DIRECT_SOLVER'); + otherwise + warning('Linear system solver not recognized. Using default solver OSQP_DIRECT_SOLVER.') + linsys_solver = osqp.constant('OSQP_DIRECT_SOLVER'); + end +end + \ No newline at end of file diff --git a/@osqp/setup.m b/@osqp/setup.m new file mode 100644 index 0000000..5317e2d --- /dev/null +++ b/@osqp/setup.m @@ -0,0 +1,101 @@ +%% +function varargout = setup(this, varargin) + % SETUP configure solver with problem data + % + % setup(P,q,A,l,u,options) + + nargin = length(varargin); + + %dimension checks on user data. Mex function does not + %perform any checks on inputs, so check everything here + assert(nargin >= 5, 'incorrect number of inputs'); + [P,q,A,l,u] = deal(varargin{1:5}); + + % + % Get problem dimensions + % + + % Get number of variables n + if (~isempty(P)) + n = size(P, 1); + elseif (~isempty(q)) + n = length(q); + elseif (~isempty(A)) + n = size(A, 2); + else + error('The problem does not have any variables'); + end + + % Get number of constraints m + if (isempty(A)) + m = 0; + else + m = size(A, 1); + assert(size(A, 2) == n, 'Incorrect dimension of A'); + end + + % + % Create sparse matrices and full vectors if they are empty + % + + if (isempty(P)) + P = sparse(n, n); + else + P = sparse(P); + end + if (~istriu(P)) + P = triu(P); + end + if (isempty(q)) + q = zeros(n, 1); + else + q = full(q(:)); + end + + % Create proper constraints if they are not passed + if (isempty(A) && (~isempty(l) || ~isempty(u))) || ... + (~isempty(A) && (isempty(l) && isempty(u))) + error('A must be supplied together with at least one bound l or u'); + end + + if (~isempty(A) && isempty(l)) + l = -Inf(m, 1); + end + + if (~isempty(A) && isempty(u)) + u = Inf(m, 1); + end + + if (isempty(A)) + A = sparse(m, n); + l = -Inf(m, 1); + u = Inf(m, 1); + else + l = full(l(:)); + u = full(u(:)); + A = sparse(A); + end + + + % + % Check vector dimensions (not checked from the C solver) + % + assert(length(q) == n, 'Incorrect dimension of q'); + assert(length(l) == m, 'Incorrect dimension of l'); + assert(length(u) == m, 'Incorrect dimension of u'); + + % + % Convert infinity values to OSQP_INFINITY + % + u = min(u, osqp.constant('OSQP_INFTY')); + l = max(l, -osqp.constant('OSQP_INFTY')); + + + %make a settings structure from the remainder of the arguments. + %'true' means that this is a settings initialization, so all + %parameter/values are allowed. No extra inputs will result + %in default settings being passed back + theSettings = validate_settings(this,true,varargin{6:end}); + + [varargout{1:nargout}] = osqp_mex('setup', this.objectHandle, n,m,P,q,A,l,u,theSettings); +end \ No newline at end of file diff --git a/@osqp/solve.m b/@osqp/solve.m new file mode 100644 index 0000000..2ce96cd --- /dev/null +++ b/@osqp/solve.m @@ -0,0 +1,11 @@ +%% +function varargout = solve(this, varargin) + % SOLVE solve the QP + + nargoutchk(0,1); %either return nothing (but still solve), or a single output structure + [out.x, out.y, out.prim_inf_cert, out.dual_inf_cert, out.info] = osqp_mex('solve', this.objectHandle); + if(nargout) + varargout{1} = out; + end + return; +end \ No newline at end of file diff --git a/@osqp/update.m b/@osqp/update.m new file mode 100644 index 0000000..4f27b6e --- /dev/null +++ b/@osqp/update.m @@ -0,0 +1,69 @@ +%% +function update(this,varargin) + % UPDATE modify the linear cost term and/or lower and upper bounds + + %second input 'false' means that this is *not* a settings + %initialization, so some parameter/values will be disallowed + allowedFields = {'q','l','u','Px','Px_idx','Ax','Ax_idx'}; + + if(isempty(varargin)) + return; + elseif(length(varargin) == 1) + if(~isstruct(varargin{1})) + error('Single input should be a structure with new problem data'); + else + newData = varargin{1}; + end + else % param / value style assumed + newData = struct(varargin{:}); + end + + %check for unknown fields + newFields = fieldnames(newData); + badFieldsIdx = find(~ismember(newFields,allowedFields)); + if(~isempty(badFieldsIdx)) + error('Unrecognized input field ''%s'' detected',newFields{badFieldsIdx(1)}); + end + + %get all of the terms. Nonexistent fields will be passed + %as empty mxArrays + try q = double(full(newData.q(:))); catch q = []; end + try l = double(full(newData.l(:))); catch l = []; end + try u = double(full(newData.u(:))); catch u = []; end + try Px = double(full(newData.Px(:))); catch Px = []; end + try Px_idx = double(full(newData.Px_idx(:))); catch Px_idx = []; end + try Ax = double(full(newData.Ax(:))); catch Ax = []; end + try Ax_idx = double(full(newData.Ax_idx(:))); catch Ax_idx = []; end + + [n,m] = get_dimensions(this); + + assert(isempty(q) || length(q) == n, 'input ''q'' is the wrong size'); + assert(isempty(l) || length(l) == m, 'input ''u'' is the wrong size'); + assert(isempty(u) || length(u) == m, 'input ''l'' is the wrong size'); + assert(isempty(Px) || isempty(Px_idx) || length(Px) == length(Px_idx), ... + 'inputs ''Px'' and ''Px_idx'' must be the same size'); + assert(isempty(Ax) || isempty(Ax_idx) || length(Ax) == length(Ax_idx), ... + 'inputs ''Ax'' and ''Ax_idx'' must be the same size'); + + % Adjust index of Px_idx and Ax_idx to match 0-based indexing + % in C + if (~isempty(Px_idx)) + Px_idx = Px_idx - 1; + end + if (~isempty(Ax_idx)) + Ax_idx = Ax_idx - 1; + end + + % Convert infinity values to OSQP_INFTY + if (~isempty(u)) + u = min(u, osqp.constant('OSQP_INFTY')); + end + if (~isempty(l)) + l = max(l, -osqp.constant('OSQP_INFTY')); + end + + %write the new problem data. C-mex does not protect + %against unknown fields, but will handle empty values + osqp_mex('update', this.objectHandle, ... + q, l, u, Px, Px_idx, length(Px), Ax, Ax_idx, length(Ax)); +end \ No newline at end of file diff --git a/@osqp/update_settings.m b/@osqp/update_settings.m new file mode 100644 index 0000000..233710d --- /dev/null +++ b/@osqp/update_settings.m @@ -0,0 +1,25 @@ +function update_settings(this, varargin) + % UPDATE_SETTINGS update the current solver settings structure + + % Check for structure style input + if(isstruct(varargin{1})) + newSettings = varargin{1}; + assert(length(varargin) == 1, 'too many input arguments'); + else + newSettings = struct(varargin{:}); + end + + % Rho update must be handled special + if( isfield(newSettings, 'rho') ) + osqp_mex('update_rho', this.objectHandle, newSettings.rho); + newSettings = rmfield(newSettings, 'rho'); + end + + % Second input 'false' means that this is *not* a settings + % initialization, so some parameter/values will be disallowed + newSettings = validate_settings(this, false, varargin{:}); + + % Write the solver settings. C-mex does not check input + % data or protect against disallowed parameter modifications + osqp_mex('update_settings', this.objectHandle, newSettings); +end \ No newline at end of file diff --git a/@osqp/validate_settings.m b/@osqp/validate_settings.m new file mode 100644 index 0000000..3ce7d6d --- /dev/null +++ b/@osqp/validate_settings.m @@ -0,0 +1,71 @@ +function currentSettings = validate_settings(this, isInitialization, varargin) + % Don't allow these fields to be changed + unmodifiableFields = {'scaling', 'linsys_solver'}; + + % Get the current settings + if(isInitialization) + currentSettings = osqp_mex('default_settings', this.objectHandle); + else + currentSettings = osqp_mex('current_settings', this.objectHandle); + end + + % No settings passed -> return defaults + if(isempty(varargin)) + return; + end + + % Check for structure style input + if(isstruct(varargin{1})) + newSettings = varargin{1}; + assert(length(varargin) == 1, 'too many input arguments'); + else + newSettings = struct(varargin{:}); + end + + % Get the osqp settings fields + currentFields = fieldnames(currentSettings); + + % Get the requested fields in the update + newFields = fieldnames(newSettings); + + % Check for unknown parameters + badFieldsIdx = find(~ismember(newFields,currentFields)); + if(~isempty(badFieldsIdx)) + error('Unrecognized solver setting ''%s'' detected',newFields{badFieldsIdx(1)}); + end + + % Convert linsys_solver string to integer + if ismember('linsys_solver',newFields) + if ~ischar(newSettings.linsys_solver) + error('Setting linsys_solver is required to be a string.'); + end + % Convert linsys_solver to number + newSettings.linsys_solver = string_to_linsys_solver(newSettings.linsys_solver); + end + + + % Check for disallowed fields if this in not an initialization call + if(~isInitialization) + badFieldsIdx = find(ismember(newFields,unmodifiableFields)); + for i = badFieldsIdx(:)' + if(~isequal(newSettings.(newFields{i}),currentSettings.(newFields{i}))) + error('Solver setting ''%s'' can only be changed at solver initialization.', newFields{i}); + end + end + end + + + % Check that everything is a nonnegative scalar (this check is already + % performed in C) + % for i = 1:length(newFields) + % val = double(newSettings.(newFields{i})); + % assert(isscalar(val) & isnumeric(val) & val >= 0, ... + % 'Solver setting ''%s'' not specified as nonnegative scalar', newFields{i}); + % end + + % Everything checks out - merge the newSettings into the current ones + for i = 1:length(newFields) + currentSettings.(newFields{i}) = double(newSettings.(newFields{i})); + end +end + \ No newline at end of file diff --git a/@osqp/warm_start.m b/@osqp/warm_start.m new file mode 100644 index 0000000..81d56ae --- /dev/null +++ b/@osqp/warm_start.m @@ -0,0 +1,50 @@ +function warm_start(this, varargin) + % WARM_START warm start primal and/or dual variables + % + % warm_start('x', x, 'y', y) + % + % or warm_start('x', x) + % or warm_start('y', y) + + + % Get problem dimensions + [n, m] = get_dimensions(this); + + % Get data + allowedFields = {'x','y'}; + + if(isempty(varargin)) + return; + elseif(length(varargin) == 1) + if(~isstruct(varargin{1})) + error('Single input should be a structure with new problem data'); + else + newData = varargin{1}; + end + else % param / value style assumed + newData = struct(varargin{:}); + end + + %check for unknown fields + newFields = fieldnames(newData); + badFieldsIdx = find(~ismember(newFields,allowedFields)); + if(~isempty(badFieldsIdx)) + error('Unrecognized input field ''%s'' detected',newFields{badFieldsIdx(1)}); + end + + %get all of the terms. Nonexistent fields will be passed + %as empty mxArrays + try x = double(full(newData.x(:))); catch x = []; end + try y = double(full(newData.y(:))); catch y = []; end + + % Check dimensions + assert(isempty(x) || length(x) == n, 'input ''x'' is the wrong size'); + assert(isempty(y) || length(y) == m, 'input ''y'' is the wrong size'); + + % Only call when there is a vector to update + if (~isempty(x) || ~isempty(y)) + osqp_mex('warm_start', this.objectHandle, x, y); + else + error('Unrecognized fields'); + end +end \ No newline at end of file diff --git a/c_sources/CMakeLists.txt b/c_sources/CMakeLists.txt index f35ceaa..5c42782 100644 --- a/c_sources/CMakeLists.txt +++ b/c_sources/CMakeLists.txt @@ -10,6 +10,10 @@ message( STATUS "Matlab root is " ${Matlab_ROOT_DIR} ) include_directories( ${Matlab_INCLUDE_DIRS} ) +# The mex interface uses C++11 +set( CMAKE_CXX_STANDARD 11 ) +set( CMAKE_CXX_STANDARD_REQUIRED ON ) + if( CMAKE_COMPILER_IS_GNUCXX ) # Add debug symbols and optimizations to the build set( CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -g -O2" ) @@ -62,6 +66,9 @@ endif() matlab_add_mex( NAME osqp_mex SRC ${CMAKE_CURRENT_SOURCE_DIR}/osqp_mex.cpp ${CMAKE_CURRENT_SOURCE_DIR}/interrupt_matlab.c + ${CMAKE_CURRENT_SOURCE_DIR}/memory_matlab.c + ${CMAKE_CURRENT_SOURCE_DIR}/osqp_struct_info.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/osqp_struct_settings.cpp LINK_TO osqpstatic ${UT_LIBRARY} # Force compilation in the traditional C API (equivalent to the -R2017b flag) diff --git a/c_sources/arrays_matlab.h b/c_sources/arrays_matlab.h new file mode 100644 index 0000000..0d5408b --- /dev/null +++ b/c_sources/arrays_matlab.h @@ -0,0 +1,43 @@ +#ifndef ARRAYS_MATLAB_H_ +#define ARRAYS_MATLAB_H_ + +#include + +/** + * Copy the data from one array to another provided array. + */ +template +void copyVector(outArr* out, inArr* in, OSQPInt numel) { + // Don't bother doing anything if there is no input data + if(!in || !out || (numel == 0)) + return; + + // Copy the data + for(OSQPInt i=0; i < numel; i++){ + out[i] = static_cast(in[i]); + } +} + + +/** + * Copy the data from one array to another newly allocated array. + * The caller gains ownership of the returned array. + */ +template +outArr* cloneVector(inArr* in, OSQPInt numel) { + // Don't bother doing anything if there is no input data + if(!in || (numel == 0)) + return NULL; + + // Allocate new array + outArr* out = static_cast(c_malloc(numel * sizeof(outArr))); + + if(!out) + mexErrMsgTxt("Failed to allocate a vector object."); + + // Copy the data + copyVector(out, in, numel); + return out; +} + +#endif \ No newline at end of file diff --git a/c_sources/memory_matlab.c b/c_sources/memory_matlab.c new file mode 100644 index 0000000..c3b06a8 --- /dev/null +++ b/c_sources/memory_matlab.c @@ -0,0 +1,19 @@ +#include + +void* c_calloc(size_t num, size_t size) { + void *m = mxCalloc(num, size); + mexMakeMemoryPersistent(m); + return m; +} + +void* c_malloc(size_t size) { + void *m = mxMalloc(size); + mexMakeMemoryPersistent(m); + return m; +} + +void* c_realloc(void *ptr, size_t size) { + void *m = mxRealloc(ptr, size); + mexMakeMemoryPersistent(m); + return m; +} \ No newline at end of file diff --git a/c_sources/memory_matlab.h b/c_sources/memory_matlab.h index c8276b8..abd48b6 100644 --- a/c_sources/memory_matlab.h +++ b/c_sources/memory_matlab.h @@ -1,22 +1,17 @@ /* Memory managment for MATLAB */ #include "mex.h" -static void* c_calloc(size_t num, size_t size) { - void *m = mxCalloc(num, size); - mexMakeMemoryPersistent(m); - return m; -} -static void* c_malloc(size_t size) { - void *m = mxMalloc(size); - mexMakeMemoryPersistent(m); - return m; -} +#ifdef __cplusplus +extern "C" { +#endif + + void* c_calloc(size_t num, size_t size); + void* c_malloc(size_t size); + void* c_realloc(void *ptr, size_t size); -static void* c_realloc(void *ptr, size_t size) { - void *m = mxRealloc(ptr, size); - mexMakeMemoryPersistent(m); - return m; +#ifdef __cplusplus } +#endif #define c_free mxFree \ No newline at end of file diff --git a/c_sources/osqp_mex.cpp b/c_sources/osqp_mex.cpp index 5b0b93b..05f82b3 100755 --- a/c_sources/osqp_mex.cpp +++ b/c_sources/osqp_mex.cpp @@ -1,101 +1,57 @@ +#include + #include "mex.h" #include "matrix.h" -#include "osqp_mex.hpp" #include "osqp.h" -#include "memory_matlab.h" -//c_int is replaced with OSQPInt -//c_float is replaced with OSQPFloat +// Mex-specific functionality +#include "osqp_mex.hpp" +#include "osqp_struct.h" +#include "arrays_matlab.h" +#include "memory_matlab.h" //TODO: Check if this definition is required, and maybe replace it with: // enum linsys_solver_type { QDLDL_SOLVER, MKL_PARDISO_SOLVER }; #define QDLDL_SOLVER 0 //Based on the previous API -// all of the OSQP_INFO fieldnames as strings -const char* OSQP_INFO_FIELDS[] = {"status", //char* - "status_val", //OSQPInt - "status_polish", //OSQPInt - "obj_val", //OSQPFloat - "prim_res", //OSQPFloat - "dual_res", //OSQPFloat - "iter", //OSQPInt - "rho_updates", //OSQPInt - "rho_estimate", //OSQPFloat - "setup_time", //OSQPFloat - "solve_time", //OSQPFloat - "update_time", //OSQPFloat - "polish_time", //OSQPFloat - "run_time", //OSQPFloat - }; - -const char* OSQP_SETTINGS_FIELDS[] = {"device", //OSQPInt - "linsys_solver", //enum osqp_linsys_solver_type - "verbose", //OSQPInt - "warm_starting", //OSQPInt - "scaling", //OSQPInt - "polishing", //OSQPInt - "rho", //OSQPFloat - "rho_is_vec", //OSQPInt - "sigma", //OSQPFloat - "alpha", //OSQPFloat - "cg_max_iter", //OSQPInt - "cg_tol_reduction", //OSQPInt - "cg_tol_fraction", //OSQPFloat - "cg_precond", //osqp_precond_type - "adaptive_rho", //OSQPInt - "adaptive_rho_interval", //OSQPInt - "adaptive_rho_fraction", //OSQPFloat - "adaptive_rho_tolerance", //OSQPFloat - "max_iter", //OSQPInt - "eps_abs", //OSQPFloat - "eps_rel", //OSQPFloat - "eps_prim_inf", //OSQPFloat - "eps_dual_inf", //OSQPFloat - "scaled_termination", //OSQPInt - "check_termination", //OSQPInt - "time_limit", //OSQPFloat - "delta", //OSQPFloat - "polish_refine_iter", //OSQPInt - }; - -#define NEW_SETTINGS_TOL (1e-10) - -// wrapper class for all osqp data and settings +// Wrapper class to pass the OSQP solver back and forth with Matlab class OsqpData { public: - OsqpData() : solver(NULL){} - OSQPSolver * solver; + OsqpData() : + solver(NULL) + {} + OSQPSolver* solver; }; -// internal utility functions -OSQPSolver* initializeOSQPSolver(); -void castToDoubleArr(OSQPFloat *arr, double* arr_out, OSQPInt len); -void setToNaN(double* arr_out, OSQPInt len); -void copyMxStructToSettings(const mxArray*, OSQPSettings*); -void copyUpdatedSettingsToWork(const mxArray*, OSQPSolver*); -//void castCintToDoubleArr(OSQPInt *arr, double* arr_out, OSQPInt len); //DELETE HERE -void freeCscMatrix(OSQPCscMatrix* M); -OSQPInt* copyToOSQPIntVector(mwIndex * vecData, OSQPInt numel); -OSQPInt* copyDoubleToOSQPIntVector(double* vecData, OSQPInt numel); -OSQPFloat* copyToOSQPFloatVector(double * vecData, OSQPInt numel); -mxArray* copyInfoToMxStruct(OSQPInfo* info); -mxArray* copySettingsToMxStruct(OSQPSettings* settings); + +// Internal utility function +static void setToNaN(double* arr_out, OSQPInt len){ + OSQPInt i; + for (i = 0; i < len; i++) { + arr_out[i] = mxGetNaN(); + } +} +// Main mex function void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[]) -{ +{ + // OSQP solver wrapper OsqpData* osqpData; - //OSQPSolver* osqpSolver = NULL; + // Exitflag OSQPInt exitflag = 0; - // Static string for static methods - char stat_string[64]; + // Get the command string char cmd[64]; - if (nrhs < 1 || mxGetString(prhs[0], cmd, sizeof(cmd))) + + if (nrhs < 1 || mxGetString(prhs[0], cmd, sizeof(cmd))) mexErrMsgTxt("First input should be a command string less than 64 characters long."); - // new object + + /* + * First check to see if a new object was requested + */ if (!strcmp("new", cmd)) { // Check parameters if (nlhs != 1){ @@ -103,27 +59,104 @@ void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[]) } // Return a handle to a new C++ wrapper instance osqpData = new OsqpData; - //osqpData->solver = initializeOSQPSolver(); - osqpData->solver = NULL; plhs[0] = convertPtr2Mat(osqpData); return; } - // Check for a second input, which should be the class instance handle or string 'static' - if (nrhs < 2) - mexErrMsgTxt("Second input should be a class instance handle or the string 'static'."); - - if(mxGetString(prhs[1], stat_string, sizeof(stat_string))){ - // If we are dealing with non-static methods, get the class instance pointer from the second input - osqpData = convertMat2Ptr(prhs[1]); - } else { - if (strcmp("static", stat_string)){ - mexErrMsgTxt("Second argument for static functions is string 'static'"); + /* + * Next check to see if any of the static methods were called + */ + // Report the version + if (!strcmp("version", cmd)) { + plhs[0] = mxCreateString(osqp_version()); + return; + } + + // Report the default settings + if (!strcmp("default_settings", cmd)) { + // Warn if other commands were ignored + if (nrhs > 2) + mexWarnMsgTxt("Default settings: unexpected number of arguments."); + + // Create a Settings structure in default form and report the results + // Useful for external solver packages (e.g. Yalmip) that want to + // know which solver settings are supported + OSQPSettingsWrapper settings; + plhs[0] = settings.GetMxStruct(); + return; + } + + // Return solver constants + if (!strcmp("constant", cmd)) { + static std::map floatConstants{ + // Numerical constants + {"OSQP_INFTY", OSQP_INFTY} + }; + + static std::map intConstants{ + // Return codes + {"OSQP_SOLVED", OSQP_SOLVED}, + {"OSQP_SOLVED_INACCURATE", OSQP_SOLVED_INACCURATE}, + {"OSQP_UNSOLVED", OSQP_UNSOLVED}, + {"OSQP_PRIMAL_INFEASIBLE", OSQP_PRIMAL_INFEASIBLE}, + {"OSQP_PRIMAL_INFEASIBLE_INACCURATE", OSQP_PRIMAL_INFEASIBLE_INACCURATE}, + {"OSQP_DUAL_INFEASIBLE", OSQP_DUAL_INFEASIBLE}, + {"OSQP_DUAL_INFEASIBLE_INACCURATE", OSQP_DUAL_INFEASIBLE_INACCURATE}, + {"OSQP_MAX_ITER_REACHED", OSQP_MAX_ITER_REACHED}, + {"OSQP_NON_CVX", OSQP_NON_CVX}, + {"OSQP_TIME_LIMIT_REACHED", OSQP_TIME_LIMIT_REACHED}, + + // Linear system solvers + {"QDLDL_SOLVER", QDLDL_SOLVER}, + {"OSQP_UNKNOWN_SOLVER", OSQP_UNKNOWN_SOLVER}, + {"OSQP_DIRECT_SOLVER", OSQP_DIRECT_SOLVER}, + {"OSQP_INDIRECT_SOLVER", OSQP_INDIRECT_SOLVER} + }; + + char constant[64]; + int constantLength = mxGetN(prhs[1]) + 1; + mxGetString(prhs[1], constant, constantLength); + + auto ci = intConstants.find(constant); + + if(ci != intConstants.end()) { + plhs[0] = mxCreateDoubleScalar(ci->second); + return; + } + + auto cf = floatConstants.find(constant); + + if(cf != floatConstants.end()) { + plhs[0] = mxCreateDoubleScalar(cf->second); + return; + } + + // NaN is special because we need the Matlab version + if (!strcmp("OSQP_NAN", constant)){ + plhs[0] = mxCreateDoubleScalar(mxGetNaN()); + return; } + + mexErrMsgTxt("Constant not recognized."); + + return; } + + /* + * Finally, check to see if this is a function operating on a solver instance + */ + + // Check for a second input, which should be the class instance handle + if (nrhs < 2) + mexErrMsgTxt("Second input should be a class instance handle."); + + + // Get the class instance pointer from the second input + osqpData = convertMat2Ptr(prhs[1]); + // delete the object and its data if (!strcmp("delete", cmd)) { - + osqp_cleanup(osqpData->solver); destroyObject(prhs[1]); // Warn if other commands were ignored @@ -134,45 +167,46 @@ void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[]) // report the current settings if (!strcmp("current_settings", cmd)) { - //throw an error if this is called before solver is configured - if(!osqpData->solver) mexErrMsgTxt("Solver is uninitialized. No settings have been configured."); - if(!osqpData->solver->settings){ - mexErrMsgTxt("Solver settings is uninitialized. No settings have been configured."); - } - //report the current settings - plhs[0] = copySettingsToMxStruct(osqpData->solver->settings); - return; + // Throw an error if this is called before solver is configured + if(!osqpData->solver) { + mexErrMsgTxt("Solver is uninitialized. No settings have been configured."); + } + if(!osqpData->solver->settings) { + mexErrMsgTxt("Solver settings is uninitialized. No settings have been configured."); + } + + // Report the current settings + OSQPSettingsWrapper settings(osqpData->solver->settings); + plhs[0] = settings.GetMxStruct(); + return; } // write_settings if (!strcmp("update_settings", cmd)) { - //overwrite the current settings. Mex function is responsible - //for disallowing overwrite of selected settings after initialization, - //and for all error checking - //throw an error if this is called before solver is configured - if(!osqpData->solver){ - mexErrMsgTxt("Solver is uninitialized. No settings have been configured."); - } + // Overwrite the current settings. Mex function is responsible + // for disallowing overwrite of selected settings after initialization, + // and for all error checking + // throw an error if this is called before solver is configured + if(!osqpData->solver){ + mexErrMsgTxt("Solver is uninitialized. No settings have been configured."); + } - copyUpdatedSettingsToWork(prhs[2],osqpData->solver); - return; + OSQPSettingsWrapper settings(prhs[2]); + osqp_update_settings(osqpData->solver, settings.GetOSQPStruct()); + return; } - // report the default settings - if (!strcmp("default_settings", cmd)) { - // Warn if other commands were ignored - if (nrhs > 2) - mexWarnMsgTxt("Default settings: unexpected number of arguments."); + // Update rho value + if (!strcmp("update_rho", cmd)) { + //throw an error if this is called before solver is configured + if(!osqpData->solver){ + mexErrMsgTxt("Solver is uninitialized. No settings have been configured."); + } + OSQPFloat rho = (OSQPFloat)mxGetScalar(prhs[2]); - //Create a Settings structure in default form and report the results - //Useful for external solver packages (e.g. Yalmip) that want to - //know which solver settings are supported - OSQPSettings* defaults = (OSQPSettings *)mxCalloc(1,sizeof(OSQPSettings)); - osqp_set_default_settings(defaults); - plhs[0] = copySettingsToMxStruct(defaults); - mxFree(defaults); - return; + osqp_update_rho(osqpData->solver, rho); + return; } // setup @@ -181,8 +215,6 @@ void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[]) if(osqpData->solver){ mexErrMsgTxt("Solver is already initialized with problem data."); } - //Create data and settings containers - OSQPSettings* settings = (OSQPSettings *)mxCalloc(1,sizeof(OSQPSettings)); // handle the problem data first. Matlab-side // class wrapper is responsible for ensuring that @@ -200,37 +232,35 @@ void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[]) OSQPInt dataN = (OSQPInt)mxGetScalar(prhs[2]); OSQPInt dataM = (OSQPInt)mxGetScalar(prhs[3]); - OSQPFloat* dataQ = copyToOSQPFloatVector(mxGetPr(q), dataN); - OSQPFloat* dataL = copyToOSQPFloatVector(mxGetPr(l), dataM); - OSQPFloat* dataU = copyToOSQPFloatVector(mxGetPr(u), dataM); + OSQPFloat* dataQ = cloneVector(mxGetPr(q), dataN); + OSQPFloat* dataL = cloneVector(mxGetPr(l), dataM); + OSQPFloat* dataU = cloneVector(mxGetPr(u), dataM); // Matrix P: nnz = P->p[n] - OSQPInt * Pp = (OSQPInt*)copyToOSQPIntVector(mxGetJc(P), dataN + 1); - OSQPInt * Pi = (OSQPInt*)copyToOSQPIntVector(mxGetIr(P), Pp[dataN]); - OSQPFloat * Px = copyToOSQPFloatVector(mxGetPr(P), Pp[dataN]); + OSQPInt * Pp = cloneVector(mxGetJc(P), dataN + 1); + OSQPInt * Pi = cloneVector(mxGetIr(P), Pp[dataN]); + OSQPFloat * Px = cloneVector(mxGetPr(P), Pp[dataN]); OSQPCscMatrix* dataP = (OSQPCscMatrix*)c_calloc(1,sizeof(OSQPCscMatrix)); csc_set_data(dataP, dataN, dataN, Pp[dataN], Px, Pi, Pp); // Matrix A: nnz = A->p[n] - OSQPInt* Ap = (OSQPInt*)copyToOSQPIntVector(mxGetJc(A), dataN + 1); - OSQPInt* Ai = (OSQPInt*)copyToOSQPIntVector(mxGetIr(A), Ap[dataN]); - OSQPFloat * Ax = copyToOSQPFloatVector(mxGetPr(A), Ap[dataN]); + OSQPInt* Ap = cloneVector(mxGetJc(A), dataN + 1); + OSQPInt* Ai = cloneVector(mxGetIr(A), Ap[dataN]); + OSQPFloat * Ax = cloneVector(mxGetPr(A), Ap[dataN]); OSQPCscMatrix* dataA = (OSQPCscMatrix*)c_calloc(1,sizeof(OSQPCscMatrix)); csc_set_data(dataA, dataM, dataN, Ap[dataN], Ax, Ai, Ap); // Create Settings - const mxArray* mxSettings = prhs[9]; - if(mxIsEmpty(mxSettings)){ - // use defaults - osqp_set_default_settings(settings); - } else { - //populate settings structure from mxArray input - copyMxStructToSettings(mxSettings, settings); + OSQPSettingsWrapper settings; + + if(!mxIsEmpty(prhs[9])){ + // Populate settings structure from mxArray input, otherwise the default settings are used + settings.ParseMxStruct(prhs[9]); } // Setup workspace //exitflag = osqp_setup(&(osqpData->work), data, settings); - exitflag = osqp_setup(&(osqpData->solver), dataP, dataQ, dataA, dataL, dataU, dataM, dataN, settings); + exitflag = osqp_setup(&(osqpData->solver), dataP, dataQ, dataA, dataL, dataU, dataM, dataN, settings.GetOSQPStruct()); //cleanup temporary structures // Data if (Px) c_free(Px); @@ -244,8 +274,6 @@ void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[]) if (dataU) c_free(dataU); if (dataP) c_free(dataP); if (dataA) c_free(dataA); - // Settings - if (settings) c_free(settings); // Report error (if any) if(exitflag){ @@ -271,14 +299,6 @@ void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[]) return; } - // report the version - if (!strcmp("version", cmd)) { - - plhs[0] = mxCreateString(osqp_version()); - - return; - } - // update linear cost and bounds if (!strcmp("update", cmd)) { @@ -312,37 +332,37 @@ void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[]) OSQPInt n, m; osqp_get_dimensions(osqpData->solver, &m, &n); if(!mxIsEmpty(q)){ - q_vec = copyToOSQPFloatVector(mxGetPr(q), n); + q_vec = cloneVector(mxGetPr(q), n); } if(!mxIsEmpty(l)){ - l_vec = copyToOSQPFloatVector(mxGetPr(l), m); + l_vec = cloneVector(mxGetPr(l), m); } if(!mxIsEmpty(u)){ - u_vec = copyToOSQPFloatVector(mxGetPr(u), m); + u_vec = cloneVector(mxGetPr(u), m); } if(!mxIsEmpty(Px)){ - Px_vec = copyToOSQPFloatVector(mxGetPr(Px), Px_n); + Px_vec = cloneVector(mxGetPr(Px), Px_n); } if(!mxIsEmpty(Ax)){ - Ax_vec = copyToOSQPFloatVector(mxGetPr(Ax), Ax_n); + Ax_vec = cloneVector(mxGetPr(Ax), Ax_n); } if(!mxIsEmpty(Px_idx)){ - Px_idx_vec = copyDoubleToOSQPIntVector(mxGetPr(Px_idx), Px_n); + Px_idx_vec = cloneVector(mxGetPr(Px_idx), Px_n); } if(!mxIsEmpty(Ax_idx)){ - Ax_idx_vec = copyDoubleToOSQPIntVector(mxGetPr(Ax_idx), Ax_n); + Ax_idx_vec = cloneVector(mxGetPr(Ax_idx), Ax_n); } if (!exitflag && (!mxIsEmpty(q) || !mxIsEmpty(l) || !mxIsEmpty(u))) { exitflag = osqp_update_data_vec(osqpData->solver, q_vec, l_vec, u_vec); - if (exitflag) exitflag=1; + if (exitflag) exitflag=1; } - + if (!exitflag && (!mxIsEmpty(Px) || !mxIsEmpty(Ax))) { exitflag = osqp_update_data_mat(osqpData->solver, Px_vec, Px_idx_vec, Px_n, Ax_vec, Ax_idx_vec, Ax_n); if (exitflag) exitflag=2; } - + // Free vectors if(!mxIsEmpty(q)) c_free(q_vec); @@ -364,29 +384,16 @@ void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[]) return; } - if (!strcmp("warm_start", cmd) || !strcmp("warm_start_x", cmd) || !strcmp("warm_start_y", cmd)) { - + if (!strcmp("warm_start", cmd)) { + //throw an error if this is called before solver is configured if(!osqpData->solver){ mexErrMsgTxt("Solver has not been initialized."); - } - - // Fill x and y - const mxArray *x = NULL; - const mxArray *y = NULL; - if (!strcmp("warm_start", cmd)) { - x = prhs[2]; - y = prhs[3]; - } - else if (!strcmp("warm_start_x", cmd)) { - x = prhs[2]; - y = NULL; } - else if (!strcmp("warm_start_y", cmd)) { - x = NULL; - y = prhs[2]; - } + // Fill x and y + const mxArray *x = prhs[2]; + const mxArray *y = prhs[3]; // Copy vectors to ensure they are cast as OSQPFloat OSQPFloat *x_vec = NULL; @@ -394,11 +401,12 @@ void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[]) OSQPInt n, m; osqp_get_dimensions(osqpData->solver, &m, &n); + if(!mxIsEmpty(x)){ - x_vec = copyToOSQPFloatVector(mxGetPr(x),n); + x_vec = cloneVector(mxGetPr(x),n); } if(!mxIsEmpty(y)){ - y_vec = copyToOSQPFloatVector(mxGetPr(y),m); + y_vec = cloneVector(mxGetPr(y),m); } // Warm start x and y @@ -441,8 +449,8 @@ void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[]) (osqpData->solver->info->status_val != OSQP_DUAL_INFEASIBLE)){ //primal and dual solutions - castToDoubleArr(osqpData->solver->solution->x, mxGetPr(plhs[0]), n); - castToDoubleArr(osqpData->solver->solution->y, mxGetPr(plhs[1]), m); + copyVector(mxGetPr(plhs[0]), osqpData->solver->solution->x, n); + copyVector(mxGetPr(plhs[1]), osqpData->solver->solution->y, m); //infeasibility certificates -> NaN values setToNaN(mxGetPr(plhs[2]), m); @@ -456,7 +464,7 @@ void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[]) setToNaN(mxGetPr(plhs[1]), m); //primal infeasibility certificates - castToDoubleArr(osqpData->solver->solution->prim_inf_cert, mxGetPr(plhs[2]), m); + copyVector(mxGetPr(plhs[2]), osqpData->solver->solution->prim_inf_cert, m); //dual infeasibility certificates -> NaN values setToNaN(mxGetPr(plhs[3]), n); @@ -474,7 +482,7 @@ void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[]) setToNaN(mxGetPr(plhs[2]), m); //dual infeasibility certificates - castToDoubleArr(osqpData->solver->solution->dual_inf_cert, mxGetPr(plhs[3]), n); + copyVector(mxGetPr(plhs[3]), osqpData->solver->solution->dual_inf_cert, n); // Set objective value to -infinity osqpData->solver->info->obj_val = -mxGetInf(); @@ -484,98 +492,9 @@ void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[]) osqpData->solver->info->obj_val = mxGetNaN(); } - plhs[4] = copyInfoToMxStruct(osqpData->solver->info); // Info structure - - return; - } - - if (!strcmp("constant", cmd)) { // Return solver constants - - char constant[32]; - mxGetString(prhs[2], constant, sizeof(constant)); - - if (!strcmp("OSQP_INFTY", constant)){ - plhs[0] = mxCreateDoubleScalar(OSQP_INFTY); - return; - } - if (!strcmp("OSQP_NAN", constant)){ - plhs[0] = mxCreateDoubleScalar(mxGetNaN()); - return; - } - - if (!strcmp("OSQP_SOLVED", constant)){ - plhs[0] = mxCreateDoubleScalar(OSQP_SOLVED); - return; - } - - if (!strcmp("OSQP_SOLVED_INACCURATE", constant)){ - plhs[0] = mxCreateDoubleScalar(OSQP_SOLVED_INACCURATE); - return; - } - - if (!strcmp("OSQP_UNSOLVED", constant)){ - plhs[0] = mxCreateDoubleScalar(OSQP_UNSOLVED); - return; - } - - if (!strcmp("OSQP_PRIMAL_INFEASIBLE", constant)){ - plhs[0] = mxCreateDoubleScalar(OSQP_PRIMAL_INFEASIBLE); - return; - } - - if (!strcmp("OSQP_PRIMAL_INFEASIBLE_INACCURATE", constant)){ - plhs[0] = mxCreateDoubleScalar(OSQP_PRIMAL_INFEASIBLE_INACCURATE); - return; - } - - if (!strcmp("OSQP_DUAL_INFEASIBLE", constant)){ - plhs[0] = mxCreateDoubleScalar(OSQP_DUAL_INFEASIBLE); - return; - } - - if (!strcmp("OSQP_DUAL_INFEASIBLE_INACCURATE", constant)){ - plhs[0] = mxCreateDoubleScalar(OSQP_DUAL_INFEASIBLE_INACCURATE); - return; - } - - if (!strcmp("OSQP_MAX_ITER_REACHED", constant)){ - plhs[0] = mxCreateDoubleScalar(OSQP_MAX_ITER_REACHED); - return; - } - - if (!strcmp("OSQP_NON_CVX", constant)){ - plhs[0] = mxCreateDoubleScalar(OSQP_NON_CVX); - return; - } - - if (!strcmp("OSQP_TIME_LIMIT_REACHED", constant)){ - plhs[0] = mxCreateDoubleScalar(OSQP_TIME_LIMIT_REACHED); - return; - } - - // Linear system solvers - if (!strcmp("QDLDL_SOLVER", constant)){ - plhs[0] = mxCreateDoubleScalar(QDLDL_SOLVER); - return; - } - - if (!strcmp("OSQP_UNKNOWN_SOLVER", constant)){ - plhs[0] = mxCreateDoubleScalar(OSQP_UNKNOWN_SOLVER); - return; - } - - if (!strcmp("OSQP_DIRECT_SOLVER", constant)){ - plhs[0] = mxCreateDoubleScalar(OSQP_DIRECT_SOLVER); - return; - } - - if (!strcmp("OSQP_INDIRECT_SOLVER", constant)){ - plhs[0] = mxCreateDoubleScalar(OSQP_INDIRECT_SOLVER); - return; - } - - - mexErrMsgTxt("Constant not recognized."); + // Populate the info structure + OSQPInfoWrapper info(osqpData->solver->info); + plhs[4] = info.GetMxStruct(); return; } @@ -583,248 +502,3 @@ void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[]) // Got here, so command not recognized mexErrMsgTxt("Command not recognized."); } - -/** - * This function dynamically allocates OSQPSovler and sets all the properties of OSQPSolver to NULL. - * WARNING: The memory allocated here (OSQPSolver*) needs to be freed. - * WARNING: Any dynamically allocated pointers must be freed before calling this function. -*/ -OSQPSolver* initializeOSQPSolver() { - OSQPSolver* osqpSolver = new OSQPSolver; - osqpSolver->info = NULL; - osqpSolver->settings = NULL; - osqpSolver->solution = NULL; - osqpSolver->work = NULL; - //osqp_set_default_settings(osqpSolver->settings); - return osqpSolver; -} - -//Dynamically creates a OSQPFloat vector copy of the input. -//Returns an empty pointer if vecData is NULL -OSQPFloat* copyToOSQPFloatVector(double * vecData, OSQPInt numel){ - if (!vecData) return NULL; - - //This needs to be freed! - OSQPFloat* out = (OSQPFloat*)c_malloc(numel * sizeof(OSQPFloat)); - - //copy data - for(OSQPInt i=0; i < numel; i++){ - out[i] = (OSQPFloat)vecData[i]; - } - return out; -} - -//Dynamically creates a OSQPInt vector copy of the input. -OSQPInt* copyToOSQPIntVector(mwIndex* vecData, OSQPInt numel){ - // This memory needs to be freed! - OSQPInt* out = (OSQPInt*)c_malloc(numel * sizeof(OSQPInt)); - - //copy data - for(OSQPInt i=0; i < numel; i++){ - out[i] = (OSQPInt)vecData[i]; - } - return out; - -} - -//Dynamically copies a double vector to OSQPInt. -OSQPInt* copyDoubleToOSQPIntVector(double* vecData, OSQPInt numel){ - // This memory needs to be freed! - OSQPInt* out = (OSQPInt*)c_malloc(numel * sizeof(OSQPInt)); - - //copy data - for(OSQPInt i=0; i < numel; i++){ - out[i] = (OSQPInt)vecData[i]; - } - return out; - -} - -/* DELETE HERE -void castCintToDoubleArr(OSQPInt *arr, double* arr_out, OSQPInt len) { - for (OSQPInt i = 0; i < len; i++) { - arr_out[i] = (double)arr[i]; - } -}*/ - -//This function frees the memory allocated in an OSQPCscMatrix M -void freeCscMatrix(OSQPCscMatrix* M) { - if (!M) return; - if (M->p) c_free(M->p); - if (M->i) c_free(M->i); - if (M->x) c_free(M->x); - c_free(M); -} - -void castToDoubleArr(OSQPFloat *arr, double* arr_out, OSQPInt len) { - for (OSQPInt i = 0; i < len; i++) { - arr_out[i] = (double)arr[i]; - } -} - -void setToNaN(double* arr_out, OSQPInt len){ - OSQPInt i; - for (i = 0; i < len; i++) { - arr_out[i] = mxGetNaN(); - } -} - -mxArray* copyInfoToMxStruct(OSQPInfo* info){ - - //create mxArray with the right number of fields - int nfields = sizeof(OSQP_INFO_FIELDS) / sizeof(OSQP_INFO_FIELDS[0]); - mxArray* mxPtr = mxCreateStructMatrix(1,1,nfields,OSQP_INFO_FIELDS); - - //map the OSQP_INFO fields one at a time into mxArrays - //matlab all numeric values as doubles - mxSetField(mxPtr, 0, "iter", mxCreateDoubleScalar(info->iter)); - mxSetField(mxPtr, 0, "status", mxCreateString(info->status)); - mxSetField(mxPtr, 0, "status_val", mxCreateDoubleScalar(info->status_val)); - mxSetField(mxPtr, 0, "status_polish", mxCreateDoubleScalar(info->status_polish)); - mxSetField(mxPtr, 0, "obj_val", mxCreateDoubleScalar(info->obj_val)); - mxSetField(mxPtr, 0, "prim_res", mxCreateDoubleScalar(info->prim_res)); - mxSetField(mxPtr, 0, "dual_res", mxCreateDoubleScalar(info->dual_res)); - - mxSetField(mxPtr, 0, "setup_time", mxCreateDoubleScalar(info->setup_time)); - mxSetField(mxPtr, 0, "solve_time", mxCreateDoubleScalar(info->solve_time)); - mxSetField(mxPtr, 0, "update_time", mxCreateDoubleScalar(info->update_time)); - mxSetField(mxPtr, 0, "polish_time", mxCreateDoubleScalar(info->polish_time)); - mxSetField(mxPtr, 0, "run_time", mxCreateDoubleScalar(info->run_time)); - - - mxSetField(mxPtr, 0, "rho_updates", mxCreateDoubleScalar(info->rho_updates)); - mxSetField(mxPtr, 0, "rho_estimate", mxCreateDoubleScalar(info->rho_estimate)); - - - return mxPtr; - -} - -mxArray* copySettingsToMxStruct(OSQPSettings* settings){ - - int nfields = sizeof(OSQP_SETTINGS_FIELDS) / sizeof(OSQP_SETTINGS_FIELDS[0]); - mxArray* mxPtr = mxCreateStructMatrix(1,1,nfields,OSQP_SETTINGS_FIELDS); - - //map the OSQP_SETTINGS fields one at a time into mxArrays - //matlab handles everything as a double - mxSetField(mxPtr, 0, "rho", mxCreateDoubleScalar(settings->rho)); - mxSetField(mxPtr, 0, "sigma", mxCreateDoubleScalar(settings->sigma)); - mxSetField(mxPtr, 0, "scaling", mxCreateDoubleScalar(settings->scaling)); - mxSetField(mxPtr, 0, "adaptive_rho", mxCreateDoubleScalar(settings->adaptive_rho)); - mxSetField(mxPtr, 0, "adaptive_rho_interval", mxCreateDoubleScalar(settings->adaptive_rho_interval)); - mxSetField(mxPtr, 0, "adaptive_rho_tolerance", mxCreateDoubleScalar(settings->adaptive_rho_tolerance)); - mxSetField(mxPtr, 0, "adaptive_rho_fraction", mxCreateDoubleScalar(settings->adaptive_rho_fraction)); - mxSetField(mxPtr, 0, "max_iter", mxCreateDoubleScalar(settings->max_iter)); - mxSetField(mxPtr, 0, "eps_abs", mxCreateDoubleScalar(settings->eps_abs)); - mxSetField(mxPtr, 0, "eps_rel", mxCreateDoubleScalar(settings->eps_rel)); - mxSetField(mxPtr, 0, "eps_prim_inf", mxCreateDoubleScalar(settings->eps_prim_inf)); - mxSetField(mxPtr, 0, "eps_dual_inf", mxCreateDoubleScalar(settings->eps_dual_inf)); - mxSetField(mxPtr, 0, "alpha", mxCreateDoubleScalar(settings->alpha)); - mxSetField(mxPtr, 0, "linsys_solver", mxCreateDoubleScalar(settings->linsys_solver)); - mxSetField(mxPtr, 0, "delta", mxCreateDoubleScalar(settings->delta)); - mxSetField(mxPtr, 0, "polish_refine_iter", mxCreateDoubleScalar(settings->polish_refine_iter)); - mxSetField(mxPtr, 0, "verbose", mxCreateDoubleScalar(settings->verbose)); - mxSetField(mxPtr, 0, "scaled_termination", mxCreateDoubleScalar(settings->scaled_termination)); - mxSetField(mxPtr, 0, "check_termination", mxCreateDoubleScalar(settings->check_termination)); - mxSetField(mxPtr, 0, "warm_starting", mxCreateDoubleScalar(settings->warm_starting)); - mxSetField(mxPtr, 0, "time_limit", mxCreateDoubleScalar(settings->time_limit)); - mxSetField(mxPtr, 0, "device", mxCreateDoubleScalar(settings->device)); - mxSetField(mxPtr, 0, "polishing", mxCreateDoubleScalar(settings->polishing)); - mxSetField(mxPtr, 0, "rho_is_vec", mxCreateDoubleScalar(settings->rho_is_vec)); - mxSetField(mxPtr, 0, "cg_max_iter", mxCreateDoubleScalar(settings->cg_max_iter)); - mxSetField(mxPtr, 0, "cg_tol_reduction", mxCreateDoubleScalar(settings->cg_tol_reduction)); - mxSetField(mxPtr, 0, "cg_tol_fraction", mxCreateDoubleScalar(settings->cg_tol_fraction)); - mxSetField(mxPtr, 0, "time_limit", mxCreateDoubleScalar(settings->time_limit)); - mxSetField(mxPtr, 0, "cg_precond", mxCreateDoubleScalar(settings->cg_precond)); - return mxPtr; -} - - -// ====================================================================== - -void copyMxStructToSettings(const mxArray* mxPtr, OSQPSettings* settings){ - - //this function assumes that only a complete and validated structure - //will be passed. matlab mex-side function is responsible for checking - //structure validity - - //map the OSQP_SETTINGS fields one at a time into mxArrays - //matlab handles everything as a double - settings->rho = (OSQPFloat)mxGetScalar(mxGetField(mxPtr, 0, "rho")); - settings->sigma = (OSQPFloat)mxGetScalar(mxGetField(mxPtr, 0, "sigma")); - settings->scaling = (OSQPInt)mxGetScalar(mxGetField(mxPtr, 0, "scaling")); - settings->adaptive_rho = (OSQPInt)mxGetScalar(mxGetField(mxPtr, 0, "adaptive_rho")); - settings->adaptive_rho_interval = (OSQPInt)mxGetScalar(mxGetField(mxPtr, 0, "adaptive_rho_interval")); - settings->adaptive_rho_tolerance = (OSQPFloat)mxGetScalar(mxGetField(mxPtr, 0, "adaptive_rho_tolerance")); - settings->adaptive_rho_fraction = (OSQPFloat)mxGetScalar(mxGetField(mxPtr, 0, "adaptive_rho_fraction")); - settings->max_iter = (OSQPInt)mxGetScalar(mxGetField(mxPtr, 0, "max_iter")); - settings->eps_abs = (OSQPFloat)mxGetScalar(mxGetField(mxPtr, 0, "eps_abs")); - settings->eps_rel = (OSQPFloat)mxGetScalar(mxGetField(mxPtr, 0, "eps_rel")); - settings->eps_prim_inf = (OSQPFloat)mxGetScalar(mxGetField(mxPtr, 0, "eps_dual_inf")); - settings->eps_dual_inf = (OSQPFloat)mxGetScalar(mxGetField(mxPtr, 0, "eps_dual_inf")); - settings->alpha = (OSQPFloat)mxGetScalar(mxGetField(mxPtr, 0, "alpha")); - settings->linsys_solver = (enum osqp_linsys_solver_type) (OSQPInt) mxGetScalar(mxGetField(mxPtr, 0, "linsys_solver")); - settings->delta = (OSQPFloat)mxGetScalar(mxGetField(mxPtr, 0, "delta")); - settings->polish_refine_iter = (OSQPInt)mxGetScalar(mxGetField(mxPtr, 0, "polish_refine_iter")); - settings->verbose = (OSQPInt)mxGetScalar(mxGetField(mxPtr, 0, "verbose")); - settings->scaled_termination = (OSQPInt)mxGetScalar(mxGetField(mxPtr, 0, "scaled_termination")); - settings->check_termination = (OSQPInt)mxGetScalar(mxGetField(mxPtr, 0, "check_termination")); - settings->warm_starting = (OSQPInt)mxGetScalar(mxGetField(mxPtr, 0, "warm_starting")); - settings->time_limit = (OSQPFloat)mxGetScalar(mxGetField(mxPtr, 0, "time_limit")); - settings->device = (OSQPInt)mxGetScalar(mxGetField(mxPtr, 0, "device")); - settings->polishing = (OSQPInt)mxGetScalar(mxGetField(mxPtr, 0, "polishing")); - settings->rho_is_vec = (OSQPInt)mxGetScalar(mxGetField(mxPtr, 0, "rho_is_vec")); - settings->cg_max_iter = (OSQPInt)mxGetScalar(mxGetField(mxPtr, 0, "cg_max_iter")); - settings->cg_tol_reduction = (OSQPInt)mxGetScalar(mxGetField(mxPtr, 0, "cg_tol_reduction")); - settings->cg_tol_fraction = (OSQPFloat)mxGetScalar(mxGetField(mxPtr, 0, "cg_tol_fraction")); - settings->cg_precond = (osqp_precond_type) (OSQPInt) (mxGetField(mxPtr, 0, "cg_precond")); -} - -void copyUpdatedSettingsToWork(const mxArray* mxPtr ,OSQPSolver* osqpSolver){ - - OSQPInt exitflag; - //TODO (Amit): Update this - OSQPSettings* update_template = (OSQPSettings *)mxCalloc(1,sizeof(OSQPSettings)); - if (!update_template) mexErrMsgTxt("Failed to allocate a temporary OSQPSettings object."); - - update_template->device = (OSQPInt)mxGetScalar(mxGetField(mxPtr, 0, "device")); - update_template->linsys_solver = (enum osqp_linsys_solver_type)mxGetScalar(mxGetField(mxPtr, 0, "linsys_solver")); - update_template->verbose = (OSQPInt)mxGetScalar(mxGetField(mxPtr, 0, "verbose")); - update_template->warm_starting = (OSQPInt)mxGetScalar(mxGetField(mxPtr, 0, "warm_starting")); - update_template->scaling = (OSQPInt)mxGetScalar(mxGetField(mxPtr, 0, "scaling")); - update_template->polishing = (OSQPInt)mxGetScalar(mxGetField(mxPtr, 0, "polishing")); - - update_template->rho = (OSQPFloat)mxGetScalar(mxGetField(mxPtr, 0, "rho")); - update_template->rho_is_vec = (OSQPInt)mxGetScalar(mxGetField(mxPtr, 0, "rho_is_vec")); - update_template->sigma = (OSQPFloat)mxGetScalar(mxGetField(mxPtr, 0, "sigma")); - update_template->alpha = (OSQPFloat)mxGetScalar(mxGetField(mxPtr, 0, "alpha")); - - update_template->cg_max_iter = (OSQPInt)mxGetScalar(mxGetField(mxPtr, 0, "cg_max_iter")); - update_template->cg_tol_reduction = (OSQPInt)mxGetScalar(mxGetField(mxPtr, 0, "cg_tol_reduction")); - update_template->cg_tol_fraction = (OSQPFloat)mxGetScalar(mxGetField(mxPtr, 0, "cg_tol_fraction")); - update_template->cg_precond = (osqp_precond_type)mxGetScalar(mxGetField(mxPtr, 0, "cg_precond")); - - update_template->adaptive_rho = (OSQPInt)mxGetScalar(mxGetField(mxPtr, 0, "adaptive_rho")); - update_template->adaptive_rho_interval = (OSQPInt)mxGetScalar(mxGetField(mxPtr, 0, "adaptive_rho_interval")); - update_template->adaptive_rho_fraction = (OSQPFloat)mxGetScalar(mxGetField(mxPtr, 0, "adaptive_rho_fraction")); - update_template->adaptive_rho_tolerance = (OSQPFloat)mxGetScalar(mxGetField(mxPtr, 0, "adaptive_rho_tolerance")); - - update_template->max_iter = (OSQPInt)mxGetScalar(mxGetField(mxPtr, 0, "max_iter")); - update_template->eps_abs = (OSQPFloat)mxGetScalar(mxGetField(mxPtr, 0, "eps_abs")); - update_template->eps_rel = (OSQPFloat)mxGetScalar(mxGetField(mxPtr, 0, "eps_rel")); - update_template->eps_prim_inf = (OSQPFloat)mxGetScalar(mxGetField(mxPtr, 0, "eps_prim_inf")); - update_template->eps_dual_inf = (OSQPFloat)mxGetScalar(mxGetField(mxPtr, 0, "eps_dual_inf")); - update_template->scaled_termination = (OSQPInt)mxGetScalar(mxGetField(mxPtr, 0, "scaled_termination")); - update_template->check_termination = (OSQPInt)mxGetScalar(mxGetField(mxPtr, 0, "check_termination")); - update_template->time_limit = (OSQPFloat)mxGetScalar(mxGetField(mxPtr, 0, "time_limit")); - - update_template->delta = (OSQPFloat)mxGetScalar(mxGetField(mxPtr, 0, "delta")); - update_template->polish_refine_iter = (OSQPInt)mxGetScalar(mxGetField(mxPtr, 0, "polish_refine_iter")); - - osqp_update_settings(osqpSolver, update_template); - //rho needs to be updated separetly, it is not updated in osqp_update_settings - OSQPFloat rho_new = (OSQPFloat)mxGetScalar(mxGetField(mxPtr, 0, "rho")); - if (rho_new != osqpSolver->settings->rho) osqp_update_rho(osqpSolver, rho_new); - - if (update_template) c_free(update_template); -} \ No newline at end of file diff --git a/c_sources/osqp_struct.h b/c_sources/osqp_struct.h new file mode 100644 index 0000000..94a432f --- /dev/null +++ b/c_sources/osqp_struct.h @@ -0,0 +1,205 @@ +#ifndef OSQP_STRUCT_H_ +#define OSQP_STRUCT_H_ + +#include +#include +#include +#include + +#include +#include + +#include "memory_matlab.h" +#include + +/** + * Base class used to store the field types for a struct. + */ +class OSQPStructFieldBase { +public: + OSQPStructFieldBase() {} + + /** + * Set the field in the given Matlab struct to the value of this field + */ + virtual void ToMxStruct(mxArray* aStruct) = 0; + + /** + * Set the field in the internal struct with the data from aStruct + */ + virtual void ToOSQPStruct(const mxArray* aStruct) = 0; +}; + +/** + * Class to hold a numeric struct field (e.g. float/double/int/enum, etc.). + */ +template +class OSQPStructField : public OSQPStructFieldBase { +public: + OSQPStructField(T* aStructPtr, std::string aName) : + m_structPtr(aStructPtr), + m_name(aName) { + } + + void ToMxStruct(mxArray* aStruct) override { + mxAddField(aStruct, m_name.data()); + mxSetField(aStruct, 0, m_name.data(), mxCreateDoubleScalar(*m_structPtr)); + } + + void ToOSQPStruct(const mxArray* aStruct) override { + *(m_structPtr) = static_cast(mxGetScalar(mxGetField(aStruct, 0, m_name.data()))); + } + +private: + T* m_structPtr; + std::string m_name; +}; + +/** + * Class to hold a character array (actual array, not char* array) field in a struct. + */ +class OSQPStructFieldCharArray : public OSQPStructFieldBase { +public: + OSQPStructFieldCharArray(char* aStructPtr, size_t aLength, std::string aName) : + m_structPtr(aStructPtr), + m_name(aName), + m_length(aLength) { + } + + void ToMxStruct(mxArray* aStruct) override { + mxAddField(aStruct, m_name.data()); + mxSetField(aStruct, 0, m_name.data(), mxCreateString(m_structPtr)); + } + + void ToOSQPStruct(const mxArray* aStruct) override { + mxArray* tmp = mxGetField(aStruct, 0, m_name.data()); + mxGetString(tmp, m_structPtr, m_length); + } + +private: + char* m_structPtr; + std::string m_name; + size_t m_length; +}; + +/** + * Wrap a struct from OSQP to automatically transfer the data between OSQP and Matlab. + */ +template +class OSQPStructWrapper { +public: + /** + * Initialize the wrapper using the default values. + */ + OSQPStructWrapper() { + // Allocate the default struct and register field handlers + registerFields(); + } + + /** + * Initialize the wrapper using the values from the OSQP struct. + */ + OSQPStructWrapper(const T* aStruct) { + // Allocate the default struct and register field handlers + registerFields(); + ParseOSQPStruct(aStruct); + } + + /** + * Initialize the wrapper using the values from the Matlab struct + */ + OSQPStructWrapper(const mxArray* aStruct) { + // Allocate the default struct and register field handlers + registerFields(); + ParseMxStruct(aStruct); + } + + ~OSQPStructWrapper() { + for(auto& s : m_structFields) { + delete s; + } + + c_free(m_struct); + } + + /** + * Return a Matlab struct populated with the values of the current struct + * contained in this wrapper. + * + * @return a Matlab struct with a copy of the struct (caller owns this copy and must free it) + */ + mxArray* GetMxStruct() { + // No fields are added right now, they are added in the for loop when they are set + mxArray* matStruct = mxCreateStructMatrix(1, 1, 0, NULL); + + // Copy the current struct into the struct to return + for(const auto& s : m_structFields) { + s->ToMxStruct(matStruct); + } + + return matStruct; + } + + /** + * Read a Matlab struct and populate the wrapper with its values. + */ + void ParseMxStruct(const mxArray* aStruct) { + for(const auto& s : m_structFields) { + s->ToOSQPStruct(aStruct); + } + } + + /** + * Get a copy of the struct contained inside this wrapper. + * + * @return a copy of the struct (caller owns this copy and must free it) + */ + T* GetOSQPStructCopy() { + // Allocate the default struct + T* ret = static_cast(c_calloc(1, sizeof(T))); + + // Copy the current values for their return + std::memcpy(ret, m_struct, sizeof(T)); + return ret; + } + + /** + * Get the pointer to the internal struct object. + */ + T* GetOSQPStruct() { + return m_struct; + } + + /* + * Read an existing OSQP struct object into this wrapper. + * The struct elements are copied, so no ownership of the aStruct pointer is transferred. + */ + void ParseOSQPStruct(const T* aStruct) { + std::memcpy(m_struct, aStruct, sizeof(T)); + } + +private: + /** + * Register all the fields for the wrapper. + * This function should be specialized for each struct type to map the fields appropriately. + */ + void registerFields(); + + // All struct fields + std::vector m_structFields; + + // Base OSQP struct object. Owned by this wrapper. + T* m_struct; +}; + +/** + * Wrapper around the OSQPSettings struct + */ +typedef OSQPStructWrapper OSQPSettingsWrapper; + +/** + * Wrapper around the OSQPInfo struct + */ +typedef OSQPStructWrapper OSQPInfoWrapper; + +#endif \ No newline at end of file diff --git a/c_sources/osqp_struct_info.cpp b/c_sources/osqp_struct_info.cpp new file mode 100644 index 0000000..b454a1c --- /dev/null +++ b/c_sources/osqp_struct_info.cpp @@ -0,0 +1,43 @@ +#include +#include "osqp_struct.h" + + +/* + * Specialization of the struct wrapper for the OSQPInfo struct. + */ +template<> +void OSQPStructWrapper::registerFields() { + m_struct = static_cast(c_calloc(1, sizeof(OSQPInfo))); + + if(!m_struct) + mexErrMsgTxt("Failed to allocate a OSQPInfo object."); + + /* + * Register the mapping between struct field name and the info struct memory location + */ + // Solver status + m_structFields.push_back(new OSQPStructFieldCharArray(m_struct->status, 32, "status")); + m_structFields.push_back(new OSQPStructField(&m_struct->status_val, "status_val")); + m_structFields.push_back(new OSQPStructField(&m_struct->status_polish, "status_polish")); + + // Solution quality + m_structFields.push_back(new OSQPStructField(&m_struct->obj_val, "obj_val")); + m_structFields.push_back(new OSQPStructField(&m_struct->prim_res, "prim_res")); + m_structFields.push_back(new OSQPStructField(&m_struct->dual_res, "dual_res")); + + // Algorithm information + m_structFields.push_back(new OSQPStructField(&m_struct->iter, "iter")); + m_structFields.push_back(new OSQPStructField(&m_struct->rho_updates, "rho_updates")); + m_structFields.push_back(new OSQPStructField(&m_struct->rho_estimate, "rho_estimate")); + + // Timing information + m_structFields.push_back(new OSQPStructField(&m_struct->setup_time, "setup_time")); + m_structFields.push_back(new OSQPStructField(&m_struct->solve_time, "solve_time")); + m_structFields.push_back(new OSQPStructField(&m_struct->update_time, "update_time")); + m_structFields.push_back(new OSQPStructField(&m_struct->polish_time, "polish_time")); + m_structFields.push_back(new OSQPStructField(&m_struct->run_time, "run_time")); +} + + +// Instantiate the OSQPInfo wrapper class +template class OSQPStructWrapper; \ No newline at end of file diff --git a/c_sources/osqp_struct_settings.cpp b/c_sources/osqp_struct_settings.cpp new file mode 100644 index 0000000..97fc357 --- /dev/null +++ b/c_sources/osqp_struct_settings.cpp @@ -0,0 +1,61 @@ +#include +#include "osqp_struct.h" + +/* + * Specialization for the settings struct + */ +template<> +void OSQPStructWrapper::registerFields() { + m_struct = static_cast(c_calloc(1, sizeof(OSQPSettings))); + + if(!m_struct) + mexErrMsgTxt("Failed to allocate a OSQPSettings object."); + + osqp_set_default_settings(m_struct); + + /* + * Register the mapping between struct field name and the settings memory location + */ + m_structFields.push_back(new OSQPStructField(&m_struct->device, "device")); + m_structFields.push_back(new OSQPStructField(&m_struct->linsys_solver, "linsys_solver")); + m_structFields.push_back(new OSQPStructField(&m_struct->verbose, "verbose")); + m_structFields.push_back(new OSQPStructField(&m_struct->warm_starting, "warm_starting")); + m_structFields.push_back(new OSQPStructField(&m_struct->scaling, "scaling")); + m_structFields.push_back(new OSQPStructField(&m_struct->polishing, "polishing")); + + // ADMM parameters + m_structFields.push_back(new OSQPStructField(&m_struct->rho, "rho")); + m_structFields.push_back(new OSQPStructField(&m_struct->rho_is_vec, "rho_is_vec")); + m_structFields.push_back(new OSQPStructField(&m_struct->sigma, "sigma")); + m_structFields.push_back(new OSQPStructField(&m_struct->alpha, "alpha")); + + // CG settings + m_structFields.push_back(new OSQPStructField(&m_struct->cg_max_iter, "cg_max_iter")); + m_structFields.push_back(new OSQPStructField(&m_struct->cg_tol_reduction, "cg_tol_reduction")); + m_structFields.push_back(new OSQPStructField(&m_struct->cg_tol_fraction, "cg_tol_fraction")); + m_structFields.push_back(new OSQPStructField(&m_struct->cg_precond, "cg_precond")); + + // adaptive rho logic + m_structFields.push_back(new OSQPStructField(&m_struct->adaptive_rho, "adaptive_rho")); + m_structFields.push_back(new OSQPStructField(&m_struct->adaptive_rho_interval, "adaptive_rho_interval")); + m_structFields.push_back(new OSQPStructField(&m_struct->adaptive_rho_fraction, "adaptive_rho_fraction")); + m_structFields.push_back(new OSQPStructField(&m_struct->adaptive_rho_tolerance, "adaptive_rho_tolerance")); + + // termination parameters + m_structFields.push_back(new OSQPStructField(&m_struct->max_iter, "max_iter")); + m_structFields.push_back(new OSQPStructField(&m_struct->eps_abs, "eps_abs")); + m_structFields.push_back(new OSQPStructField(&m_struct->eps_rel, "eps_rel")); + m_structFields.push_back(new OSQPStructField(&m_struct->eps_prim_inf, "eps_prim_inf")); + m_structFields.push_back(new OSQPStructField(&m_struct->eps_dual_inf, "eps_dual_inf")); + m_structFields.push_back(new OSQPStructField(&m_struct->scaled_termination, "scaled_termination")); + m_structFields.push_back(new OSQPStructField(&m_struct->check_termination, "check_termination")); + m_structFields.push_back(new OSQPStructField(&m_struct->time_limit, "time_limit")); + + // polishing parameters + m_structFields.push_back(new OSQPStructField(&m_struct->delta, "delta")); + m_structFields.push_back(new OSQPStructField(&m_struct->polish_refine_iter, "polish_refine_iter")); +} + + +// Instantiate the OSQPSettings wrapper class +template class OSQPStructWrapper; \ No newline at end of file diff --git a/make_osqp.m b/make_osqp.m deleted file mode 100644 index 79058b7..0000000 --- a/make_osqp.m +++ /dev/null @@ -1,167 +0,0 @@ -function make_osqp(varargin) -% Matlab MEX makefile for OSQP. -% -% MAKE_OSQP(VARARGIN) is a make file for OSQP solver. It -% builds OSQP and its components from source. -% -% WHAT is the last element of VARARGIN and cell array of strings, -% with the following options: -% -% {}, '' (empty string) or 'all': build all components and link. -% -% 'osqp_mex': builds the OSQP mex interface and the OSQP library -% -% Additional commands: -% -% 'clean': Delete all compiled files -% 'purge': Delete all compiled files and copied code generation files - - -if( nargin == 0 ) - what = {'all'}; - verbose = false; -elseif ( nargin == 1 && ismember('-verbose', varargin) ) - what = {'all'}; - verbose = true; -else - what = varargin{nargin}; - if(isempty(strfind(what, 'all')) && ... - isempty(strfind(what, 'osqp_mex')) && ... - isempty(strfind(what, 'clean')) && ... - isempty(strfind(what, 'purge'))) - fprintf('No rule to make target "%s", exiting.\n', what); - end - - verbose = ismember('-verbose', varargin); -end - -%% Determine where the various files are all located -% Various parts of the build system -[makefile_path,~,~] = fileparts( which( 'make_osqp.m' ) ); -osqp_mex_src_dir = fullfile( makefile_path, 'c_sources' ); -osqp_mex_build_dir = fullfile( osqp_mex_src_dir, 'build' ); -osqp_cg_src_dir = fullfile( osqp_mex_build_dir, 'codegen_src' ); -osqp_cg_dest_dir = fullfile( makefile_path, 'codegen', 'sources' ); - -% Determine where CMake should look for MATLAB -Matlab_ROOT = strrep( matlabroot, '\', '/' ); - -%% Try to unlock any pre-existing version of osqp_mex -% this prevents compile errors if a user builds, runs osqp -% and then tries to recompile -if(mislocked('osqp_mex')) - munlock('osqp_mex'); -end - -%% Configure, build and install the OSQP mex interface -if( any(strcmpi(what,'osqp_mex')) || any(strcmpi(what,'all')) ) - fprintf('Compiling OSQP solver mex interface...\n'); - - % Create build for the mex file and go inside - if exist( osqp_mex_build_dir, 'dir' ) - rmdir( osqp_mex_build_dir, 's' ); - end - mkdir( osqp_mex_build_dir ); -% cd( osqp_mex_build_dir ); - - % Extend path for CMake mac (via Homebrew) - PATH = getenv('PATH'); - if( (ismac) && (isempty(strfind(PATH, '/usr/local/bin'))) ) - setenv('PATH', [PATH ':/usr/local/bin']); - end - - - - %% Configure CMake for the mex interface - fprintf(' Configuring...' ) - [status, output] = system( sprintf( 'cmake -B %s -S %s -DCMAKE_BUILD_TYPE=RelWithDebInfo -DMatlab_ROOT_DIR=\"%s\"', osqp_mex_build_dir, osqp_mex_src_dir, Matlab_ROOT ), 'LD_LIBRARY_PATH', '' ); - if( status ) - fprintf( '\n' ); - disp( output ); - error( 'Error configuring CMake environment' ); - elseif( verbose ) - fprintf( '\n' ); - disp( output ); - else - fprintf( '\t\t\t\t\t[done]\n' ); - end - - %% Build the mex interface - fprintf( ' Building...') - [status, output] = system( sprintf( 'cmake --build %s --config Release', osqp_mex_build_dir ), 'LD_LIBRARY_PATH', '' ); - if( status ) - fprintf( '\n' ); - disp( output ); - error( 'Error compiling OSQP mex interface' ); - elseif( verbose ) - fprintf( '\n' ); - disp( output ); - else - fprintf( '\t\t\t\t\t\t[done]\n' ); - end - - - %% Install various files - fprintf( ' Installing...' ) - - % Copy mex file to root directory for use - if( ispc ) - [err, errmsg, ~] = copyfile( [osqp_mex_build_dir, filesep, 'Release', filesep, 'osqp_mex.mex*'], makefile_path ); - else - [err, errmsg, ~] = copyfile( [osqp_mex_build_dir, filesep, 'osqp_mex.mex*'], makefile_path ); - end - if( ~err ) - fprintf( '\n' ) - disp( errmsg ) - error( ' Error copying mex file' ) - end - - % Copy the code generation source files - % Create build for the mex file and go inside - if exist( osqp_cg_dest_dir, 'dir' ) - rmdir( osqp_cg_dest_dir, 's' ); - end - mkdir( osqp_cg_dest_dir ); - - [err, errmsg, ~] = copyfile( [osqp_cg_src_dir, filesep, '*'], osqp_cg_dest_dir ); - if( ~err ) - fprintf( '\n' ) - disp( errmsg ) - error( ' Error copying code generation source files' ) - end - - fprintf( '\t\t\t\t\t\t[done]\n' ); -end - -%% Clean and purge -if( any(strcmpi(what,'clean')) || any(strcmpi(what,'purge')) ) - fprintf('Cleaning OSQP mex files and build directory...'); - - % Delete mex file - mexfiles = dir(['*.', mexext]); - for i = 1 : length(mexfiles) - delete(mexfiles(i).name); - end - - % Delete OSQP build directory - if exist(osqp_mex_build_dir, 'dir') - rmdir(osqp_mex_build_dir, 's'); - end - - fprintf('\t\t[done]\n'); - - %% Purge only - if( any(strcmpi(what,'purge')) ) - fprintf('Cleaning OSQP codegen directories...'); - - % Delete codegen files - if exist(osqp_cg_dest_dir, 'dir') - rmdir(osqp_cg_dest_dir, 's'); - end - - fprintf('\t\t\t[done]\n'); - end - -end - -end diff --git a/osqp.m b/osqp.m deleted file mode 100755 index c87cbcc..0000000 --- a/osqp.m +++ /dev/null @@ -1,681 +0,0 @@ -classdef osqp < handle - % osqp interface class for OSQP solver - % This class provides a complete interface to the C implementation - % of the OSQP solver. - % - % osqp Properties: - % objectHandle - pointer to the C structure of OSQP solver - % - % osqp Methods: - % - % setup - configure solver with problem data - % solve - solve the QP - % update - modify problem vectors - % warm_start - set warm starting variables x and y - % - % default_settings - create default settings structure - % current_settings - get the current solver settings structure - % update_settings - update the current solver settings structure - % - % get_dimensions - get the number of variables and constraints - % version - return OSQP version - % constant - return a OSQP internal constant - % - % codegen - generate embeddable C code for the problem - - - properties (SetAccess = private, Hidden = true) - objectHandle % Handle to underlying C instance - end - methods(Static) - %% - function out = default_settings() - % DEFAULT_SETTINGS get the default solver settings structure - out = osqp_mex('default_settings', 'static'); - - % Convert linsys solver to string - out.linsys_solver = linsys_solver_to_string(out.linsys_solver); - - end - - %% - function out = constant(constant_name) - % CONSTANT Return solver constant - % C = CONSTANT(CONSTANT_NAME) return constant called CONSTANT_NAME - out = osqp_mex('constant', 'static', constant_name); - end - - %% - function out = version() - % Return OSQP version - out = osqp_mex('version', 'static'); - end - - end - methods - %% Constructor - Create a new solver instance - function this = osqp(varargin) - % Construct OSQP solver class - this.objectHandle = osqp_mex('new', varargin{:}); - end - - %% Destructor - destroy the solver instance - function delete(this) - % Destroy OSQP solver class - osqp_mex('delete', this.objectHandle); - end - - %% - function out = current_settings(this) - % CURRENT_SETTINGS get the current solver settings structure - out = osqp_mex('current_settings', this.objectHandle); - - % Convert linsys solver to string - out.linsys_solver = linsys_solver_to_string(out.linsys_solver); - - end - - %% - function update_settings(this,varargin) - % UPDATE_SETTINGS update the current solver settings structure - - %second input 'false' means that this is *not* a settings - %initialization, so some parameter/values will be disallowed - newSettings = validateSettings(this,false,varargin{:}); - - %write the solver settings. C-mex does not check input - %data or protect against disallowed parameter modifications - osqp_mex('update_settings', this.objectHandle, newSettings); - - end - - %% - function [n,m] = get_dimensions(this) - % GET_DIMENSIONS get the number of variables and constraints - - [n,m] = osqp_mex('get_dimensions', this.objectHandle); - - end - - %% - function update(this,varargin) - % UPDATE modify the linear cost term and/or lower and upper bounds - - %second input 'false' means that this is *not* a settings - %initialization, so some parameter/values will be disallowed - allowedFields = {'q','l','u','Px','Px_idx','Ax','Ax_idx'}; - - if(isempty(varargin)) - return; - elseif(length(varargin) == 1) - if(~isstruct(varargin{1})) - error('Single input should be a structure with new problem data'); - else - newData = varargin{1}; - end - else % param / value style assumed - newData = struct(varargin{:}); - end - - %check for unknown fields - newFields = fieldnames(newData); - badFieldsIdx = find(~ismember(newFields,allowedFields)); - if(~isempty(badFieldsIdx)) - error('Unrecognized input field ''%s'' detected',newFields{badFieldsIdx(1)}); - end - - %get all of the terms. Nonexistent fields will be passed - %as empty mxArrays - try q = double(full(newData.q(:))); catch q = []; end - try l = double(full(newData.l(:))); catch l = []; end - try u = double(full(newData.u(:))); catch u = []; end - try Px = double(full(newData.Px(:))); catch Px = []; end - try Px_idx = double(full(newData.Px_idx(:))); catch Px_idx = []; end - try Ax = double(full(newData.Ax(:))); catch Ax = []; end - try Ax_idx = double(full(newData.Ax_idx(:))); catch Ax_idx = []; end - - [n,m] = get_dimensions(this); - - assert(isempty(q) || length(q) == n, 'input ''q'' is the wrong size'); - assert(isempty(l) || length(l) == m, 'input ''u'' is the wrong size'); - assert(isempty(u) || length(u) == m, 'input ''l'' is the wrong size'); - assert(isempty(Px) || isempty(Px_idx) || length(Px) == length(Px_idx), ... - 'inputs ''Px'' and ''Px_idx'' must be the same size'); - assert(isempty(Ax) || isempty(Ax_idx) || length(Ax) == length(Ax_idx), ... - 'inputs ''Ax'' and ''Ax_idx'' must be the same size'); - - % Adjust index of Px_idx and Ax_idx to match 0-based indexing - % in C - if (~isempty(Px_idx)) - Px_idx = Px_idx - 1; - end - if (~isempty(Ax_idx)) - Ax_idx = Ax_idx - 1; - end - - % Convert infinity values to OSQP_INFTY - if (~isempty(u)) - u = min(u, osqp.constant('OSQP_INFTY')); - end - if (~isempty(l)) - l = max(l, -osqp.constant('OSQP_INFTY')); - end - - %write the new problem data. C-mex does not protect - %against unknown fields, but will handle empty values - osqp_mex('update', this.objectHandle, ... - q, l, u, Px, Px_idx, length(Px), Ax, Ax_idx, length(Ax)); - - end - - %% - function varargout = setup(this, varargin) - % SETUP configure solver with problem data - % - % setup(P,q,A,l,u,options) - - nargin = length(varargin); - - %dimension checks on user data. Mex function does not - %perform any checks on inputs, so check everything here - assert(nargin >= 5, 'incorrect number of inputs'); - [P,q,A,l,u] = deal(varargin{1:5}); - - % - % Get problem dimensions - % - - % Get number of variables n - if (isempty(P)) - if (~isempty(q)) - n = length(q); - else - if (~isempty(A)) - n = size(A, 2); - else - error('The problem does not have any variables'); - end - end - else - n = size(P, 1); - end - - % Get number of constraints m - if (isempty(A)) - m = 0; - else - m = size(A, 1); - assert(size(A, 2) == n, 'Incorrect dimension of A'); - end - - % - % Create sparse matrices and full vectors if they are empty - % - - if (isempty(P)) - P = sparse(n, n); - else - P = sparse(P); - end - if (~istriu(P)) - P = triu(P); - end - if (isempty(q)) - q = zeros(n, 1); - else - q = full(q(:)); - end - - % Create proper constraints if they are not passed - if (isempty(A) && (~isempty(l) || ~isempty(u))) || ... - (~isempty(A) && (isempty(l) && isempty(u))) - error('A must be supplied together with at least one bound l or u'); - end - - if (~isempty(A) && isempty(l)) - l = -Inf(m, 1); - end - - if (~isempty(A) && isempty(u)) - u = Inf(m, 1); - end - - if (isempty(A)) - A = sparse(m, n); - l = -Inf(m, 1); - u = Inf(m, 1); - else - l = full(l(:)); - u = full(u(:)); - A = sparse(A); - end - - - % - % Check vector dimensions (not checked from the C solver) - % - - assert(length(q) == n, 'Incorrect dimension of q'); - assert(length(l) == m, 'Incorrect dimension of l'); - assert(length(u) == m, 'Incorrect dimension of u'); - - % - % Convert infinity values to OSQP_INFINITY - % - u = min(u, osqp.constant('OSQP_INFTY')); - l = max(l, -osqp.constant('OSQP_INFTY')); - - - %make a settings structure from the remainder of the arguments. - %'true' means that this is a settings initialization, so all - %parameter/values are allowed. No extra inputs will result - %in default settings being passed back - theSettings = validateSettings(this,true,varargin{6:end}); - - [varargout{1:nargout}] = osqp_mex('setup', this.objectHandle, n,m,P,q,A,l,u,theSettings); - - end - - - %% - - function warm_start(this, varargin) - % WARM_START warm start primal and/or dual variables - % - % warm_start('x', x, 'y', y) - % - % or warm_start('x', x) - % or warm_start('y', y) - - - % Get problem dimensions - [n, m] = get_dimensions(this); - - % Get data - allowedFields = {'x','y'}; - - if(isempty(varargin)) - return; - elseif(length(varargin) == 1) - if(~isstruct(varargin{1})) - error('Single input should be a structure with new problem data'); - else - newData = varargin{1}; - end - else % param / value style assumed - newData = struct(varargin{:}); - end - - %check for unknown fields - newFields = fieldnames(newData); - badFieldsIdx = find(~ismember(newFields,allowedFields)); - if(~isempty(badFieldsIdx)) - error('Unrecognized input field ''%s'' detected',newFields{badFieldsIdx(1)}); - end - - %get all of the terms. Nonexistent fields will be passed - %as empty mxArrays - try x = double(full(newData.x(:))); catch x = []; end - try y = double(full(newData.y(:))); catch y = []; end - - % Check dimensions - assert(isempty(x) || length(x) == n, 'input ''x'' is the wrong size'); - assert(isempty(y) || length(y) == m, 'input ''y'' is the wrong size'); - - - % Decide which function to call - if (~isempty(x) && isempty(y)) - osqp_mex('warm_start_x', this.objectHandle, x); - return; - end - - if (isempty(x) && ~isempty(y)) - osqp_mex('warm_start_y', this.objectHandle, y); - end - - if (~isempty(x) && ~isempty(y)) - osqp_mex('warm_start', this.objectHandle, x, y); - end - - if (isempty(x) && isempty(y)) - error('Unrecognized fields'); - end - - end - - %% - function varargout = solve(this, varargin) - % SOLVE solve the QP - - nargoutchk(0,1); %either return nothing (but still solve), or a single output structure - [out.x, out.y, out.prim_inf_cert, out.dual_inf_cert, out.info] = osqp_mex('solve', this.objectHandle); - if(nargout) - varargout{1} = out; - end - return; - end - - %% - function codegen(this, target_dir, varargin) - % CODEGEN generate C code for the parametric problem - % - % codegen(target_dir,options) - - % Parse input arguments - p = inputParser; - defaultProject = ''; - expectedProject = {'', 'Makefile', 'MinGW Makefiles', 'Unix Makefiles', 'CodeBlocks', 'Xcode'}; - defaultParams = 'vectors'; - expectedParams = {'vectors', 'matrices'}; - defaultMexname = 'emosqp'; - defaultFloat = false; - defaultLong = true; - defaultFW = false; - - addRequired(p, 'target_dir', @isstr); - addParameter(p, 'project_type', defaultProject, ... - @(x) ischar(validatestring(x, expectedProject))); - addParameter(p, 'parameters', defaultParams, ... - @(x) ischar(validatestring(x, expectedParams))); - addParameter(p, 'mexname', defaultMexname, @isstr); - addParameter(p, 'FLOAT', defaultFloat, @islogical); - addParameter(p, 'LONG', defaultLong, @islogical); - addParameter(p, 'force_rewrite', defaultFW, @islogical); - - parse(p, target_dir, varargin{:}); - - % Set internal variables - if strcmp(p.Results.parameters, 'vectors') - embedded = 1; - else - embedded = 2; - end - if p.Results.FLOAT - float_flag = 'ON'; - else - float_flag = 'OFF'; - end - if p.Results.LONG - long_flag = 'ON'; - else - long_flag = 'OFF'; - end - if strcmp(p.Results.project_type, 'Makefile') - if (ispc) - project_type = 'MinGW Makefiles'; % Windows - elseif (ismac || isunix) - project_type = 'Unix Makefiles'; % Unix - end - else - project_type = p.Results.project_type; - end - - % Check whether the specified directory already exists - if exist(target_dir, 'dir') - if p.Results.force_rewrite - rmdir(target_dir, 's'); - else - while(1) - prompt = sprintf('Directory "%s" already exists. Do you want to replace it? y/n [y]: ', target_dir); - str = input(prompt, 's'); - - if any(strcmpi(str, {'','y'})) - rmdir(target_dir, 's'); - break; - elseif strcmpi(str, 'n') - return; - end - end - end - end - - % Import OSQP path - [osqp_path,~,~] = fileparts(which('osqp.m')); - - % Add codegen directory to path - addpath(fullfile(osqp_path, 'codegen')); - - % Path of osqp module - cg_dir = fullfile(osqp_path, 'codegen'); - files_to_generate_path = fullfile(cg_dir, 'files_to_generate'); - - % Get workspace structure - work = osqp_mex('get_workspace', this.objectHandle); - - % Make target directory - fprintf('Creating target directories...\t\t\t\t\t'); - target_configure_dir = fullfile(target_dir, 'configure'); - target_include_dir = fullfile(target_dir, 'include'); - target_src_dir = fullfile(target_dir, 'src'); - - if ~exist(target_dir, 'dir') - mkdir(target_dir); - end - if ~exist(target_configure_dir, 'dir') - mkdir(target_configure_dir); - end - if ~exist(target_include_dir, 'dir') - mkdir(target_include_dir); - end - if ~exist(target_src_dir, 'dir') - mkdir(fullfile(target_src_dir, 'osqp')); - end - fprintf('[done]\n'); - - % Copy source files to target directory - fprintf('Copying OSQP source files...\t\t\t\t\t'); - cdir = fullfile(cg_dir, 'sources', 'src'); - cfiles = dir(fullfile(cdir, '*.c')); - for i = 1 : length(cfiles) - if embedded == 1 - % Do not copy kkt.c if embedded is 1 - if ~strcmp(cfiles(i).name, 'kkt.c') - copyfile(fullfile(cdir, cfiles(i).name), ... - fullfile(target_src_dir, 'osqp', cfiles(i).name)); - end - else - copyfile(fullfile(cdir, cfiles(i).name), ... - fullfile(target_src_dir, 'osqp', cfiles(i).name)); - end - end - configure_dir = fullfile(cg_dir, 'sources', 'configure'); - configure_files = dir(fullfile(configure_dir, '*.h.in')); - for i = 1 : length(configure_files) - copyfile(fullfile(configure_dir, configure_files(i).name), ... - fullfile(target_configure_dir, configure_files(i).name)); - end - hdir = fullfile(cg_dir, 'sources', 'include'); - hfiles = dir(fullfile(hdir, '*.h')); - for i = 1 : length(hfiles) - if embedded == 1 - % Do not copy kkt.h if embedded is 1 - if ~strcmp(hfiles(i).name, 'kkt.h') - copyfile(fullfile(hdir, hfiles(i).name), ... - fullfile(target_include_dir, hfiles(i).name)); - end - else - copyfile(fullfile(hdir, hfiles(i).name), ... - fullfile(target_include_dir, hfiles(i).name)); - end - end - - % Copy cmake files - copyfile(fullfile(cdir, 'CMakeLists.txt'), ... - fullfile(target_src_dir, 'osqp', 'CMakeLists.txt')); - copyfile(fullfile(hdir, 'CMakeLists.txt'), ... - fullfile(target_include_dir, 'CMakeLists.txt')); - fprintf('[done]\n'); - - % Copy example.c - copyfile(fullfile(files_to_generate_path, 'example.c'), target_src_dir); - - % Render CMakeLists.txt - fidi = fopen(fullfile(files_to_generate_path, 'CMakeLists.txt'),'r'); - fido = fopen(fullfile(target_dir, 'CMakeLists.txt'),'w'); - while ~feof(fidi) - l = fgetl(fidi); % read line - % Replace EMBEDDED_FLAG in CMakeLists.txt by a numerical value - newl = strrep(l, 'EMBEDDED_FLAG', num2str(embedded)); - fprintf(fido, '%s\n', newl); - end - fclose(fidi); - fclose(fido); - - % Render workspace.h and workspace.c - work_hfile = fullfile(target_include_dir, 'workspace.h'); - work_cfile = fullfile(target_src_dir, 'osqp', 'workspace.c'); - fprintf('Generating workspace.h/.c...\t\t\t\t\t\t'); - render_workspace(work, work_hfile, work_cfile, embedded); - fprintf('[done]\n'); - - % Create project - if ~isempty(project_type) - - % Extend path for CMake mac (via Homebrew) - PATH = getenv('PATH'); - if ((ismac) && (isempty(strfind(PATH, '/usr/local/bin')))) - setenv('PATH', [PATH ':/usr/local/bin']); - end - - fprintf('Creating project...\t\t\t\t\t\t\t\t'); - orig_dir = pwd; - cd(target_dir); - mkdir('build') - cd('build'); - cmd = sprintf('cmake -G "%s" ..', project_type); - [status, output] = system(cmd); - if(status) - fprintf('\n'); - fprintf(output); - error('Error configuring CMake environment'); - else - fprintf('[done]\n'); - end - cd(orig_dir); - end - - % Make mex interface to the generated code - mex_cfile = fullfile(files_to_generate_path, 'emosqp_mex.c'); - make_emosqp(target_dir, mex_cfile, embedded, float_flag, long_flag); - - % Rename the mex file - old_mexfile = ['emosqp_mex.', mexext]; - new_mexfile = [p.Results.mexname, '.', mexext]; - movefile(old_mexfile, new_mexfile); - - end - - end -end - - - -function currentSettings = validateSettings(this,isInitialization,varargin) - -%don't allow these fields to be changed -unmodifiableFields = {'scaling', 'linsys_solver'}; - -%get the current settings -if(isInitialization) - currentSettings = osqp_mex('default_settings', this.objectHandle); -else - currentSettings = osqp_mex('current_settings', this.objectHandle); -end - -%no settings passed -> return defaults -if(isempty(varargin)) - return; -end - -%check for structure style input -if(isstruct(varargin{1})) - newSettings = varargin{1}; - assert(length(varargin) == 1, 'too many input arguments'); -else - newSettings = struct(varargin{:}); -end - -%get the osqp settings fields -currentFields = fieldnames(currentSettings); - -%get the requested fields in the update -newFields = fieldnames(newSettings); - -%check for unknown parameters -badFieldsIdx = find(~ismember(newFields,currentFields)); -if(~isempty(badFieldsIdx)) - error('Unrecognized solver setting ''%s'' detected',newFields{badFieldsIdx(1)}); -end - -%convert linsys_solver string to integer -if ismember('linsys_solver',newFields) - if ~ischar(newSettings.linsys_solver) - error('Setting linsys_solver is required to be a string.'); - end - % Convert linsys_solver to number - newSettings.linsys_solver = string_to_linsys_solver(newSettings.linsys_solver); -end - - -%check for disallowed fields if this in not an initialization call -if(~isInitialization) - badFieldsIdx = find(ismember(newFields,unmodifiableFields)); - for i = badFieldsIdx(:)' - if(~isequal(newSettings.(newFields{i}),currentSettings.(newFields{i}))) - error('Solver setting ''%s'' can only be changed at solver initialization.', newFields{i}); - end - end -end - - -%check that everything is a nonnegative scalar (this check is already -%performed in C) -% for i = 1:length(newFields) -% val = double(newSettings.(newFields{i})); -% assert(isscalar(val) & isnumeric(val) & val >= 0, ... -% 'Solver setting ''%s'' not specified as nonnegative scalar', newFields{i}); -% end - -%everything checks out - merge the newSettings into the current ones -for i = 1:length(newFields) - currentSettings.(newFields{i}) = double(newSettings.(newFields{i})); -end - - -end - -function [linsys_solver_string] = linsys_solver_to_string(linsys_solver) -% Convert linear systme solver integer to stringh -switch linsys_solver - case osqp.constant('OSQP_UNKNOWN_SOLVER') - linsys_solver_string = 'unknown solver'; - case osqp.constant('OSQP_DIRECT_SOLVER') - linsys_solver_string = 'direct solver'; - case osqp.constant('OSQP_INDIRECT_SOLVER') - linsys_solver_string = 'indirect solver'; - otherwise - error('Unrecognized linear system solver.'); -end -end - - - -function [linsys_solver] = string_to_linsys_solver(linsys_solver_string) -linsys_solver_string = lower(linsys_solver_string); -switch linsys_solver_string - case 'unknown solver' - linsys_solver = osqp.constant('OSQP_UNKNOWN_SOLVER'); - case 'direct solver' - linsys_solver = osqp.constant('OSQP_DIRECT_SOLVER'); - case 'indirect solver' - linsys_solver = osqp.constant('OSQP_INDIRECT_SOLVER'); - % Default solver: QDLDL - case '' - linsys_solver = osqp.constant('OSQP_DIRECT_SOLVER'); - otherwise - warning('Linear system solver not recognized. Using default solver OSQP_DIRECT_SOLVER.') - linsys_solver = osqp.constant('OSQP_DIRECT_SOLVER'); -end -end - - diff --git a/run_osqp_tests.m b/run_osqp_tests.m index 29d9d18..2bfca1c 100644 --- a/run_osqp_tests.m +++ b/run_osqp_tests.m @@ -1,7 +1,7 @@ import matlab.unittest.TestSuite; -[osqp_path,~,~] = fileparts(which('osqp.m')); -unittest_dir = fullfile(osqp_path, 'unittests'); +[osqp_classpath,~,~] = fileparts( mfilename( 'fullpath' ) ); +unittest_dir = fullfile(osqp_classpath, 'unittests'); suiteFolder = TestSuite.fromFolder(unittest_dir); % Solve individual test file diff --git a/unittests/basic_tests.m b/unittests/basic_tests.m index dedb073..c938a7a 100644 --- a/unittests/basic_tests.m +++ b/unittests/basic_tests.m @@ -124,6 +124,9 @@ function test_update_max_iter(testCase) opts.max_iter = 30; testCase.solver.update_settings(opts); + set = testCase.solver.current_settings(); + testCase.verifyEqual(set.max_iter, 30) + % Solve again results = testCase.solver.solve(); @@ -139,6 +142,9 @@ function test_update_early_termination(testCase) opts.check_termination = 0; testCase.solver.update_settings(opts); + set = testCase.solver.current_settings(); + testCase.verifyEqual(set.check_termination, 0) + % Solve again results = testCase.solver.solve(); @@ -184,6 +190,11 @@ function test_update_time_limit(testCase) 'max_iter', 2e9,... 'check_termination', 0); + set = testCase.solver.current_settings(); + testCase.verifyEqual(set.check_termination, 0) + testCase.verifyEqual(set.time_limit, 1e-6) + testCase.verifyEqual(set.max_iter, 2e9) + results = testCase.solver.solve(); testCase.verifyEqual(results.info.status_val, ... testCase.solver.constant('OSQP_TIME_LIMIT_REACHED')) diff --git a/unittests/warm_start_tests.m b/unittests/warm_start_tests.m index f6ab098..db5ff7b 100644 --- a/unittests/warm_start_tests.m +++ b/unittests/warm_start_tests.m @@ -33,8 +33,7 @@ function setup_problem(testCase) end methods (Test) - function test_warm_start(testCase) - + function test_warm_start_zeros(testCase) % big example rng(4) testCase.n = 100; @@ -63,12 +62,99 @@ function test_warm_start(testCase) testCase.solver.warm_start('x', zeros(testCase.n, 1), 'y', zeros(testCase.m, 1)); results = testCase.solver.solve(); testCase.verifyEqual(results.info.iter, tot_iter, 'AbsTol', testCase.tol) + end + + function test_warm_start_optimal(testCase) + % big example + rng(4) + testCase.n = 100; + testCase.m = 200; + Pt = sprandn(testCase.n, testCase.n, 0.6); + testCase.P = Pt' * Pt; + testCase.q = randn(testCase.n, 1); + testCase.A = sprandn(testCase.m, testCase.n, 0.8); + testCase.u = 2*rand(testCase.m, 1); + testCase.l = -2*rand(testCase.m, 1); + + % Setup solver + testCase.solver = osqp; + testCase.solver.setup(testCase.P, testCase.q, ... + testCase.A, testCase.l, testCase.u, testCase.options); + + % Solve with OSQP + results = testCase.solver.solve(); + + % Store optimal values + x_opt = results.x; + y_opt = results.y; + tot_iter = results.info.iter; % Warm start with optimal values and check that number of iterations is < 10 testCase.solver.warm_start('x', x_opt, 'y', y_opt); results = testCase.solver.solve(); testCase.verifyThat(results.info.iter, matlab.unittest.constraints.IsLessThan(10)); + end + + function test_warm_start_duals(testCase) + % big example + rng(4) + testCase.n = 100; + testCase.m = 200; + Pt = sprandn(testCase.n, testCase.n, 0.6); + testCase.P = Pt' * Pt; + testCase.q = randn(testCase.n, 1); + testCase.A = sprandn(testCase.m, testCase.n, 0.8); + testCase.u = 2*rand(testCase.m, 1); + testCase.l = -2*rand(testCase.m, 1); + % Setup solver + testCase.solver = osqp; + testCase.solver.setup(testCase.P, testCase.q, ... + testCase.A, testCase.l, testCase.u, testCase.options); + + % Solve with OSQP + results = testCase.solver.solve(); + + % Store optimal values + x_opt = results.x; + y_opt = results.y; + tot_iter = results.info.iter; + + % Warm start with zeros for dual variables + testCase.solver.warm_start('y', zeros(testCase.m, 1)); + results = testCase.solver.solve(); + testCase.verifyEqual(results.y, y_opt, 'AbsTol', testCase.tol) + end + + function test_warm_start_primal(testCase) + % big example + rng(4) + testCase.n = 100; + testCase.m = 200; + Pt = sprandn(testCase.n, testCase.n, 0.6); + testCase.P = Pt' * Pt; + testCase.q = randn(testCase.n, 1); + testCase.A = sprandn(testCase.m, testCase.n, 0.8); + testCase.u = 2*rand(testCase.m, 1); + testCase.l = -2*rand(testCase.m, 1); + + % Setup solver + testCase.solver = osqp; + testCase.solver.setup(testCase.P, testCase.q, ... + testCase.A, testCase.l, testCase.u, testCase.options); + + % Solve with OSQP + results = testCase.solver.solve(); + + % Store optimal values + x_opt = results.x; + y_opt = results.y; + tot_iter = results.info.iter; + + % Warm start with zeros for primal variables + testCase.solver.warm_start('x', zeros(testCase.n, 1)); + results = testCase.solver.solve(); + testCase.verifyEqual(results.x, x_opt, 'AbsTol', testCase.tol) end end