Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tolerate QN linesearch failures when it's harmless #3791

Merged
146 changes: 106 additions & 40 deletions cpp/src/glm/qn/qn_solvers.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,74 @@ inline size_t owlqn_workspace_size(const LBFGSParam<T>& param, const int n)
return lbfgs_workspace_size(param, n) + vec_size;
}

template <typename T>
inline bool update_and_check(const char* solver,
const LBFGSParam<T>& param,
int iter,
LINE_SEARCH_RETCODE lsret,
T& fx,
T& fxp,
ML::SimpleVec<T>& x,
ML::SimpleVec<T>& xp,
ML::SimpleVec<T>& grad,
ML::SimpleVec<T>& gradp,
std::vector<T>& fx_hist,
T* dev_scalar,
OPT_RETCODE& outcode,
cudaStream_t stream)
{
bool stop = false;
bool converged = false;
bool isLsValid = !isnan(fx) && !isinf(fx);
// Linesearch may fail to converge, but still come closer to the solution;
// if that is not the case, let `check_convergence` ("insufficient change")
// below terminate the loop.
bool isLsNonCritical = lsret == LS_INVALID_STEP_MIN || lsret == LS_MAX_ITERS_REACHED;
// If the error is not critical, check that the target function does not grow.
// This shouldn't really happen, but weird things can happen if the convergence
// thresholds are too small.
bool isLsInDoubt = isLsValid && fx <= fxp + param.ftol && isLsNonCritical;
bool isLsSuccess = lsret == LS_SUCCESS || isLsInDoubt;

CUML_LOG_TRACE("%s iteration %d, fx=%f", solver, iter, fx);

// if the target is at least finite, we can check the convergence
if (isLsValid)
converged = check_convergence(param, iter, fx, x, grad, fx_hist, dev_scalar, stream);

if (!isLsSuccess && !converged) {
CUML_LOG_WARN(
"%s line search failed (code %d); stopping at the last valid step", solver, lsret);
outcode = OPT_LS_FAILED;
stop = true;
} else if (!isLsValid) {
CUML_LOG_ERROR(
"%s error fx=%f at iteration %d; stopping at the last valid step", solver, fx, iter);
outcode = OPT_NUMERIC_ERROR;
stop = true;
} else if (converged) {
CUML_LOG_DEBUG("%s converged", solver);
outcode = OPT_SUCCESS;
stop = true;
} else if (isLsInDoubt && fx + param.ftol >= fxp) {
// If a non-critical error has happened during the line search, check if the target
// is improved at least a bit. Otherwise, stop to avoid spinning till the iteration limit.
CUML_LOG_WARN(
"%s stopped, because the line search failed to advance (step delta = %f)", solver, fx - fxp);
outcode = OPT_LS_FAILED;
stop = true;
}

// if lineseach wasn't successful, undo the update.
if (!isLsSuccess || !isLsValid) {
fx = fxp;
x.copy_async(xp, stream);
grad.copy_async(gradp, stream);
}

return stop;
}

template <typename T, typename Function>
inline OPT_RETCODE min_lbfgs(const LBFGSParam<T>& param,
Function& f, // function to minimize
Expand Down Expand Up @@ -131,35 +199,32 @@ inline OPT_RETCODE min_lbfgs(const LBFGSParam<T>& param,
*k = 1;
int end = 0;
int n_vec = 0; // number of vector updates made in lbfgs_search_dir
OPT_RETCODE retcode;
LINE_SEARCH_RETCODE lsret;
for (; *k <= param.max_iterations; (*k)++) {
// Save the curent x and gradient
xp.copy_async(x, stream);
gradp.copy_async(grad, stream);
fxp = fx;

// Line search to update x, fx and gradient
LINE_SEARCH_RETCODE lsret =
ls_backtrack(param, f, fx, x, grad, step, drt, xp, dev_scalar, stream);

bool isLsSuccess = lsret == LS_SUCCESS;
CUML_LOG_TRACE("Iteration %d, fx=%f", *k, fx);
achirkin marked this conversation as resolved.
Show resolved Hide resolved

if (!isLsSuccess || isnan(fx) || isinf(fx)) {
fx = fxp;
x.copy_async(xp, stream);
grad.copy_async(gradp, stream);
if (!isLsSuccess) {
CUML_LOG_ERROR("L-BFGS line search failed");
return OPT_LS_FAILED;
}
CUML_LOG_ERROR("L-BFGS error fx=%f at iteration %d", fx, *k);
return OPT_NUMERIC_ERROR;
}

if (check_convergence(param, *k, fx, x, grad, fx_hist, dev_scalar, stream)) {
CUML_LOG_DEBUG("L-BFGS converged");
return OPT_SUCCESS;
}
lsret = ls_backtrack(param, f, fx, x, grad, step, drt, xp, dev_scalar, stream);

if (update_and_check("L-BFGS",
param,
*k,
lsret,
fx,
fxp,
x,
xp,
grad,
gradp,
fx_hist,
dev_scalar,
retcode,
stream))
return retcode;

// Update s and y
// s_{k+1} = x_{k+1} - x_k
Expand Down Expand Up @@ -282,37 +347,38 @@ inline OPT_RETCODE min_owlqn(const LBFGSParam<T>& param,

int end = 0;
int n_vec = 0; // number of vector updates made in lbfgs_search_dir
OPT_RETCODE retcode;
LINE_SEARCH_RETCODE lsret;
for ((*k) = 1; (*k) <= param.max_iterations; (*k)++) {
// Save the curent x and gradient
xp.copy_async(x, stream);
gradp.copy_async(grad, stream);
fxp = fx;

// Projected line search to update x, fx and gradient
LINE_SEARCH_RETCODE lsret = ls_backtrack_projected(
lsret = ls_backtrack_projected(
param, f_wrap, fx, x, grad, pseudo, step, drt, xp, l1_penalty, dev_scalar, stream);

bool isLsSuccess = lsret == LS_SUCCESS;
if (!isLsSuccess || isnan(fx) || isinf(fx)) {
fx = fxp;
x.copy_async(xp, stream);
grad.copy_async(gradp, stream);
if (!isLsSuccess) {
CUML_LOG_ERROR("QWL-QN line search failed");
return OPT_LS_FAILED;
}
CUML_LOG_ERROR("OWL-QN error fx=%f at iteration %d", fx, *k);
return OPT_NUMERIC_ERROR;
}
if (update_and_check("QWL-QN",
param,
*k,
lsret,
fx,
fxp,
x,
xp,
grad,
gradp,
fx_hist,
dev_scalar,
retcode,
stream))
return retcode;

// recompute pseudo
// pseudo.assign_binary(x, grad, pseudo_grad);
update_pseudo(x, grad, pseudo_grad, pg_limit, pseudo, stream);

if (check_convergence(param, *k, fx, x, pseudo, fx_hist, dev_scalar, stream)) {
CUML_LOG_DEBUG("OWL-QN converged");
return OPT_SUCCESS;
}

// Update s and y - We should only do this if there is no skipping condition

col_ref(S, svec, end);
Expand Down