Commit 182c811f authored by Davis King's avatar Davis King

Added the gradient_norm_stop_strategy

--HG--
extra : convert_revision : svn%3Afdd8eb12-d10e-0410-9acb-85c331704f74/trunk%403199
parent a6d56265
......@@ -77,6 +77,65 @@ namespace dlib
double _prev_funct_value;
};
// ----------------------------------------------------------------------------------------
class gradient_norm_stop_strategy
{
public:
gradient_norm_stop_strategy (
double min_norm = 1e-7
) : _min_norm(min_norm), _max_iter(0), _cur_iter(0)
{
DLIB_ASSERT (
min_norm >= 0,
"\t gradient_norm_stop_strategy(min_norm)"
<< "\n\t min_norm can't be negative"
<< "\n\t min_norm: " << min_norm
);
}
gradient_norm_stop_strategy (
double min_norm,
unsigned long max_iter
) : _min_norm(min_norm), _max_iter(max_iter), _cur_iter(0)
{
DLIB_ASSERT (
min_norm >= 0 && max_iter > 0,
"\t gradient_norm_stop_strategy(min_norm, max_iter)"
<< "\n\t min_norm can't be negative and max_iter can't be 0"
<< "\n\t min_norm: " << min_norm
<< "\n\t max_iter: " << max_iter
);
}
template <typename T>
bool should_continue_search (
const T& ,
const double ,
const T& funct_derivative
)
{
++_cur_iter;
// Check if we have hit the max allowable number of iterations. (but only
// check if _max_iter is enabled (i.e. not 0)).
if (_max_iter != 0 && _cur_iter > _max_iter)
return false;
// check if the gradient norm is too small
if (length(funct_derivative) < _min_norm)
return false;
return true;
}
private:
double _min_norm;
unsigned long _max_iter;
unsigned long _cur_iter;
};
// ----------------------------------------------------------------------------------------
}
......
......@@ -73,6 +73,64 @@ namespace dlib
};
// ----------------------------------------------------------------------------------------
class gradient_norm_stop_strategy
{
/*!
WHAT THIS OBJECT REPRESENTS
This object represents a strategy for deciding if an optimization
algorithm should terminate. This particular object looks at the
norm (i.e. the length) of the current gradient vector and stops
if it is smaller than a user given threshold.
!*/
public:
gradient_norm_stop_strategy (
double min_norm = 1e-7
);
/*!
requires
- min_norm >= 0
ensures
- This stop strategy object will only consider a search to be complete
if the current gradient norm is less than min_norm
!*/
gradient_norm_stop_strategy (
double min_norm,
unsigned long max_iter
);
/*!
requires
- min_norm >= 0
- max_iter > 0
ensures
- This stop strategy object will only consider a search to be complete
if the current gradient norm is less than min_norm or more than
max_iter iterations has been executed.
!*/
template <typename T>
bool should_continue_search (
const T& x,
const double funct_value,
const T& funct_derivative
);
/*!
requires
- this function is only called once per search iteration
- for some objective function f():
- x == the search point for the current iteration
- funct_value == f(x)
- funct_derivative == derivative(f)(x)
ensures
- returns true if the point x doest not satisfy the stopping condition and
false otherwise.
!*/
};
// ----------------------------------------------------------------------------------------
}
......
......@@ -158,6 +158,14 @@ namespace
wrap_function(apq<T>), wrap_function(der_apq<T>), x, minf);
DLIB_TEST_MSG(dlib::equal(x,opt, 1e-5),opt-x);
dlog << LINFO << "find_min() bgfs: got apq in " << total_count;
total_count = 0;
x = p;
find_min(bfgs_search_strategy(),
gradient_norm_stop_strategy(),
wrap_function(apq<T>), wrap_function(der_apq<T>), x, minf);
DLIB_TEST_MSG(dlib::equal(x,opt, 1e-5),opt-x);
dlog << LINFO << "find_min() bgfs(gn): got apq in " << total_count;
}
......@@ -288,6 +296,15 @@ namespace
dlog << LINFO << "find_min() lbfgs-4: got powell/noder2 in " << total_count;
total_count = 0;
x = p;
find_min_using_approximate_derivatives(lbfgs_search_strategy(4),
gradient_norm_stop_strategy(),
&powell, x, minf, 1e-10);
DLIB_TEST_MSG(dlib::equal(x,opt, 1e-1),opt-x);
dlog << LINFO << "find_min() lbfgs-4(gn): got powell/noder2 in " << total_count;
total_count = 0;
x = p;
find_min_using_approximate_derivatives(cg_search_strategy(),
......@@ -320,6 +337,15 @@ namespace
dlog << LINFO << "find_min() bfgs: got simple in " << total_count;
total_count = 0;
x = p;
find_min(bfgs_search_strategy(),
gradient_norm_stop_strategy(),
&simple, &der_simple, x, minf);
DLIB_TEST_MSG(dlib::equal(x,opt, 1e-5),opt-x);
dlog << LINFO << "find_min() bfgs(gn): got simple in " << total_count;
total_count = 0;
x = p;
find_min(lbfgs_search_strategy(3),
......@@ -417,6 +443,15 @@ namespace
dlog << LINFO << "find_min() bfgs: got rosen in " << total_count;
total_count = 0;
x = p;
find_min(bfgs_search_strategy(),
gradient_norm_stop_strategy(),
&rosen, &der_rosen, x, minf);
DLIB_TEST_MSG(dlib::equal(x,opt, 1e-7),opt-x);
dlog << LINFO << "find_min() bfgs(gn): got rosen in " << total_count;
total_count = 0;
x = p;
find_min(lbfgs_search_strategy(20),
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment