Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added matlab binding for vector valued constraints #313

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
206 changes: 163 additions & 43 deletions src/octave/nlopt_optimize-mex.c
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,17 @@

/* Matlab MEX interface to NLopt, and in particular to nlopt_optimize */

#include <math.h>
#include <mex.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <math.h>
#include <mex.h>

#include "nlopt.h"

#define CHECK0(cond, msg) if (!(cond)) mexErrMsgTxt(msg);
#define CHECK0(cond, msg) \
if (!(cond)) \
mexErrMsgTxt(msg);

static double struct_val_default(const mxArray* s, const char* name, double dflt)
{
Expand Down Expand Up @@ -71,7 +73,8 @@ static mxArray *struct_funcval(const mxArray *s, const char *name)
static double* fill(double* arr, unsigned n, double val)
{
unsigned i;
for (i = 0; i < n; ++i) arr[i] = val;
for (i = 0; i < n; ++i)
arr[i] = val;
return arr;
}

Expand All @@ -97,8 +100,7 @@ static double user_function(unsigned n, const double *x,
d->plhs[0] = d->plhs[1] = NULL;
memcpy(mxGetPr(d->prhs[d->xrhs]), x, n * sizeof(double));

CHECK0(0 == mexCallMATLAB(gradient ? 2 : 1, d->plhs,
d->nrhs, d->prhs, d->f),
CHECK0(0 == mexCallMATLAB(gradient ? 2 : 1, d->plhs, d->nrhs, d->prhs, d->f),
"error calling user function");

CHECK0(mxIsNumeric(d->plhs[0]) && !mxIsComplex(d->plhs[0])
Expand All @@ -115,11 +117,62 @@ static double user_function(unsigned n, const double *x,
mxDestroyArray(d->plhs[1]);
}
d->neval++;
if (d->verbose) mexPrintf("nlopt_optimize eval #%d: %g\n", d->neval, f);
if (mxIsNaN(f)) nlopt_force_stop(d->opt);
if (d->verbose)
mexPrintf("nlopt_optimize eval #%d: %g\n", d->neval, f);
if (mxIsNaN(f))
nlopt_force_stop(d->opt);
return f;
}

static void user_mfunction(unsigned m, double* result, unsigned n, const double* x,
double* gradient, /* NULL if not needed */
void* d_)
{
user_function_data* d = (user_function_data*)d_;

d->plhs[0] = d->plhs[1] = NULL;
memcpy(mxGetPr(d->prhs[d->xrhs]), x, n * sizeof(double));

CHECK0(0 == mexCallMATLAB(gradient ? 2 : 1, d->plhs, d->nrhs, d->prhs, d->f),
"error calling user mfunction");

CHECK0(mxIsNumeric(d->plhs[0]) && !mxIsComplex(d->plhs[0])
&& mxGetM(d->plhs[0]) * mxGetN(d->plhs[0]) == m,
"user mfunction must return an array of size equal to mfc_count parameter");
memcpy(result, mxGetPr(d->plhs[0]), n * sizeof(double));
mxDestroyArray(d->plhs[0]);

if (gradient) {
CHECK0(mxIsDouble(d->plhs[1]) && !mxIsComplex(d->plhs[1])
&& (mxGetM(d->plhs[1]) == m && mxGetN(d->plhs[1]) == n),
"gradient vector from user mfunction is the wrong size (mxn)");
double* ptr = mxGetPr(d->plhs[1]);
for (size_t j = 0; j < n; j++) {
for (size_t i = 0; i < m; i++) {
gradient[(i * n) + j] = ptr[i + (j * m)];
}
}
mxDestroyArray(d->plhs[1]);
}

d->neval++;

if (d->verbose) {
mexPrintf("nlopt_optimize eval #%d: ", d->neval);
for (size_t i = 0; i < n - 1; i++) {
mexPrintf("%g ", result[i]);
}
mexPrintf("%g\n", result[n - 1]);
}

for (size_t i = 0; i < n; i++) {
if (mxIsNaN(result[i])) {
nlopt_force_stop(d->opt);
break;
}
}
}

static void user_pre(unsigned n, const double* x, const double* v,
double* vpre, void* d_)
{
Expand All @@ -128,8 +181,7 @@ static void user_pre(unsigned n, const double *x, const double *v,
memcpy(mxGetPr(d->prhs[d->xrhs]), x, n * sizeof(double));
memcpy(mxGetPr(d->prhs[d->xrhs + 1]), v, n * sizeof(double));

CHECK0(0 == mexCallMATLAB(1, d->plhs,
d->nrhs, d->prhs, d->f),
CHECK0(0 == mexCallMATLAB(1, d->plhs, d->nrhs, d->prhs, d->f),
"error calling user function");

CHECK0(mxIsDouble(d->plhs[0]) && !mxIsComplex(d->plhs[0])
Expand All @@ -139,10 +191,18 @@ static void user_pre(unsigned n, const double *x, const double *v,
memcpy(vpre, mxGetPr(d->plhs[0]), n * sizeof(double));
mxDestroyArray(d->plhs[0]);
d->neval++;
if (d->verbose) mexPrintf("nlopt_optimize precond eval #%d\n", d->neval);
if (d->verbose)
mexPrintf("nlopt_optimize precond eval #%d\n", d->neval);
}

#define CHECK1(cond, msg) if (!(cond)) { mxFree(tmp); nlopt_destroy(opt); nlopt_destroy(local_opt); mexWarnMsgTxt(msg); return NULL; };
#define CHECK1(cond, msg) \
if (!(cond)) { \
mxFree(tmp); \
nlopt_destroy(opt); \
nlopt_destroy(local_opt); \
mexWarnMsgTxt(msg); \
return NULL; \
};

nlopt_opt make_opt(const mxArray* opts, unsigned n)
{
Expand All @@ -160,21 +220,15 @@ nlopt_opt make_opt(const mxArray *opts, unsigned n)
opt = nlopt_create(algorithm, n);
CHECK1(opt, "nlopt: out of memory");

nlopt_set_lower_bounds(opt, struct_arrval(opts, "lower_bounds", n,
fill(tmp, n, -HUGE_VAL)));
nlopt_set_upper_bounds(opt, struct_arrval(opts, "upper_bounds", n,
fill(tmp, n, +HUGE_VAL)));
nlopt_set_lower_bounds(opt, struct_arrval(opts, "lower_bounds", n, fill(tmp, n, -HUGE_VAL)));
nlopt_set_upper_bounds(opt, struct_arrval(opts, "upper_bounds", n, fill(tmp, n, +HUGE_VAL)));

nlopt_set_stopval(opt, struct_val_default(opts, "stopval", -HUGE_VAL));
nlopt_set_ftol_rel(opt, struct_val_default(opts, "ftol_rel", 0.0));
nlopt_set_ftol_abs(opt, struct_val_default(opts, "ftol_abs", 0.0));
nlopt_set_xtol_rel(opt, struct_val_default(opts, "xtol_rel", 0.0));
nlopt_set_xtol_abs(opt, struct_arrval(opts, "xtol_abs", n,
fill(tmp, n, 0.0)));
nlopt_set_x_weights(opt, struct_arrval(opts, "x_weights", n,
fill(tmp, n, 1.0)));
nlopt_set_maxeval(opt, struct_val_default(opts, "maxeval", 0.0) < 0 ?
0 : struct_val_default(opts, "maxeval", 0.0));
nlopt_set_xtol_abs(opt, struct_arrval(opts, "xtol_abs", n, fill(tmp, n, 0.0)));
nlopt_set_maxeval(opt, struct_val_default(opts, "maxeval", 0.0) < 0 ? 0 : struct_val_default(opts, "maxeval", 0.0));
nlopt_set_maxtime(opt, struct_val_default(opts, "maxtime", 0.0));

nlopt_set_population(opt, struct_val_default(opts, "population", 0));
Expand All @@ -191,14 +245,21 @@ nlopt_opt make_opt(const mxArray *opts, unsigned n)
CHECK1(local_opt = make_opt(local_opts, n),
"error initializing local optimizer");
nlopt_set_local_optimizer(opt, local_opt);
nlopt_destroy(local_opt); local_opt = NULL;
nlopt_destroy(local_opt);
local_opt = NULL;
}

mxFree(tmp);
return opt;
}

#define CHECK(cond, msg) if (!(cond)) { mxFree(dh); mxFree(dfc); nlopt_destroy(opt); mexErrMsgTxt(msg); }
#define CHECK(cond, msg) \
if (!(cond)) { \
mxFree(dh); \
mxFree(dfc); \
nlopt_destroy(opt); \
mexErrMsgTxt(msg); \
}

void mexFunction(int nlhs, mxArray* plhs[],
int nrhs, const mxArray* prhs[])
Expand All @@ -207,7 +268,7 @@ void mexFunction(int nlhs, mxArray *plhs[],
double *x, *x0, opt_f;
nlopt_result ret;
mxArray *x_mx, *mx;
user_function_data d, dpre, *dfc = NULL, *dh = NULL;
user_function_data d, dpre, dmfc, dmh, *dfc = NULL, *dh = NULL;
nlopt_opt opt = NULL;

CHECK(nrhs == 2 && nlhs <= 3, "wrong number of arguments");
Expand All @@ -230,15 +291,15 @@ void mexFunction(int nlhs, mxArray *plhs[],

/* function f = prhs[1] */
mx = struct_funcval(prhs[0], "min_objective");
if (!mx) mx = struct_funcval(prhs[0], "max_objective");
if (!mx)
mx = struct_funcval(prhs[0], "max_objective");
CHECK(mx, "either opt.min_objective or opt.max_objective must exist");
if (mxIsChar(mx)) {
CHECK(mxGetString(mx, d.f, FLEN) == 0,
"error reading function name string (too long?)");
d.nrhs = 1;
d.xrhs = 0;
}
else {
} else {
d.prhs[0] = mx;
strcpy(d.f, "feval");
d.nrhs = 2;
Expand All @@ -254,8 +315,7 @@ void mexFunction(int nlhs, mxArray *plhs[],
"error reading function name string (too long?)");
dpre.nrhs = 2;
dpre.xrhs = 0;
}
else {
} else {
dpre.prhs[0] = mx;
strcpy(dpre.f, "feval");
dpre.nrhs = 3;
Expand All @@ -272,21 +332,82 @@ void mexFunction(int nlhs, mxArray *plhs[],
nlopt_set_precond_min_objective(opt, user_function, user_pre, &d);
else
nlopt_set_precond_max_objective(opt, user_function, user_pre, &d);
}
else {
} else {
dpre.nrhs = 0;
if (struct_funcval(prhs[0], "min_objective"))
nlopt_set_min_objective(opt, user_function, &d);
else
nlopt_set_max_objective(opt, user_function, &d);
}

if ((mx = mxGetField(prhs[0], 0, "mfc"))) {
int m;
double* mfc_tol;

m = (int)struct_val_default(prhs[0], "mfc_count", -1);
mfc_tol = struct_arrval(prhs[0], "mfc_tol", m, NULL);

CHECK(mxIsChar(mx) || mxIsFunctionHandle(mx),
"mfc must contain function handles or function names");
if (mxIsChar(mx)) {
CHECK(mxGetString(mx, dmfc.f, FLEN) == 0,
"error reading function name string (too long?)");
dmfc.nrhs = 1;
dmfc.xrhs = 0;
} else {
dmfc.prhs[0] = mx;
strcpy(dmfc.f, "feval");
dmfc.nrhs = 2;
dmfc.xrhs = 1;
}
dmfc.verbose = d.verbose > 1;
dmfc.opt = opt;
dmfc.neval = 0;
dmfc.prhs[dmfc.xrhs] = d.prhs[d.xrhs];

CHECK(nlopt_add_inequality_mconstraint(opt, m,
user_mfunction, &dmfc, mfc_tol ? mfc_tol : 0)
> 0,
"nlopt error adding multiple inequality constraints");
}

if ((mx = mxGetField(prhs[0], 0, "mh"))) {
int m;
double* mh_tol;

m = (int)struct_val_default(prhs[0], "mh_count", -1);
mh_tol = struct_arrval(prhs[0], "mh_tol", m, NULL);

CHECK(mxIsChar(mx) || mxIsFunctionHandle(mx),
"mh must contain function handles or function names");
if (mxIsChar(mx)) {
CHECK(mxGetString(mx, dmh.f, FLEN) == 0,
"error reading function name string (too long?)");
dmh.nrhs = 1;
dmh.xrhs = 0;
} else {
dmh.prhs[0] = mx;
strcpy(dmh.f, "feval");
dmh.nrhs = 2;
dmh.xrhs = 1;
}
dmh.verbose = d.verbose > 1;
dmh.opt = opt;
dmh.neval = 0;
dmh.prhs[dmh.xrhs] = d.prhs[d.xrhs];

CHECK(nlopt_add_equality_mconstraint(opt, m,
user_mfunction, &dmh, mh_tol ? mh_tol : 0)
> 0,
"nlopt error adding multiple equality constraints");
}

if ((mx = mxGetField(prhs[0], 0, "fc"))) {
int j, m;
double* fc_tol;

CHECK(mxIsCell(mx), "fc must be a Cell array");
m = mxGetM(mx) * mxGetN(mx);;
m = mxGetM(mx) * mxGetN(mx);
dfc = (user_function_data*)mxCalloc(m, sizeof(user_function_data));
fc_tol = struct_arrval(prhs[0], "fc_tol", m, NULL);

Expand All @@ -299,8 +420,7 @@ void mexFunction(int nlhs, mxArray *plhs[],
"error reading function name string (too long?)");
dfc[j].nrhs = 1;
dfc[j].xrhs = 0;
}
else {
} else {
dfc[j].prhs[0] = fc;
strcpy(dfc[j].f, "feval");
dfc[j].nrhs = 2;
Expand All @@ -313,17 +433,17 @@ void mexFunction(int nlhs, mxArray *plhs[],
CHECK(nlopt_add_inequality_constraint(opt, user_function,
dfc + j,
fc_tol ? fc_tol[j] : 0)
> 0, "nlopt error adding inequality constraint");
> 0,
"nlopt error adding inequality constraint");
}
}


if ((mx = mxGetField(prhs[0], 0, "h"))) {
int j, m;
double* h_tol;

CHECK(mxIsCell(mx), "h must be a Cell array");
m = mxGetM(mx) * mxGetN(mx);;
m = mxGetM(mx) * mxGetN(mx);
dh = (user_function_data*)mxCalloc(m, sizeof(user_function_data));
h_tol = struct_arrval(prhs[0], "h_tol", m, NULL);

Expand All @@ -336,8 +456,7 @@ void mexFunction(int nlhs, mxArray *plhs[],
"error reading function name string (too long?)");
dh[j].nrhs = 1;
dh[j].xrhs = 0;
}
else {
} else {
dh[j].prhs[0] = h;
strcpy(dh[j].f, "feval");
dh[j].nrhs = 2;
Expand All @@ -350,11 +469,11 @@ void mexFunction(int nlhs, mxArray *plhs[],
CHECK(nlopt_add_equality_constraint(opt, user_function,
dh + j,
h_tol ? h_tol[j] : 0)
> 0, "nlopt error adding equality constraint");
> 0,
"nlopt error adding equality constraint");
}
}


x_mx = mxCreateDoubleMatrix(mxGetM(prhs[1]), mxGetN(prhs[1]), mxREAL);
x = mxGetPr(x_mx);
memcpy(x, x0, sizeof(double) * n);
Expand All @@ -364,7 +483,8 @@ void mexFunction(int nlhs, mxArray *plhs[],
mxFree(dh);
mxFree(dfc);
mxDestroyArray(d.prhs[d.xrhs]);
if (dpre.nrhs > 0) mxDestroyArray(dpre.prhs[d.xrhs+1]);
if (dpre.nrhs > 0)
mxDestroyArray(dpre.prhs[d.xrhs + 1]);
nlopt_destroy(opt);

plhs[0] = x_mx;
Expand Down