Commit 7831ae02 authored by Davis King's avatar Davis King

This is the first checkin of a bunch of code refactoring I have been doing to this component.

Most of these changes are just rearrangements of the old code.  However, I also changed
the starting value of a line search and also removed some unneeded calls to the objective
function and its derivative in the course of doing the refactoring.  This has all resulted
in a significant reduction in calls to the objective function.

--HG--
extra : convert_revision : svn%3Afdd8eb12-d10e-0410-9acb-85c331704f74/trunk%403182
parent 46401b06
...@@ -12,6 +12,129 @@ ...@@ -12,6 +12,129 @@
namespace dlib namespace dlib
{ {
// ----------------------------------------------------------------------------------------
class cg_strategy
{
public:
cg_strategy() : been_used(false) {}
double get_wolfe_rho (
) const { return 0.001; }
double get_wolfe_sigma (
) const { return 0.01; }
template <typename T>
const matrix<double,0,1>& get_next_direction (
const T& ,
const double ,
const T& funct_derivative
)
/*!
requires
- for some function f():
- funct_value == f(x)
- funct_derivative == derivative(f)(x)
ensures
- returns the next search direction (from the point x). This
direction is chosen using the Polak-Ribiere conjugate gradient
method.
!*/
{
if (been_used == false)
{
been_used = true;
prev_direction = -funct_derivative;
}
else
{
// Use the Polak-Ribiere (4.1.12) conjugate gradient described by Fletcher on page 83
const double temp = trans(prev_derivative)*prev_derivative;
// If this value hits zero then just use the direction of steepest descent.
if (std::abs(temp) < std::numeric_limits<double>::epsilon())
{
prev_derivative = funct_derivative;
prev_direction = -funct_derivative;
return prev_direction;
}
double b = trans(funct_derivative-prev_derivative)*funct_derivative/(temp);
prev_direction = -funct_derivative + b*prev_direction;
}
prev_derivative = funct_derivative;
return prev_direction;
}
private:
bool been_used;
matrix<double,0,1> prev_derivative;
matrix<double,0,1> prev_direction;
};
// ----------------------------------------------------------------------------------------
class bfgs_strategy
{
public:
bfgs_strategy() : been_used(false) {}
double get_wolfe_rho (
) const { return 0.01; }
double get_wolfe_sigma (
) const { return 0.9; }
template <typename T>
const matrix<double,0,1>& get_next_direction (
const T& x,
const double ,
const T& funct_derivative
)
{
if (been_used == false)
{
been_used = true;
H = identity_matrix<double>(x.size());
}
else
{
// update H with the BFGS formula from (3.2.12) on page 55 of Fletcher
delta = (x-prev_x);
gamma = funct_derivative-prev_derivative;
Hg = H*gamma;
gH = trans(trans(gamma)*H);
double gHg = trans(gamma)*H*gamma;
double dg = trans(delta)*gamma;
if (gHg < std::numeric_limits<double>::infinity() && dg < std::numeric_limits<double>::infinity() &&
dg != 0)
{
H += (1 + gHg/dg)*delta*trans(delta)/(dg) - (delta*trans(gH) + Hg*trans(delta))/(dg);
}
else
{
H = identity_matrix<double>(H.nr());
}
}
prev_x = x;
prev_direction = -H*funct_derivative;
prev_derivative = funct_derivative;
return prev_direction;
}
private:
bool been_used;
matrix<double,0,1> prev_x;
matrix<double,0,1> prev_derivative;
matrix<double,0,1> prev_direction;
matrix<double> H;
matrix<double,0,1> delta, gamma, Hg, gH;
};
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
// Functions that transform other functions // Functions that transform other functions
...@@ -94,7 +217,16 @@ namespace dlib ...@@ -94,7 +217,16 @@ namespace dlib
// invoked through a reference) // invoked through a reference)
COMPILE_TIME_ASSERT(is_function<funct>::value == false); COMPILE_TIME_ASSERT(is_function<funct>::value == false);
line_search_funct(const funct& f_, const T& start_, const T& direction_) : f(f_),start(start_), direction(direction_) line_search_funct(const funct& f_, const T& start_, const T& direction_)
: f(f_),start(start_), direction(direction_), matrix_r(0), scalar_r(0)
{}
line_search_funct(const funct& f_, const T& start_, const T& direction_, T& r)
: f(f_),start(start_), direction(direction_), matrix_r(&r), scalar_r(0)
{}
line_search_funct(const funct& f_, const T& start_, const T& direction_, double& r)
: f(f_),start(start_), direction(direction_), matrix_r(0), scalar_r(&r)
{} {}
double operator()(const double& x) const double operator()(const double& x) const
...@@ -102,10 +234,21 @@ namespace dlib ...@@ -102,10 +234,21 @@ namespace dlib
return get_value(f(start + x*direction)); return get_value(f(start + x*direction));
} }
// TODO figure out some requirements for these two functions and add them to the abstract
const T& get_last_scalar_eval (
) const { return *scalar_r; }
const T& get_last_gradient_eval (
) const { return *matrix_r; }
private: private:
double get_value (const double& r) const double get_value (const double& r) const
{ {
// save a copy of this value for later
if (scalar_r)
*scalar_r = r;
return r; return r;
} }
...@@ -115,12 +258,18 @@ namespace dlib ...@@ -115,12 +258,18 @@ namespace dlib
// U should be a matrix type // U should be a matrix type
COMPILE_TIME_ASSERT(is_matrix<U>::value); COMPILE_TIME_ASSERT(is_matrix<U>::value);
// save a copy of this value for later
if (matrix_r)
*matrix_r = r;
return trans(r)*direction; return trans(r)*direction;
} }
const funct& f; const funct& f;
const T& start; const T& start;
const T& direction; const T& direction;
T* matrix_r;
double* scalar_r;
}; };
template <typename funct, typename T> template <typename funct, typename T>
...@@ -135,15 +284,67 @@ namespace dlib ...@@ -135,15 +284,67 @@ namespace dlib
COMPILE_TIME_ASSERT(is_matrix<T>::value); COMPILE_TIME_ASSERT(is_matrix<T>::value);
DLIB_ASSERT ( DLIB_ASSERT (
start.nc() == 1 && direction.nc() == 1, is_col_vector(start) && is_col_vector(direction) && start.size() == direction.size(),
"\tline_search_funct make_line_search_function(f,start,direction)" "\tline_search_funct make_line_search_function(f,start,direction)"
<< "\n\tYou have to supply column vectors to this function" << "\n\tYou have to supply column vectors to this function"
<< "\n\tstart.nc(): " << start.nc() << "\n\tstart.nc(): " << start.nc()
<< "\n\tdirection.nc(): " << direction.nc() << "\n\tdirection.nc(): " << direction.nc()
<< "\n\tstart.nr(): " << start.nr()
<< "\n\tdirection.nr(): " << direction.nr()
); );
return line_search_funct<funct,T>(f,start,direction); return line_search_funct<funct,T>(f,start,direction);
} }
// ----------------------------------------------------------------------------------------
template <typename funct, typename T>
const line_search_funct<funct,T> make_line_search_function(const funct& f, const T& start, const T& direction, double& f_out)
{
// You get an error on this line when you pass in a global function to this function.
// You have to either use a function object or pass a pointer to your global function
// by taking its address using the & operator. (This check is here because gcc 4.0
// has a bug that causes it to silently corrupt return values from functions that
// invoked through a reference)
COMPILE_TIME_ASSERT(is_function<funct>::value == false);
COMPILE_TIME_ASSERT(is_matrix<T>::value);
DLIB_ASSERT (
is_col_vector(start) && is_col_vector(direction) && start.size() == direction.size(),
"\tline_search_funct make_line_search_function(f,start,direction)"
<< "\n\tYou have to supply column vectors to this function"
<< "\n\tstart.nc(): " << start.nc()
<< "\n\tdirection.nc(): " << direction.nc()
<< "\n\tstart.nr(): " << start.nr()
<< "\n\tdirection.nr(): " << direction.nr()
);
return line_search_funct<funct,T>(f,start,direction, f_out);
}
// ----------------------------------------------------------------------------------------
template <typename funct, typename T>
const line_search_funct<funct,T> make_line_search_function(const funct& f, const T& start, const T& direction, T& grad_out)
{
// You get an error on this line when you pass in a global function to this function.
// You have to either use a function object or pass a pointer to your global function
// by taking its address using the & operator. (This check is here because gcc 4.0
// has a bug that causes it to silently corrupt return values from functions that
// invoked through a reference)
COMPILE_TIME_ASSERT(is_function<funct>::value == false);
COMPILE_TIME_ASSERT(is_matrix<T>::value);
DLIB_ASSERT (
is_col_vector(start) && is_col_vector(direction) && start.size() == direction.size(),
"\tline_search_funct make_line_search_function(f,start,direction)"
<< "\n\tYou have to supply column vectors to this function"
<< "\n\tstart.nc(): " << start.nc()
<< "\n\tdirection.nc(): " << direction.nc()
<< "\n\tstart.nr(): " << start.nr()
<< "\n\tdirection.nr(): " << direction.nr()
);
return line_search_funct<funct,T>(f,start,direction,grad_out);
}
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
// Functions that perform unconstrained optimization // Functions that perform unconstrained optimization
...@@ -200,11 +401,12 @@ namespace dlib ...@@ -200,11 +401,12 @@ namespace dlib
> >
double line_search ( double line_search (
const funct& f, const funct& f,
const double f0,
const funct_der& der, const funct_der& der,
const double d0,
double rho, double rho,
double sigma, double sigma,
double minf, double minf
double& f0_out
) )
{ {
// You get an error on this line when you pass in a global function to this function. // You get an error on this line when you pass in a global function to this function.
...@@ -235,31 +437,20 @@ namespace dlib ...@@ -235,31 +437,20 @@ namespace dlib
const double tau2 = 1.0/10.0; const double tau2 = 1.0/10.0;
const double tau3 = 1.0/2.0; const double tau3 = 1.0/2.0;
const double f0 = f(0);
const double d0 = der(0);
// Stop right away and return a step size of 0 if the gradient is 0 at the starting point
if (std::abs(d0) < std::numeric_limits<double>::epsilon()) if (std::abs(d0) < std::numeric_limits<double>::epsilon())
return 0; return 0;
//DLIB_CASSERT(d0 < 0,d0);
// Figure out a reasonable upper bound on how large alpha can get.
const double mu = (minf-f0)/(rho*d0); const double mu = (minf-f0)/(rho*d0);
f0_out = f0; double alpha = 1;
if (mu < 0)
alpha = -alpha;
double f1 = f(mu); alpha = put_in_range(0, 0.65*mu, alpha);
double d1 = der(mu);
// pick the initial alpha by guessing at the minimum alpha
double alpha = mu*poly_min_extrap(f0, d0, f1, d1);
alpha = put_in_range(0.1*mu, 0.9*mu, alpha);
DLIB_CASSERT(alpha < std::numeric_limits<double>::infinity(),
"alpha: " << alpha << " mu: " << mu << " f0: " << f0 << " d0: " << d0 << " f1: " << f1 << " d1: " << d1
);
using namespace std;
//cout << "alpha: " << alpha << " mu: " << mu << " f0: " << f0 << " d0: " << d0 << " f1: " << f1 << " d1: " << d1 << endl;
double last_alpha = 0; double last_alpha = 0;
double last_val = f0; double last_val = f0;
...@@ -269,16 +460,15 @@ namespace dlib ...@@ -269,16 +460,15 @@ namespace dlib
// that contains a reasonable solution to the line search // that contains a reasonable solution to the line search
double a, b; double a, b;
// the value of f(a) // These variables will hold the values and derivatives of f(a) and f(b)
double a_val, b_val, a_val_der, b_val_der; double a_val, b_val, a_val_der, b_val_der;
// This thresh value represents the Wolfe curvature condition
const double thresh = std::abs(sigma*d0); const double thresh = std::abs(sigma*d0);
// do the bracketing stage to find the bracket range [a,b] // do the bracketing stage to find the bracket range [a,b]
while (true) while (true)
{ {
//cout << "alpha: " << alpha << " mu: " << mu << " f0: " << f0 << " d0: " << d0 << " f1: " << f1 << " d1: " << d1 << endl;
DLIB_CASSERT(alpha < std::numeric_limits<double>::infinity(), alpha);
const double val = f(alpha); const double val = f(alpha);
const double val_der = der(alpha); const double val_der = der(alpha);
...@@ -287,7 +477,7 @@ namespace dlib ...@@ -287,7 +477,7 @@ namespace dlib
if (val <= minf) if (val <= minf)
return alpha; return alpha;
if (val > f0 + alpha*d0 || val >= last_val) if (val > f0 + rho*alpha*d0 || val >= last_val)
{ {
a_val = last_val; a_val = last_val;
a_val_der = last_val_der; a_val_der = last_val_der;
...@@ -353,17 +543,9 @@ namespace dlib ...@@ -353,17 +543,9 @@ namespace dlib
} }
DLIB_CASSERT(alpha < std::numeric_limits<double>::infinity(),
"alpha: " << alpha << " mu: " << mu << " f0: " << f0 << " d0: " << d0 << " f1: " << f1 << " d1: " << d1
);
// Now do the sectioning phase from 2.6.4 // Now do the sectioning phase from 2.6.4
while (true) while (true)
{ {
//cout << "alpha: " << alpha << " mu: " << mu << " f0: " << f0 << " d0: " << d0 << " f1: " << f1 << " d1: " << d1 << endl;
DLIB_CASSERT(alpha < std::numeric_limits<double>::infinity(),
"alpha: " << alpha << " mu: " << mu << " f0: " << f0 << " d0: " << d0 << " f1: " << f1 << " d1: " << d1
);
double first = a + tau2*(b-a); double first = a + tau2*(b-a);
double last = b - tau3*(b-a); double last = b - tau3*(b-a);
...@@ -407,14 +589,18 @@ namespace dlib ...@@ -407,14 +589,18 @@ namespace dlib
} }
} }
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
template < template <
typename strategy_type,
typename funct, typename funct,
typename funct_der, typename funct_der,
typename T typename T
> >
void find_min_quasi_newton ( void find_min (
strategy_type& strategy,
const funct& f, const funct& f,
const funct_der& der, const funct_der& der,
T& x, T& x,
...@@ -433,62 +619,141 @@ namespace dlib ...@@ -433,62 +619,141 @@ namespace dlib
COMPILE_TIME_ASSERT(is_matrix<T>::value); COMPILE_TIME_ASSERT(is_matrix<T>::value);
DLIB_ASSERT ( DLIB_ASSERT (
min_delta >= 0 && x.nc() == 1, min_delta >= 0 && x.nc() == 1,
"\tdouble find_min_quasi_newton()" "\tdouble find_min()"
<< "\n\tYou have to supply column vectors to this function" << "\n\tYou have to supply column vectors to this function"
<< "\n\tmin_delta: " << min_delta << "\n\tmin_delta: " << min_delta
<< "\n\tx.nc(): " << x.nc() << "\n\tx.nc(): " << x.nc()
); );
T g, g2, s, Hg, gH;
double alpha = 10;
matrix<double,T::NR,T::NR> H(x.nr(),x.nr()); T g, s;
T delta, gamma; double alpha = 1;
H = identity_matrix<double>(H.nr());
double f_value = f(x);
g = der(x); g = der(x);
double old_f_value = f_value + 10;
double f_value = min_f - 1;
double old_f_value = 0;
// loop until the derivative is almost zero // loop until there isn't any large change in the function value
while(std::abs(old_f_value - f_value) > min_delta) while(std::abs(old_f_value - f_value) > min_delta )
{ {
old_f_value = f_value; s = strategy.get_next_direction(x, f_value, g);
s = -H*g; old_f_value = f_value;
alpha = line_search(make_line_search_function(f,x,s),make_line_search_function(der,x,s),0.01, 0.9,min_f, f_value); alpha = line_search(
make_line_search_function(f,x,s, f_value),
f_value,
make_line_search_function(der,x,s, g),
dot(g,s),
strategy.get_wolfe_rho(), strategy.get_wolfe_sigma(), min_f);
x += alpha*s; x += alpha*s;
}
}
g2 = der(x); // ----------------------------------------------------------------------------------------
// update H with the BFGS formula from (3.2.12) on page 55 of Fletcher template <
delta = alpha*s; typename strategy_type,
gamma = g2-g; typename funct,
typename T
>
void find_min_using_approximate_derivatives (
strategy_type& strategy,
const funct& f,
T& x,
double min_f,
double min_delta = 1e-7,
double derivative_eps = 1e-7
)
{
// You get an error on this line when you pass in a global function to this function.
// You have to either use a function object or pass a pointer to your global function
// by taking its address using the & operator. (This check is here because gcc 4.0
// has a bug that causes it to silently corrupt return values from functions that
// invoked through a reference)
COMPILE_TIME_ASSERT(is_function<funct>::value == false);
Hg = H*gamma; COMPILE_TIME_ASSERT(is_matrix<T>::value);
gH = trans(trans(gamma)*H); DLIB_ASSERT (
double gHg = trans(gamma)*H*gamma; min_delta >= 0 && x.nc() == 1 && derivative_eps > 0,
double dg = trans(delta)*gamma; "\tdouble find_min_using_approximate_derivatives()"
if (gHg < std::numeric_limits<double>::infinity() && dg < std::numeric_limits<double>::infinity() && << "\n\tYou have to supply column vectors to this function"
dg != 0) << "\n\tmin_delta: " << min_delta
{ << "\n\tx.nc(): " << x.nc()
H += (1 + gHg/dg)*delta*trans(delta)/(dg) - (delta*trans(gH) + Hg*trans(delta))/(dg); << "\n\tderivative_eps: " << derivative_eps
} );
else
{
H = identity_matrix<double>(H.nr());
}
g.swap(g2); T g, s;
double alpha = 1;
double f_value = f(x);
double old_f_value = f_value + 10;
// loop until there isn't any large change in the function value
while(std::abs(old_f_value - f_value) > min_delta )
{
g = derivative(f,derivative_eps)(x);
s = strategy.get_next_direction(x, f_value, g);
old_f_value = f_value;
alpha = line_search(
make_line_search_function(f,x,s,f_value),
f_value,
derivative(make_line_search_function(f,x,s),derivative_eps),
dot(g,s), // TODO. Maybe this line isn't better than the following one??
//derivative(make_line_search_function(f,x,s),derivative_eps)(0),
strategy.get_wolfe_rho(), strategy.get_wolfe_sigma(), min_f);
x += alpha*s;
} }
} }
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
template <
typename funct,
typename funct_der,
typename T
>
void find_min_quasi_newton (
const funct& f,
const funct_der& der,
T& x,
double min_f,
double min_delta = 1e-7
)
{
// You get an error on this line when you pass in a global function to this function.
// You have to either use a function object or pass a pointer to your global function
// by taking its address using the & operator. (This check is here because gcc 4.0
// has a bug that causes it to silently corrupt return values from functions that
// invoked through a reference)
COMPILE_TIME_ASSERT(is_function<funct>::value == false);
COMPILE_TIME_ASSERT(is_function<funct_der>::value == false);
COMPILE_TIME_ASSERT(is_matrix<T>::value);
DLIB_ASSERT (
min_delta >= 0 && x.nc() == 1,
"\tdouble find_min_quasi_newton()"
<< "\n\tYou have to supply column vectors to this function"
<< "\n\tmin_delta: " << min_delta
<< "\n\tx.nc(): " << x.nc()
);
bfgs_strategy strategy;
find_min(strategy, f, der, x, min_f, min_delta);
}
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
template < template <
...@@ -521,38 +786,8 @@ namespace dlib ...@@ -521,38 +786,8 @@ namespace dlib
<< "\n\tx.nc(): " << x.nc() << "\n\tx.nc(): " << x.nc()
); );
T g, g2, s; cg_strategy strategy;
double alpha = 0; find_min(strategy, f, der, x, min_f, min_delta);
g = der(x);
s = -g;
double f_value = min_f - 1;
double old_f_value = 0;
// loop until the derivative is almost zero
while(std::abs(old_f_value - f_value) > min_delta)
{
old_f_value = f_value;
alpha = line_search(make_line_search_function(f,x,s),make_line_search_function(der,x,s),0.001, 0.010,min_f, f_value);
x += alpha*s;
g2 = der(x);
const double temp = trans(g)*g;
// just stop if this value hits zero
if (std::abs(temp) < std::numeric_limits<double>::epsilon())
break;
// Use the Polak-Ribiere (4.1.12) conjugate gradient described by Fletcher on page 83
double b = trans(g2-g)*g2/(temp);
s = -g2 + b*s;
g.swap(g2);
}
} }
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
...@@ -579,64 +814,15 @@ namespace dlib ...@@ -579,64 +814,15 @@ namespace dlib
COMPILE_TIME_ASSERT(is_matrix<T>::value); COMPILE_TIME_ASSERT(is_matrix<T>::value);
DLIB_ASSERT ( DLIB_ASSERT (
min_delta >= 0 && x.nc() == 1, min_delta >= 0 && x.nc() == 1,
"\tdouble find_min_quasi_newton()" "\tdouble find_min_quasi_newton2()"
<< "\n\tYou have to supply column vectors to this function" << "\n\tYou have to supply column vectors to this function"
<< "\n\tmin_delta: " << min_delta << "\n\tmin_delta: " << min_delta
<< "\n\tx.nc(): " << x.nc() << "\n\tx.nc(): " << x.nc()
<< "\n\tderivative_eps: " << derivative_eps << "\n\tderivative_eps: " << derivative_eps
); );
T g, g2, s, Hg, gH; bfgs_strategy strategy;
double alpha = 10; find_min_using_approximate_derivatives(strategy, f, x, min_f, min_delta, derivative_eps);
matrix<double,T::NR,T::NR> H(x.nr(),x.nr());
T delta, gamma;
H = identity_matrix<double>(H.nr());
g = derivative(f,derivative_eps)(x);
double f_value = min_f - 1;
double old_f_value = 0;
// loop until there isn't any large change in the function value
while(std::abs(old_f_value - f_value) > min_delta)
{
old_f_value = f_value;
s = -H*g;
alpha = line_search(
make_line_search_function(f,x,s),
derivative(make_line_search_function(f,x,s),derivative_eps),
0.01, 0.9,min_f, f_value);
x += alpha*s;
g2 = derivative(f,derivative_eps)(x);
// update H with the BFGS formula from (3.2.12) on page 55 of Fletcher
delta = alpha*s;
gamma = g2-g;
Hg = H*gamma;
gH = trans(trans(gamma)*H);
double gHg = trans(gamma)*H*gamma;
double dg = trans(delta)*gamma;
if (gHg < std::numeric_limits<double>::infinity() && dg < std::numeric_limits<double>::infinity() &&
dg != 0)
{
H += (1 + gHg/dg)*delta*trans(delta)/(dg) - (delta*trans(gH) + Hg*trans(delta))/(dg);
}
else
{
H = identity_matrix<double>(H.nr());
}
g.swap(g2);
}
} }
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
...@@ -663,47 +849,15 @@ namespace dlib ...@@ -663,47 +849,15 @@ namespace dlib
COMPILE_TIME_ASSERT(is_matrix<T>::value); COMPILE_TIME_ASSERT(is_matrix<T>::value);
DLIB_ASSERT ( DLIB_ASSERT (
min_delta >= 0 && x.nc() == 1 && derivative_eps > 0, min_delta >= 0 && x.nc() == 1 && derivative_eps > 0,
"\tdouble find_min_conjugate_gradient()" "\tdouble find_min_conjugate_gradient2()"
<< "\n\tYou have to supply column vectors to this function" << "\n\tYou have to supply column vectors to this function"
<< "\n\tmin_delta: " << min_delta << "\n\tmin_delta: " << min_delta
<< "\n\tx.nc(): " << x.nc() << "\n\tx.nc(): " << x.nc()
<< "\n\tderivative_eps: " << derivative_eps << "\n\tderivative_eps: " << derivative_eps
); );
T g, g2, s; cg_strategy strategy;
double alpha = 0; find_min_using_approximate_derivatives(strategy, f, x, min_f, min_delta, derivative_eps);
g = derivative(f,derivative_eps)(x);
s = -g;
double f_value = min_f - 1;
double old_f_value = 0;
// loop until there isn't any large change in the function value
while(std::abs(old_f_value - f_value) > min_delta)
{
old_f_value = f_value;
alpha = line_search(
make_line_search_function(f,x,s),
derivative(make_line_search_function(f,x,s),derivative_eps),
0.001, 0.010,min_f, f_value);
x += alpha*s;
g2 = derivative(f,derivative_eps)(x);
// Use the Polak-Ribiere (4.1.12) conjugate gradient described by Fletcher on page 83
const double temp = trans(g)*g;
// just stop if this value hits zero
if (std::abs(temp) < std::numeric_limits<double>::epsilon())
break;
double b = trans(g2-g)*g2/(temp);
s = -g2 + b*s;
g.swap(g2);
}
} }
......
...@@ -88,9 +88,8 @@ namespace dlib ...@@ -88,9 +88,8 @@ namespace dlib
requires requires
- is_matrix<T>::value == true (i.e. T must be a dlib::matrix) - is_matrix<T>::value == true (i.e. T must be a dlib::matrix)
- f must take a dlib::matrix that is a column vector - f must take a dlib::matrix that is a column vector
- start.nc() == 1 - is_col_vector(start) && is_col_vector(direction) && start.size() == direction.size()
- direction.nc() == 1 (i.e. start and direction should be column vectors of the same size)
(i.e. start and direction should be column vectors)
- f must return either a double or a column vector the same length and - f must return either a double or a column vector the same length and
type as start type as start
- f(start + 1.5*direction) should be a valid expression - f(start + 1.5*direction) should be a valid expression
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment