Commit 5cf9b7e1 authored by Davis King's avatar Davis King

Added negate_function and find_max versions of the find_min functions.

--HG--
extra : convert_revision : svn%3Afdd8eb12-d10e-0410-9acb-85c331704f74/trunk%403186
parent 0e63193e
...@@ -212,6 +212,17 @@ namespace dlib ...@@ -212,6 +212,17 @@ namespace dlib
public: public:
lbfgs_search_strategy(unsigned long max_size_) : max_size(max_size_), been_used(false) {} lbfgs_search_strategy(unsigned long max_size_) : max_size(max_size_), been_used(false) {}
lbfgs_search_strategy(const lbfgs_search_strategy& item)
{
max_size = item.max_size;
been_used = item.been_used;
prev_x = item.prev_x;
prev_derivative = item.prev_derivative;
prev_direction = item.prev_direction;
alpha = item.alpha;
dh_temp = item.dh_temp;
}
double get_wolfe_rho ( double get_wolfe_rho (
) const { return 0.01; } ) const { return 0.01; }
...@@ -393,6 +404,34 @@ namespace dlib ...@@ -393,6 +404,34 @@ namespace dlib
return central_differences<funct>(f,eps); return central_differences<funct>(f,eps);
} }
// ----------------------------------------------------------------------------------------
template <typename funct>
class negate_function_object
{
public:
// You get an error on this line when you pass in a global function to this function.
// You have to either use a function object or pass a pointer to your global function
// by taking its address using the & operator. (This check is here because gcc 4.0
// has a bug that causes it to silently corrupt return values from functions that
// invoked through a reference)
COMPILE_TIME_ASSERT(is_function<funct>::value == false);
negate_function_object(const funct& f_) : f(f_){}
template <typename T>
double operator()(const T& x) const
{
return -f(x);
}
private:
const funct& f;
};
template <typename funct>
const negate_function_object<funct> negate_function(const funct& f) { return negate_function_object<funct>(f); }
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
template <typename funct, typename T> template <typename funct, typename T>
...@@ -835,6 +874,65 @@ namespace dlib ...@@ -835,6 +874,65 @@ namespace dlib
} }
} }
// ----------------------------------------------------------------------------------------
template <
typename search_strategy_type,
typename stop_strategy_type,
typename funct,
typename funct_der,
typename T
>
void find_max (
search_strategy_type search_strategy,
stop_strategy_type stop_strategy,
const funct& f,
const funct_der& der,
T& x,
double max_f
)
{
// You get an error on this line when you pass in a global function to this function.
// You have to either use a function object or pass a pointer to your global function
// by taking its address using the & operator. (This check is here because gcc 4.0
// has a bug that causes it to silently corrupt return values from functions that
// invoked through a reference)
COMPILE_TIME_ASSERT(is_function<funct>::value == false);
COMPILE_TIME_ASSERT(is_function<funct_der>::value == false);
COMPILE_TIME_ASSERT(is_matrix<T>::value);
DLIB_ASSERT (
is_col_vector(x),
"\tdouble find_max()"
<< "\n\tYou have to supply column vectors to this function"
<< "\n\tx.nc(): " << x.nc()
);
T g, s;
double f_value = -f(x);
g = -der(x);
while(stop_strategy.should_continue_search(x, f_value, g))
{
s = search_strategy.get_next_direction(x, f_value, g);
double alpha = line_search(
negate_function(make_line_search_function(f,x,s, f_value)),
f_value,
negate_function(make_line_search_function(der,x,s, g)),
dot(g,s), // compute initial gradient for the line search
search_strategy.get_wolfe_rho(), search_strategy.get_wolfe_sigma(), -max_f);
x += alpha*s;
// we have to negate the outputs from the line search
g *= -1;
f_value *= -1;
}
}
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
template < template <
...@@ -892,6 +990,49 @@ namespace dlib ...@@ -892,6 +990,49 @@ namespace dlib
} }
// ----------------------------------------------------------------------------------------
template <
typename search_strategy_type,
typename stop_strategy_type,
typename funct,
typename T
>
void find_max_using_approximate_derivatives (
search_strategy_type search_strategy,
stop_strategy_type stop_strategy,
const funct& f,
T& x,
double max_f,
double derivative_eps = 1e-7
)
{
// You get an error on this line when you pass in a global function to this function.
// You have to either use a function object or pass a pointer to your global function
// by taking its address using the & operator. (This check is here because gcc 4.0
// has a bug that causes it to silently corrupt return values from functions that
// invoked through a reference)
COMPILE_TIME_ASSERT(is_function<funct>::value == false);
COMPILE_TIME_ASSERT(is_matrix<T>::value);
DLIB_ASSERT (
is_col_vector(x) && derivative_eps > 0,
"\tdouble find_max_using_approximate_derivatives()"
<< "\n\tYou have to supply column vectors to this function"
<< "\n\tx.nc(): " << x.nc()
<< "\n\tderivative_eps: " << derivative_eps
);
find_min_using_approximate_derivatives(
search_strategy,
stop_strategy,
negate_function(f),
x,
-max_f,
derivative_eps
);
}
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment