Commit 4901ed60 authored by Davis King's avatar Davis King

Added a max iterations parameter to the line_search() function.

--HG--
extra : convert_revision : svn%3Afdd8eb12-d10e-0410-9acb-85c331704f74/trunk%403771
parent 4f791bb4
...@@ -170,7 +170,8 @@ namespace dlib ...@@ -170,7 +170,8 @@ namespace dlib
f_value, f_value,
make_line_search_function(der,x,s, g), make_line_search_function(der,x,s, g),
dot(g,s), // compute initial gradient for the line search dot(g,s), // compute initial gradient for the line search
search_strategy.get_wolfe_rho(), search_strategy.get_wolfe_sigma(), min_f); search_strategy.get_wolfe_rho(), search_strategy.get_wolfe_sigma(), min_f,
search_strategy.get_max_line_search_iterations());
// Take the search step indicated by the above line search // Take the search step indicated by the above line search
x += alpha*s; x += alpha*s;
...@@ -230,7 +231,9 @@ namespace dlib ...@@ -230,7 +231,9 @@ namespace dlib
f_value, f_value,
negate_function(make_line_search_function(der,x,s, g)), negate_function(make_line_search_function(der,x,s, g)),
dot(g,s), // compute initial gradient for the line search dot(g,s), // compute initial gradient for the line search
search_strategy.get_wolfe_rho(), search_strategy.get_wolfe_sigma(), -max_f); search_strategy.get_wolfe_rho(), search_strategy.get_wolfe_sigma(), -max_f,
search_strategy.get_max_line_search_iterations()
);
// Take the search step indicated by the above line search // Take the search step indicated by the above line search
x += alpha*s; x += alpha*s;
...@@ -292,7 +295,9 @@ namespace dlib ...@@ -292,7 +295,9 @@ namespace dlib
derivative(make_line_search_function(f,x,s),derivative_eps), derivative(make_line_search_function(f,x,s),derivative_eps),
dot(g,s), // Sometimes the following line is a better way of determining the initial gradient. dot(g,s), // Sometimes the following line is a better way of determining the initial gradient.
//derivative(make_line_search_function(f,x,s),derivative_eps)(0), //derivative(make_line_search_function(f,x,s),derivative_eps)(0),
search_strategy.get_wolfe_rho(), search_strategy.get_wolfe_sigma(), min_f); search_strategy.get_wolfe_rho(), search_strategy.get_wolfe_sigma(), min_f,
search_strategy.get_max_line_search_iterations()
);
// Take the search step indicated by the above line search // Take the search step indicated by the above line search
x += alpha*s; x += alpha*s;
......
...@@ -247,7 +247,8 @@ namespace dlib ...@@ -247,7 +247,8 @@ namespace dlib
const double d0, const double d0,
double rho, double rho,
double sigma, double sigma,
double min_f double min_f,
unsigned long max_iter
) )
{ {
// You get an error on this line when you pass in a global function to this function. // You get an error on this line when you pass in a global function to this function.
...@@ -259,11 +260,12 @@ namespace dlib ...@@ -259,11 +260,12 @@ namespace dlib
COMPILE_TIME_ASSERT(is_function<funct_der>::value == false); COMPILE_TIME_ASSERT(is_function<funct_der>::value == false);
DLIB_ASSERT ( DLIB_ASSERT (
0 < rho && rho < sigma && sigma < 1, 0 < rho && rho < sigma && sigma < 1 && max_iter > 0,
"\tdouble line_search()" "\tdouble line_search()"
<< "\n\tYou have given invalid arguments to this function" << "\n\tYou have given invalid arguments to this function"
<< "\n\tsigma: " << sigma << "\n\t sigma: " << sigma
<< "\n\trho: " << rho << "\n\t rho: " << rho
<< "\n\t max_iter: " << max_iter
); );
// The bracketing phase of this function is implemented according to block 2.6.2 from // The bracketing phase of this function is implemented according to block 2.6.2 from
...@@ -311,9 +313,11 @@ namespace dlib ...@@ -311,9 +313,11 @@ namespace dlib
// This thresh value represents the Wolfe curvature condition // This thresh value represents the Wolfe curvature condition
const double thresh = std::abs(sigma*d0); const double thresh = std::abs(sigma*d0);
unsigned long itr = 0;
// do the bracketing stage to find the bracket range [a,b] // do the bracketing stage to find the bracket range [a,b]
while (true) while (true)
{ {
++itr;
const double val = f(alpha); const double val = f(alpha);
const double val_der = der(alpha); const double val_der = der(alpha);
...@@ -338,7 +342,7 @@ namespace dlib ...@@ -338,7 +342,7 @@ namespace dlib
return alpha; return alpha;
// if we are stuck not making progress then quit with the current alpha // if we are stuck not making progress then quit with the current alpha
if (last_alpha == alpha) if (last_alpha == alpha || itr >= max_iter)
return alpha; return alpha;
if (val_der >= 0) if (val_der >= 0)
...@@ -391,6 +395,7 @@ namespace dlib ...@@ -391,6 +395,7 @@ namespace dlib
// Now do the sectioning phase from 2.6.4 // Now do the sectioning phase from 2.6.4
while (true) while (true)
{ {
++itr;
double first = a + tau2*(b-a); double first = a + tau2*(b-a);
double last = b - tau3*(b-a); double last = b - tau3*(b-a);
...@@ -402,8 +407,8 @@ namespace dlib ...@@ -402,8 +407,8 @@ namespace dlib
const double val_der = der(alpha); const double val_der = der(alpha);
// we are done with the line search since we found a value smaller // we are done with the line search since we found a value smaller
// than the minimum f value // than the minimum f value or we ran out of iterations.
if (val <= min_f) if (val <= min_f || itr >= max_iter)
return alpha; return alpha;
// stop if the interval gets so small that it isn't shrinking any more due to rounding error // stop if the interval gets so small that it isn't shrinking any more due to rounding error
......
...@@ -137,7 +137,8 @@ namespace dlib ...@@ -137,7 +137,8 @@ namespace dlib
const double d0, const double d0,
double rho, double rho,
double sigma, double sigma,
double min_f double min_f,
unsigned long max_iter
) )
/*! /*!
requires requires
...@@ -147,11 +148,14 @@ namespace dlib ...@@ -147,11 +148,14 @@ namespace dlib
- der is the derivative of f - der is the derivative of f
- f0 == f(0) - f0 == f(0)
- d0 == der(0) - d0 == der(0)
- max_iter > 0
ensures ensures
- Performs a line search and uses the strong Wolfe conditions to decide when - Performs a line search and uses the strong Wolfe conditions to decide when
the search can stop. the search can stop.
- rho == the parameter of the Wolfe sufficient decrease condition - rho == the parameter of the Wolfe sufficient decrease condition
- sigma == the parameter of the Wolfe curvature condition - sigma == the parameter of the Wolfe curvature condition
- max_iter == the maximum number of iterations allowable. After this
many evaluations of f() line_search() is guaranteed to terminate.
- returns a value alpha such that f(alpha) is significantly closer to - returns a value alpha such that f(alpha) is significantly closer to
the minimum of f than f(0). the minimum of f than f(0).
- It is assumed that the minimum possible value of f(x) is min_f. So if - It is assumed that the minimum possible value of f(x) is min_f. So if
...@@ -210,7 +214,7 @@ namespace dlib ...@@ -210,7 +214,7 @@ namespace dlib
throws throws
- optimize_single_variable_failure - optimize_single_variable_failure
This exception is thrown if max_iter iterations are performed without This exception is thrown if max_iter iterations are performed without
determining the min point to the requsted accuracy of eps. determining the min point to the requested accuracy of eps.
!*/ !*/
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
...@@ -245,7 +249,7 @@ namespace dlib ...@@ -245,7 +249,7 @@ namespace dlib
throws throws
- optimize_single_variable_failure - optimize_single_variable_failure
This exception is thrown if max_iter iterations are performed without This exception is thrown if max_iter iterations are performed without
determining the max point to the requsted accuracy of eps. determining the max point to the requested accuracy of eps.
!*/ !*/
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
......
...@@ -26,6 +26,9 @@ namespace dlib ...@@ -26,6 +26,9 @@ namespace dlib
double get_wolfe_sigma ( double get_wolfe_sigma (
) const { return 0.01; } ) const { return 0.01; }
unsigned long get_max_line_search_iterations (
) const { return 100; }
template <typename T> template <typename T>
const matrix<double,0,1>& get_next_direction ( const matrix<double,0,1>& get_next_direction (
const T& , const T& ,
...@@ -78,6 +81,9 @@ namespace dlib ...@@ -78,6 +81,9 @@ namespace dlib
double get_wolfe_sigma ( double get_wolfe_sigma (
) const { return 0.9; } ) const { return 0.9; }
unsigned long get_max_line_search_iterations (
) const { return 100; }
template <typename T> template <typename T>
const matrix<double,0,1>& get_next_direction ( const matrix<double,0,1>& get_next_direction (
const T& x, const T& x,
...@@ -174,6 +180,9 @@ namespace dlib ...@@ -174,6 +180,9 @@ namespace dlib
double get_wolfe_sigma ( double get_wolfe_sigma (
) const { return 0.9; } ) const { return 0.9; }
unsigned long get_max_line_search_iterations (
) const { return 100; }
template <typename T> template <typename T>
const matrix<double,0,1>& get_next_direction ( const matrix<double,0,1>& get_next_direction (
const T& x, const T& x,
......
...@@ -59,6 +59,14 @@ namespace dlib ...@@ -59,6 +59,14 @@ namespace dlib
this search strategy is used with the line_search() function. this search strategy is used with the line_search() function.
!*/ !*/
unsigned long get_max_line_search_iterations (
) const;
/*!
ensures
- returns the value of the max iterations parameter that should be used when
this search strategy is used with the line_search() function.
!*/
template <typename T> template <typename T>
const matrix<double,0,1>& get_next_direction ( const matrix<double,0,1>& get_next_direction (
const T& x, const T& x,
...@@ -121,6 +129,14 @@ namespace dlib ...@@ -121,6 +129,14 @@ namespace dlib
this search strategy is used with the line_search() function. this search strategy is used with the line_search() function.
!*/ !*/
unsigned long get_max_line_search_iterations (
) const;
/*!
ensures
- returns the value of the max iterations parameter that should be used when
this search strategy is used with the line_search() function.
!*/
template <typename T> template <typename T>
const matrix<double,0,1>& get_next_direction ( const matrix<double,0,1>& get_next_direction (
const T& x, const T& x,
...@@ -187,6 +203,14 @@ namespace dlib ...@@ -187,6 +203,14 @@ namespace dlib
this search strategy is used with the line_search() function. this search strategy is used with the line_search() function.
!*/ !*/
unsigned long get_max_line_search_iterations (
) const;
/*!
ensures
- returns the value of the max iterations parameter that should be used when
this search strategy is used with the line_search() function.
!*/
template <typename T> template <typename T>
const matrix<double,0,1>& get_next_direction ( const matrix<double,0,1>& get_next_direction (
const T& x, const T& x,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment