Commit 0e63193e authored by Davis King's avatar Davis King

Added the objective_delta_stop_strategy and generally cleaned up the code more.

Also optimized the derivative() function a bit.

--HG--
extra : convert_revision : svn%3Afdd8eb12-d10e-0410-9acb-85c331704f74/trunk%403185
parent 7d63029a
......@@ -13,12 +13,64 @@
namespace dlib
{
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
class objective_delta_stop_strategy
{
public:
objective_delta_stop_strategy (
double min_delta = 1e-7
) : _been_used(false), _min_delta(min_delta), _max_iter(0), _cur_iter(0), _prev_funct_value(0) {}
objective_delta_stop_strategy (
double min_delta,
unsigned long max_iter
) : _been_used(false), _min_delta(min_delta), _max_iter(max_iter), _cur_iter(0), _prev_funct_value(0) {}
template <typename T>
bool should_continue_search (
const T& ,
const double funct_value,
const T&
)
{
++_cur_iter;
if (_been_used)
{
// Check if we have hit the max allowable number of iterations. (but only
// check if _max_iter is enabled (i.e. not 0)).
if (_max_iter != 0 && _cur_iter > _max_iter)
return false;
// check if the function change was too small
if (std::abs(funct_value - _prev_funct_value) < _min_delta)
return false;
}
_been_used = true;
_prev_funct_value = funct_value;
return true;
}
private:
bool _been_used;
double _min_delta;
unsigned long _max_iter;
unsigned long _cur_iter;
double _prev_funct_value;
};
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
class cg_strategy
class cg_search_strategy
{
public:
cg_strategy() : been_used(false) {}
cg_search_strategy() : been_used(false) {}
double get_wolfe_rho (
) const { return 0.001; }
......@@ -77,10 +129,10 @@ namespace dlib
// ----------------------------------------------------------------------------------------
class bfgs_strategy
class bfgs_search_strategy
{
public:
bfgs_strategy() : been_used(false), been_used_twice(false) {}
bfgs_search_strategy() : been_used(false), been_used_twice(false) {}
double get_wolfe_rho (
) const { return 0.01; }
......@@ -155,10 +207,10 @@ namespace dlib
// ----------------------------------------------------------------------------------------
class lbfgs_strategy
class lbfgs_search_strategy
{
public:
lbfgs_strategy(unsigned long max_size_) : max_size(max_size_), been_used(false) {}
lbfgs_search_strategy(unsigned long max_size_) : max_size(max_size_), been_used(false) {}
double get_wolfe_rho (
) const { return 0.01; }
......@@ -291,13 +343,20 @@ namespace dlib
COMPILE_TIME_ASSERT(is_matrix<T>::value);
typename T::matrix_type der(x.size());
typename T::matrix_type e(x.size());
set_all_elements(e,0);
typename T::matrix_type e(x);
for (long i = 0; i < x.size(); ++i)
{
e(i) = 1;
der(i) = (f(x+e*eps)-f(x-e*eps))/(2*eps);
e(i) = 0;
const double old_val = e(i);
e(i) += eps;
const double delta_plus = f(e);
e(i) = old_val - eps;
const double delta_minus = f(e);
der(i) = (delta_plus - delta_minus)/(2*eps);
// and finally restore the old value of this element
e(i) = old_val;
}
return der;
......@@ -724,18 +783,19 @@ namespace dlib
// ----------------------------------------------------------------------------------------
template <
typename strategy_type,
typename search_strategy_type,
typename stop_strategy_type,
typename funct,
typename funct_der,
typename T
>
void find_min (
strategy_type& strategy,
search_strategy_type search_strategy,
stop_strategy_type stop_strategy,
const funct& f,
const funct_der& der,
T& x,
double min_f,
double min_delta = 1e-7
double min_f
)
{
// You get an error on this line when you pass in a global function to this function.
......@@ -748,55 +808,47 @@ namespace dlib
COMPILE_TIME_ASSERT(is_matrix<T>::value);
DLIB_ASSERT (
min_delta >= 0 && x.nc() == 1,
is_col_vector(x),
"\tdouble find_min()"
<< "\n\tYou have to supply column vectors to this function"
<< "\n\tmin_delta: " << min_delta
<< "\n\tx.nc(): " << x.nc()
);
T g, s;
double alpha = 1;
double f_value = f(x);
g = der(x);
double old_f_value = f_value + 10;
// loop until there isn't any large change in the function value
while(std::abs(old_f_value - f_value) > min_delta )
while(stop_strategy.should_continue_search(x, f_value, g))
{
s = strategy.get_next_direction(x, f_value, g);
old_f_value = f_value;
s = search_strategy.get_next_direction(x, f_value, g);
alpha = line_search(
double alpha = line_search(
make_line_search_function(f,x,s, f_value),
f_value,
make_line_search_function(der,x,s, g),
dot(g,s),
strategy.get_wolfe_rho(), strategy.get_wolfe_sigma(), min_f);
dot(g,s), // compute initial gradient for the line search
search_strategy.get_wolfe_rho(), search_strategy.get_wolfe_sigma(), min_f);
x += alpha*s;
}
}
// ----------------------------------------------------------------------------------------
template <
typename strategy_type,
typename search_strategy_type,
typename stop_strategy_type,
typename funct,
typename T
>
void find_min_using_approximate_derivatives (
strategy_type& strategy,
search_strategy_type search_strategy,
stop_strategy_type stop_strategy,
const funct& f,
T& x,
double min_f,
double min_delta = 1e-7,
double derivative_eps = 1e-7
)
{
......@@ -809,39 +861,33 @@ namespace dlib
COMPILE_TIME_ASSERT(is_matrix<T>::value);
DLIB_ASSERT (
min_delta >= 0 && x.nc() == 1 && derivative_eps > 0,
is_col_vector(x) && derivative_eps > 0,
"\tdouble find_min_using_approximate_derivatives()"
<< "\n\tYou have to supply column vectors to this function"
<< "\n\tmin_delta: " << min_delta
<< "\n\tx.nc(): " << x.nc()
<< "\n\tderivative_eps: " << derivative_eps
);
T g, s;
double alpha = 1;
double f_value = f(x);
double old_f_value = f_value + 10;
g = derivative(f,derivative_eps)(x);
// loop until there isn't any large change in the function value
while(std::abs(old_f_value - f_value) > min_delta )
while(stop_strategy.should_continue_search(x, f_value, g))
{
g = derivative(f,derivative_eps)(x);
s = strategy.get_next_direction(x, f_value, g);
old_f_value = f_value;
s = search_strategy.get_next_direction(x, f_value, g);
alpha = line_search(
double alpha = line_search(
make_line_search_function(f,x,s,f_value),
f_value,
derivative(make_line_search_function(f,x,s),derivative_eps),
dot(g,s), // TODO. Maybe this line isn't better than the following one??
dot(g,s), // Sometimes the following line is a better way of determining the initial gradient.
//derivative(make_line_search_function(f,x,s),derivative_eps)(0),
strategy.get_wolfe_rho(), strategy.get_wolfe_sigma(), min_f);
search_strategy.get_wolfe_rho(), search_strategy.get_wolfe_sigma(), min_f);
x += alpha*s;
g = derivative(f,derivative_eps)(x);
}
}
......@@ -880,8 +926,11 @@ namespace dlib
<< "\n\tx.nc(): " << x.nc()
);
bfgs_strategy strategy;
find_min(strategy, f, der, x, min_f, min_delta);
find_min(
bfgs_search_strategy(),
objective_delta_stop_strategy(min_delta),
f, der, x, min_f
);
}
// ----------------------------------------------------------------------------------------
......@@ -916,8 +965,11 @@ namespace dlib
<< "\n\tx.nc(): " << x.nc()
);
cg_strategy strategy;
find_min(strategy, f, der, x, min_f, min_delta);
find_min(
cg_search_strategy(),
objective_delta_stop_strategy(min_delta),
f, der, x, min_f
);
}
// ----------------------------------------------------------------------------------------
......@@ -951,8 +1003,11 @@ namespace dlib
<< "\n\tderivative_eps: " << derivative_eps
);
bfgs_strategy strategy;
find_min_using_approximate_derivatives(strategy, f, x, min_f, min_delta, derivative_eps);
find_min_using_approximate_derivatives(
bfgs_search_strategy(),
objective_delta_stop_strategy(min_delta),
f, x, min_f, derivative_eps
);
}
// ----------------------------------------------------------------------------------------
......@@ -986,8 +1041,11 @@ namespace dlib
<< "\n\tderivative_eps: " << derivative_eps
);
cg_strategy strategy;
find_min_using_approximate_derivatives(strategy, f, x, min_f, min_delta, derivative_eps);
find_min_using_approximate_derivatives(
cg_search_strategy(),
objective_delta_stop_strategy(min_delta),
f, x, min_f, derivative_eps
);
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment