Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LBFGS: avoid generation of NaNs, and add checks for finite values #368

Merged
merged 4 commits into from
Jun 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion HISTORY.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
### ensmallen ?.??.?: "???"
###### ????-??-??
* Fix CNE test tolerances
* LBFGS: avoid generation of NaNs, and add checks for finite values
([#368](https://github.com/mlpack/ensmallen/pull/368)).

* Fix CNE test tolerances
([#360](https://github.com/mlpack/ensmallen/pull/360)).

### ensmallen 2.19.1: "Eight Ball Deluxe"
Expand Down
43 changes: 35 additions & 8 deletions include/ensmallen_bits/lbfgs/lbfgs_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,11 +86,17 @@ double L_BFGS::ChooseScalingFactor(const size_t iterationNum,
// Get s and y matrices once instead of multiple times.
const arma::Mat<CubeElemType>& sMat = s.slice(previousPos);
const arma::Mat<CubeElemType>& yMat = y.slice(previousPos);
scalingFactor = dot(sMat, yMat) / dot(yMat, yMat);

const CubeElemType tmp = arma::dot(yMat, yMat);
const CubeElemType denom = (tmp != CubeElemType(0)) ? tmp : CubeElemType(1);

scalingFactor = arma::dot(sMat, yMat) / denom;
}
else
{
scalingFactor = 1.0 / sqrt(dot(gradient, gradient));
const CubeElemType tmp = arma::norm(gradient, "fro");

scalingFactor = (tmp != CubeElemType(0)) ? (1.0 / tmp) : 1.0;
}

return scalingFactor;
Expand Down Expand Up @@ -129,11 +135,17 @@ void L_BFGS::SearchDirection(const MatType& gradient,
for (size_t i = iterationNum; i != limit; i--)
{
int translatedPosition = (i + (numBasis - 1)) % numBasis;
rho[iterationNum - i] = 1.0 / arma::dot(y.slice(translatedPosition),
s.slice(translatedPosition));
alpha[iterationNum - i] = rho[iterationNum - i] *
arma::dot(s.slice(translatedPosition), searchDirection);
searchDirection -= alpha[iterationNum - i] * y.slice(translatedPosition);

const arma::Mat<CubeElemType>& sMat = s.slice(translatedPosition);
const arma::Mat<CubeElemType>& yMat = y.slice(translatedPosition);

const CubeElemType tmp = arma::dot(yMat, sMat);

rho[iterationNum - i] = (tmp != CubeElemType(0)) ? (1.0 / tmp) : CubeElemType(1);

alpha[iterationNum - i] = rho[iterationNum - i] * arma::dot(sMat, searchDirection);

searchDirection -= alpha[iterationNum - i] * yMat;
}

searchDirection *= scalingFactor;
Expand Down Expand Up @@ -218,7 +230,8 @@ bool L_BFGS::LineSearch(FunctionType& function,
arma::dot(gradient, searchDirection);

// If it is not a descent direction, just report failure.
if (initialSearchDirectionDotGradient > 0.0)
if ( (initialSearchDirectionDotGradient > 0.0)
|| (std::isfinite(initialSearchDirectionDotGradient) == false) )
{
Warn << "L-BFGS line search direction is not a descent direction "
<< "(terminating)!" << std::endl;
Expand Down Expand Up @@ -250,6 +263,12 @@ bool L_BFGS::LineSearch(FunctionType& function,
newIterateTmp += stepSize * searchDirection;
functionValue = function.EvaluateWithGradient(newIterateTmp, gradient);

if (std::isnan(functionValue))
{
Warn << "L-BFGS objective value is NaN (terminating)!" << std::endl;
return false;
}

terminate |= Callback::EvaluateWithGradient(*this, function, newIterateTmp,
functionValue, gradient, callbacks...);

Expand Down Expand Up @@ -391,6 +410,7 @@ L_BFGS::Optimize(FunctionType& function,
//
// But don't do this on the first iteration to ensure we always take at
// least one descent step.
// TODO: to speed this up, investigate use of arma::norm2est() in Armadillo 12.4
if (arma::norm(gradient, 2) < minGradientNorm)
{
Info << "L-BFGS gradient norm too small (terminating successfully)."
Expand All @@ -416,6 +436,13 @@ L_BFGS::Optimize(FunctionType& function,
break;
}

if (std::isfinite(scalingFactor) == false)
{
Warn << "L-BFGS scaling factor is not finite. Stopping optimization."
<< std::endl;
break;
}

// Build an approximation to the Hessian and choose the search
// direction for the current iteration.
SearchDirection(gradient, itNum, scalingFactor, s, y, searchDirection);
Expand Down