Commit fdd035f4 authored by Davis King's avatar Davis King

Removed a bunch of checks that prevented users from using references to

functions with the optimization code and forced the use of function pointers.
This was to avoid triggering a bug in gcc 4.0.  Since that compiler is no
longer officially supported by dlib I've removed these checks to increase
usability.
parent 583110af
...@@ -23,13 +23,6 @@ namespace dlib ...@@ -23,13 +23,6 @@ namespace dlib
class central_differences class central_differences
{ {
public: public:
// You get an error on this line when you pass in a global function to this function.
// You have to either use a function object or pass a pointer to your global function
// by taking its address using the & operator. (This check is here because gcc 4.0
// has a bug that causes it to silently corrupt return values from functions that
// invoked through a reference)
COMPILE_TIME_ASSERT(is_function<funct>::value == false);
central_differences(const funct& f_, double eps_ = 1e-7) : f(f_), eps(eps_){} central_differences(const funct& f_, double eps_ = 1e-7) : f(f_), eps(eps_){}
template <typename T> template <typename T>
...@@ -100,13 +93,6 @@ namespace dlib ...@@ -100,13 +93,6 @@ namespace dlib
template <typename funct> template <typename funct>
const central_differences<funct> derivative(const funct& f, double eps) const central_differences<funct> derivative(const funct& f, double eps)
{ {
// You get an error on this line when you pass in a global function to this function.
// You have to either use a function object or pass a pointer to your global function
// by taking its address using the & operator. (This check is here because gcc 4.0
// has a bug that causes it to silently corrupt return values from functions that
// invoked through a reference)
COMPILE_TIME_ASSERT(is_function<funct>::value == false);
DLIB_ASSERT ( DLIB_ASSERT (
eps > 0, eps > 0,
"\tcentral_differences derivative(f,eps)" "\tcentral_differences derivative(f,eps)"
...@@ -122,13 +108,6 @@ namespace dlib ...@@ -122,13 +108,6 @@ namespace dlib
class negate_function_object class negate_function_object
{ {
public: public:
// You get an error on this line when you pass in a global function to this function.
// You have to either use a function object or pass a pointer to your global function
// by taking its address using the & operator. (This check is here because gcc 4.0
// has a bug that causes it to silently corrupt return values from functions that
// invoked through a reference)
COMPILE_TIME_ASSERT(is_function<funct>::value == false);
negate_function_object(const funct& f_) : f(f_){} negate_function_object(const funct& f_) : f(f_){}
template <typename T> template <typename T>
...@@ -201,14 +180,6 @@ namespace dlib ...@@ -201,14 +180,6 @@ namespace dlib
double min_f double min_f
) )
{ {
// You get an error on this line when you pass in a global function to this function.
// You have to either use a function object or pass a pointer to your global function
// by taking its address using the & operator. (This check is here because gcc 4.0
// has a bug that causes it to silently corrupt return values from functions that
// invoked through a reference)
COMPILE_TIME_ASSERT(is_function<funct>::value == false);
COMPILE_TIME_ASSERT(is_function<funct_der>::value == false);
COMPILE_TIME_ASSERT(is_matrix<T>::value); COMPILE_TIME_ASSERT(is_matrix<T>::value);
DLIB_ASSERT ( DLIB_ASSERT (
is_col_vector(x), is_col_vector(x),
...@@ -266,14 +237,6 @@ namespace dlib ...@@ -266,14 +237,6 @@ namespace dlib
double max_f double max_f
) )
{ {
// You get an error on this line when you pass in a global function to this function.
// You have to either use a function object or pass a pointer to your global function
// by taking its address using the & operator. (This check is here because gcc 4.0
// has a bug that causes it to silently corrupt return values from functions that
// invoked through a reference)
COMPILE_TIME_ASSERT(is_function<funct>::value == false);
COMPILE_TIME_ASSERT(is_function<funct_der>::value == false);
COMPILE_TIME_ASSERT(is_matrix<T>::value); COMPILE_TIME_ASSERT(is_matrix<T>::value);
DLIB_ASSERT ( DLIB_ASSERT (
is_col_vector(x), is_col_vector(x),
...@@ -338,13 +301,6 @@ namespace dlib ...@@ -338,13 +301,6 @@ namespace dlib
double derivative_eps = 1e-7 double derivative_eps = 1e-7
) )
{ {
// You get an error on this line when you pass in a global function to this function.
// You have to either use a function object or pass a pointer to your global function
// by taking its address using the & operator. (This check is here because gcc 4.0
// has a bug that causes it to silently corrupt return values from functions that
// invoked through a reference)
COMPILE_TIME_ASSERT(is_function<funct>::value == false);
COMPILE_TIME_ASSERT(is_matrix<T>::value); COMPILE_TIME_ASSERT(is_matrix<T>::value);
DLIB_ASSERT ( DLIB_ASSERT (
is_col_vector(x) && derivative_eps > 0, is_col_vector(x) && derivative_eps > 0,
...@@ -405,13 +361,6 @@ namespace dlib ...@@ -405,13 +361,6 @@ namespace dlib
double derivative_eps = 1e-7 double derivative_eps = 1e-7
) )
{ {
// You get an error on this line when you pass in a global function to this function.
// You have to either use a function object or pass a pointer to your global function
// by taking its address using the & operator. (This check is here because gcc 4.0
// has a bug that causes it to silently corrupt return values from functions that
// invoked through a reference)
COMPILE_TIME_ASSERT(is_function<funct>::value == false);
COMPILE_TIME_ASSERT(is_matrix<T>::value); COMPILE_TIME_ASSERT(is_matrix<T>::value);
DLIB_ASSERT ( DLIB_ASSERT (
is_col_vector(x) && derivative_eps > 0, is_col_vector(x) && derivative_eps > 0,
......
...@@ -3353,14 +3353,6 @@ L210: ...@@ -3353,14 +3353,6 @@ L210:
const long max_f_evals const long max_f_evals
) )
{ {
// You get an error on this line when you pass in a global function to this function.
// You have to either use a function object or pass a pointer to your global function
// by taking its address using the & operator. (This check is here because gcc 4.0
// has a bug that causes it to silently corrupt return values from functions that
// invoked through a reference)
COMPILE_TIME_ASSERT(is_function<funct>::value == false);
// check the requirements. Also split the assert up so that the error message isn't huge. // check the requirements. Also split the assert up so that the error message isn't huge.
DLIB_CASSERT(is_col_vector(x) && is_col_vector(x_lower) && is_col_vector(x_upper) && DLIB_CASSERT(is_col_vector(x) && is_col_vector(x_lower) && is_col_vector(x_upper) &&
x.size() == x_lower.size() && x_lower.size() == x_upper.size() && x.size() == x_lower.size() && x_lower.size() == x_upper.size() &&
......
...@@ -19,13 +19,6 @@ namespace dlib ...@@ -19,13 +19,6 @@ namespace dlib
class line_search_funct class line_search_funct
{ {
public: public:
// You get an error on this line when you pass in a global function to this function.
// You have to either use a function object or pass a pointer to your global function
// by taking its address using the & operator. (This check is here because gcc 4.0
// has a bug that causes it to silently corrupt return values from functions that
// invoked through a reference)
COMPILE_TIME_ASSERT(is_function<funct>::value == false);
line_search_funct(const funct& f_, const T& start_, const T& direction_) line_search_funct(const funct& f_, const T& start_, const T& direction_)
: f(f_),start(start_), direction(direction_), matrix_r(0), scalar_r(0) : f(f_),start(start_), direction(direction_), matrix_r(0), scalar_r(0)
{} {}
...@@ -77,13 +70,6 @@ namespace dlib ...@@ -77,13 +70,6 @@ namespace dlib
template <typename funct, typename T> template <typename funct, typename T>
const line_search_funct<funct,T> make_line_search_function(const funct& f, const T& start, const T& direction) const line_search_funct<funct,T> make_line_search_function(const funct& f, const T& start, const T& direction)
{ {
// You get an error on this line when you pass in a global function to this function.
// You have to either use a function object or pass a pointer to your global function
// by taking its address using the & operator. (This check is here because gcc 4.0
// has a bug that causes it to silently corrupt return values from functions that
// invoked through a reference)
COMPILE_TIME_ASSERT(is_function<funct>::value == false);
COMPILE_TIME_ASSERT(is_matrix<T>::value); COMPILE_TIME_ASSERT(is_matrix<T>::value);
DLIB_ASSERT ( DLIB_ASSERT (
is_col_vector(start) && is_col_vector(direction) && start.size() == direction.size(), is_col_vector(start) && is_col_vector(direction) && start.size() == direction.size(),
...@@ -102,13 +88,6 @@ namespace dlib ...@@ -102,13 +88,6 @@ namespace dlib
template <typename funct, typename T> template <typename funct, typename T>
const line_search_funct<funct,T> make_line_search_function(const funct& f, const T& start, const T& direction, double& f_out) const line_search_funct<funct,T> make_line_search_function(const funct& f, const T& start, const T& direction, double& f_out)
{ {
// You get an error on this line when you pass in a global function to this function.
// You have to either use a function object or pass a pointer to your global function
// by taking its address using the & operator. (This check is here because gcc 4.0
// has a bug that causes it to silently corrupt return values from functions that
// invoked through a reference)
COMPILE_TIME_ASSERT(is_function<funct>::value == false);
COMPILE_TIME_ASSERT(is_matrix<T>::value); COMPILE_TIME_ASSERT(is_matrix<T>::value);
DLIB_ASSERT ( DLIB_ASSERT (
is_col_vector(start) && is_col_vector(direction) && start.size() == direction.size(), is_col_vector(start) && is_col_vector(direction) && start.size() == direction.size(),
...@@ -127,13 +106,6 @@ namespace dlib ...@@ -127,13 +106,6 @@ namespace dlib
template <typename funct, typename T> template <typename funct, typename T>
const line_search_funct<funct,T> make_line_search_function(const funct& f, const T& start, const T& direction, T& grad_out) const line_search_funct<funct,T> make_line_search_function(const funct& f, const T& start, const T& direction, T& grad_out)
{ {
// You get an error on this line when you pass in a global function to this function.
// You have to either use a function object or pass a pointer to your global function
// by taking its address using the & operator. (This check is here because gcc 4.0
// has a bug that causes it to silently corrupt return values from functions that
// invoked through a reference)
COMPILE_TIME_ASSERT(is_function<funct>::value == false);
COMPILE_TIME_ASSERT(is_matrix<T>::value); COMPILE_TIME_ASSERT(is_matrix<T>::value);
DLIB_ASSERT ( DLIB_ASSERT (
is_col_vector(start) && is_col_vector(direction) && start.size() == direction.size(), is_col_vector(start) && is_col_vector(direction) && start.size() == direction.size(),
...@@ -269,14 +241,6 @@ namespace dlib ...@@ -269,14 +241,6 @@ namespace dlib
unsigned long max_iter unsigned long max_iter
) )
{ {
// You get an error on this line when you pass in a global function to this function.
// You have to either use a function object or pass a pointer to your global function
// by taking its address using the & operator. (This check is here because gcc 4.0
// has a bug that causes it to silently corrupt return values from functions that
// invoked through a reference)
COMPILE_TIME_ASSERT(is_function<funct>::value == false);
COMPILE_TIME_ASSERT(is_function<funct_der>::value == false);
DLIB_ASSERT ( DLIB_ASSERT (
0 < rho && rho < sigma && sigma < 1 && max_iter > 0, 0 < rho && rho < sigma && sigma < 1 && max_iter > 0,
"\tdouble line_search()" "\tdouble line_search()"
...@@ -529,13 +493,6 @@ namespace dlib ...@@ -529,13 +493,6 @@ namespace dlib
const long max_iter = 100 const long max_iter = 100
) )
{ {
// You get an error on this line when you pass in a global function to this function.
// You have to either use a function object or pass a pointer to your global function
// by taking its address using the & operator. (This check is here because gcc 4.0
// has a bug that causes it to silently corrupt return values from functions that
// invoked through a reference)
COMPILE_TIME_ASSERT(is_function<funct>::value == false);
DLIB_CASSERT( eps > 0 && DLIB_CASSERT( eps > 0 &&
max_iter > 1 && max_iter > 1 &&
begin <= starting_point && starting_point <= end, begin <= starting_point && starting_point <= end,
...@@ -790,13 +747,6 @@ namespace dlib ...@@ -790,13 +747,6 @@ namespace dlib
const long max_iter = 100 const long max_iter = 100
) )
{ {
// You get an error on this line when you pass in a global function to this function.
// You have to either use a function object or pass a pointer to your global function
// by taking its address using the & operator. (This check is here because gcc 4.0
// has a bug that causes it to silently corrupt return values from functions that
// invoked through a reference)
COMPILE_TIME_ASSERT(is_function<funct>::value == false);
return -find_min_single_variable(negate_function(f), starting_point, begin, end, eps, max_iter); return -find_min_single_variable(negate_function(f), starting_point, begin, end, eps, max_iter);
} }
......
...@@ -313,7 +313,7 @@ namespace dlib ...@@ -313,7 +313,7 @@ namespace dlib
template <typename hessian_funct> template <typename hessian_funct>
newton_search_strategy_obj<hessian_funct> newton_search_strategy ( newton_search_strategy_obj<hessian_funct> newton_search_strategy (
const hessian_funct& hessian hessian_funct hessian
) { return newton_search_strategy_obj<hessian_funct>(hessian); } ) { return newton_search_strategy_obj<hessian_funct>(hessian); }
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
......
...@@ -313,7 +313,7 @@ namespace dlib ...@@ -313,7 +313,7 @@ namespace dlib
template <typename hessian_funct> template <typename hessian_funct>
newton_search_strategy_obj<hessian_funct> newton_search_strategy ( newton_search_strategy_obj<hessian_funct> newton_search_strategy (
const hessian_funct& hessian hessian_funct hessian
) { return newton_search_strategy_obj<hessian_funct>(hessian); } ) { return newton_search_strategy_obj<hessian_funct>(hessian); }
/*! /*!
ensures ensures
......
...@@ -35,8 +35,8 @@ namespace ...@@ -35,8 +35,8 @@ namespace
ch = chebyquad_start(2); ch = chebyquad_start(2);
solve_least_squares(objective_delta_stop_strategy(1e-13, 80), solve_least_squares(objective_delta_stop_strategy(1e-13, 80),
&chebyquad_residual, chebyquad_residual,
derivative(&chebyquad_residual), derivative(chebyquad_residual),
range(0,ch.size()-1), range(0,ch.size()-1),
ch); ch);
...@@ -53,8 +53,8 @@ namespace ...@@ -53,8 +53,8 @@ namespace
ch = chebyquad_start(2); ch = chebyquad_start(2);
solve_least_squares_lm(objective_delta_stop_strategy(1e-13, 80), solve_least_squares_lm(objective_delta_stop_strategy(1e-13, 80),
&chebyquad_residual, chebyquad_residual,
derivative(&chebyquad_residual), derivative(chebyquad_residual),
range(0,ch.size()-1), range(0,ch.size()-1),
ch); ch);
...@@ -73,8 +73,8 @@ namespace ...@@ -73,8 +73,8 @@ namespace
ch = chebyquad_start(2); ch = chebyquad_start(2);
solve_least_squares(objective_delta_stop_strategy(1e-13, 80), solve_least_squares(objective_delta_stop_strategy(1e-13, 80),
&chebyquad_residual, chebyquad_residual,
derivative(&chebyquad_residual), derivative(chebyquad_residual),
range(0,ch.size()-1), range(0,ch.size()-1),
ch); ch);
...@@ -92,8 +92,8 @@ namespace ...@@ -92,8 +92,8 @@ namespace
ch = chebyquad_start(2); ch = chebyquad_start(2);
solve_least_squares_lm(objective_delta_stop_strategy(1e-13, 80), solve_least_squares_lm(objective_delta_stop_strategy(1e-13, 80),
&chebyquad_residual, chebyquad_residual,
derivative(&chebyquad_residual), derivative(chebyquad_residual),
range(0,ch.size()-1), range(0,ch.size()-1),
ch); ch);
...@@ -112,8 +112,8 @@ namespace ...@@ -112,8 +112,8 @@ namespace
ch = chebyquad_start(4); ch = chebyquad_start(4);
solve_least_squares(objective_delta_stop_strategy(1e-13, 80), solve_least_squares(objective_delta_stop_strategy(1e-13, 80),
&chebyquad_residual, chebyquad_residual,
derivative(&chebyquad_residual), derivative(chebyquad_residual),
range(0,ch.size()-1), range(0,ch.size()-1),
ch); ch);
...@@ -131,8 +131,8 @@ namespace ...@@ -131,8 +131,8 @@ namespace
ch = chebyquad_start(4); ch = chebyquad_start(4);
solve_least_squares_lm(objective_delta_stop_strategy(1e-13, 80), solve_least_squares_lm(objective_delta_stop_strategy(1e-13, 80),
&chebyquad_residual, chebyquad_residual,
derivative(&chebyquad_residual), derivative(chebyquad_residual),
range(0,ch.size()-1), range(0,ch.size()-1),
ch); ch);
...@@ -152,8 +152,8 @@ namespace ...@@ -152,8 +152,8 @@ namespace
ch = chebyquad_start(6); ch = chebyquad_start(6);
solve_least_squares(objective_delta_stop_strategy(1e-13, 80), solve_least_squares(objective_delta_stop_strategy(1e-13, 80),
&chebyquad_residual, chebyquad_residual,
derivative(&chebyquad_residual), derivative(chebyquad_residual),
range(0,ch.size()-1), range(0,ch.size()-1),
ch); ch);
...@@ -174,8 +174,8 @@ namespace ...@@ -174,8 +174,8 @@ namespace
ch = chebyquad_start(6); ch = chebyquad_start(6);
solve_least_squares_lm(objective_delta_stop_strategy(1e-13, 80), solve_least_squares_lm(objective_delta_stop_strategy(1e-13, 80),
&chebyquad_residual, chebyquad_residual,
derivative(&chebyquad_residual), derivative(chebyquad_residual),
range(0,ch.size()-1), range(0,ch.size()-1),
ch); ch);
...@@ -195,8 +195,8 @@ namespace ...@@ -195,8 +195,8 @@ namespace
ch = chebyquad_start(8); ch = chebyquad_start(8);
solve_least_squares(objective_delta_stop_strategy(1e-13, 80), solve_least_squares(objective_delta_stop_strategy(1e-13, 80),
&chebyquad_residual, chebyquad_residual,
derivative(&chebyquad_residual), derivative(chebyquad_residual),
range(0,ch.size()-1), range(0,ch.size()-1),
ch); ch);
...@@ -214,8 +214,8 @@ namespace ...@@ -214,8 +214,8 @@ namespace
ch = chebyquad_start(8); ch = chebyquad_start(8);
solve_least_squares_lm(objective_delta_stop_strategy(1e-13, 80), solve_least_squares_lm(objective_delta_stop_strategy(1e-13, 80),
&chebyquad_residual, chebyquad_residual,
derivative(&chebyquad_residual), derivative(chebyquad_residual),
range(0,ch.size()-1), range(0,ch.size()-1),
ch); ch);
...@@ -239,8 +239,8 @@ namespace ...@@ -239,8 +239,8 @@ namespace
ch = brown_start(); ch = brown_start();
solve_least_squares(objective_delta_stop_strategy(1e-13, 300), solve_least_squares(objective_delta_stop_strategy(1e-13, 300),
&brown_residual, brown_residual,
derivative(&brown_residual), derivative(brown_residual),
range(1,20), range(1,20),
ch); ch);
...@@ -258,8 +258,8 @@ namespace ...@@ -258,8 +258,8 @@ namespace
ch = brown_start(); ch = brown_start();
solve_least_squares_lm(objective_delta_stop_strategy(1e-13, 80), solve_least_squares_lm(objective_delta_stop_strategy(1e-13, 80),
&brown_residual, brown_residual,
derivative(&brown_residual), derivative(brown_residual),
range(1,20), range(1,20),
ch); ch);
...@@ -301,8 +301,8 @@ namespace ...@@ -301,8 +301,8 @@ namespace
ch = rosen_start<double>(); ch = rosen_start<double>();
solve_least_squares(objective_delta_stop_strategy(1e-13, 80), solve_least_squares(objective_delta_stop_strategy(1e-13, 80),
&rosen_residual_double, rosen_residual_double,
&rosen_residual_derivative_double, rosen_residual_derivative_double,
range(1,20), range(1,20),
ch); ch);
...@@ -319,8 +319,8 @@ namespace ...@@ -319,8 +319,8 @@ namespace
ch = rosen_start<double>(); ch = rosen_start<double>();
solve_least_squares_lm(objective_delta_stop_strategy(1e-13, 80), solve_least_squares_lm(objective_delta_stop_strategy(1e-13, 80),
&rosen_residual_double, rosen_residual_double,
&rosen_residual_derivative_double, rosen_residual_derivative_double,
range(1,20), range(1,20),
ch); ch);
...@@ -340,8 +340,8 @@ namespace ...@@ -340,8 +340,8 @@ namespace
ch = rosen_start<double>(); ch = rosen_start<double>();
solve_least_squares(objective_delta_stop_strategy(1e-13, 80), solve_least_squares(objective_delta_stop_strategy(1e-13, 80),
&rosen_residual_double, rosen_residual_double,
derivative(&rosen_residual_double), derivative(rosen_residual_double),
range(1,20), range(1,20),
ch); ch);
...@@ -358,8 +358,8 @@ namespace ...@@ -358,8 +358,8 @@ namespace
ch = rosen_start<float>(); ch = rosen_start<float>();
solve_least_squares(objective_delta_stop_strategy(1e-13, 80), solve_least_squares(objective_delta_stop_strategy(1e-13, 80),
&rosen_residual_float, rosen_residual_float,
derivative(&rosen_residual_float), derivative(rosen_residual_float),
range(1,20), range(1,20),
ch); ch);
...@@ -376,8 +376,8 @@ namespace ...@@ -376,8 +376,8 @@ namespace
ch = rosen_start<float>(); ch = rosen_start<float>();
solve_least_squares_lm(objective_delta_stop_strategy(1e-13, 80), solve_least_squares_lm(objective_delta_stop_strategy(1e-13, 80),
&rosen_residual_float, rosen_residual_float,
derivative(&rosen_residual_float), derivative(rosen_residual_float),
range(1,20), range(1,20),
ch); ch);
...@@ -394,8 +394,8 @@ namespace ...@@ -394,8 +394,8 @@ namespace
ch = rosen_start<double>(); ch = rosen_start<double>();
solve_least_squares_lm(objective_delta_stop_strategy(1e-13, 80), solve_least_squares_lm(objective_delta_stop_strategy(1e-13, 80),
&rosen_residual_double, rosen_residual_double,
derivative(&rosen_residual_double), derivative(rosen_residual_double),
range(1,20), range(1,20),
ch); ch);
...@@ -412,8 +412,8 @@ namespace ...@@ -412,8 +412,8 @@ namespace
ch = rosen_big_start<double>(); ch = rosen_big_start<double>();
solve_least_squares(objective_delta_stop_strategy(1e-13, 80), solve_least_squares(objective_delta_stop_strategy(1e-13, 80),
&rosen_big_residual_double, rosen_big_residual_double,
derivative(&rosen_big_residual_double), derivative(rosen_big_residual_double),
range(1,2), range(1,2),
ch); ch);
......
This diff is collapsed.
...@@ -111,7 +111,7 @@ int main() ...@@ -111,7 +111,7 @@ int main()
// the approximate derivative computed using central differences (via derivative()). // the approximate derivative computed using central differences (via derivative()).
// If this value is big then it means we probably typed the derivative function incorrectly. // If this value is big then it means we probably typed the derivative function incorrectly.
cout << "derivative error: " << length(residual_derivative(data_samples[0], params) - cout << "derivative error: " << length(residual_derivative(data_samples[0], params) -
derivative(&residual)(data_samples[0], params) ) << endl; derivative(residual)(data_samples[0], params) ) << endl;
...@@ -126,8 +126,8 @@ int main() ...@@ -126,8 +126,8 @@ int main()
// Use the Levenberg-Marquardt method to determine the parameters which // Use the Levenberg-Marquardt method to determine the parameters which
// minimize the sum of all squared residuals. // minimize the sum of all squared residuals.
solve_least_squares_lm(objective_delta_stop_strategy(1e-7).be_verbose(), solve_least_squares_lm(objective_delta_stop_strategy(1e-7).be_verbose(),
&residual, residual,
&residual_derivative, residual_derivative,
data_samples, data_samples,
x); x);
...@@ -144,8 +144,8 @@ int main() ...@@ -144,8 +144,8 @@ int main()
// If we didn't create the residual_derivative function then we could // If we didn't create the residual_derivative function then we could
// have used this method which numerically approximates the derivatives for you. // have used this method which numerically approximates the derivatives for you.
solve_least_squares_lm(objective_delta_stop_strategy(1e-7).be_verbose(), solve_least_squares_lm(objective_delta_stop_strategy(1e-7).be_verbose(),
&residual, residual,
derivative(&residual), derivative(residual),
data_samples, data_samples,
x); x);
...@@ -163,8 +163,8 @@ int main() ...@@ -163,8 +163,8 @@ int main()
// where the residuals don't go to zero at the solution. So in these cases // where the residuals don't go to zero at the solution. So in these cases
// it may provide a better answer. // it may provide a better answer.
solve_least_squares(objective_delta_stop_strategy(1e-7).be_verbose(), solve_least_squares(objective_delta_stop_strategy(1e-7).be_verbose(),
&residual, residual,
&residual_derivative, residual_derivative,
data_samples, data_samples,
x); x);
......
...@@ -175,7 +175,7 @@ int main() ...@@ -175,7 +175,7 @@ int main()
// the results are similar. If they are very different then you probably made a // the results are similar. If they are very different then you probably made a
// mistake. So the first thing we do is compare the results at a test point: // mistake. So the first thing we do is compare the results at a test point:
cout << "Difference between analytic derivative and numerical approximation of derivative: " cout << "Difference between analytic derivative and numerical approximation of derivative: "
<< length(derivative(&rosen)(starting_point) - rosen_derivative(starting_point)) << endl; << length(derivative(rosen)(starting_point) - rosen_derivative(starting_point)) << endl;
cout << "Find the minimum of the rosen function()" << endl; cout << "Find the minimum of the rosen function()" << endl;
...@@ -194,7 +194,7 @@ int main() ...@@ -194,7 +194,7 @@ int main()
find_min(bfgs_search_strategy(), // Use BFGS search algorithm find_min(bfgs_search_strategy(), // Use BFGS search algorithm
objective_delta_stop_strategy(1e-7), // Stop when the change in rosen() is less than 1e-7 objective_delta_stop_strategy(1e-7), // Stop when the change in rosen() is less than 1e-7
&rosen, &rosen_derivative, starting_point, -1); rosen, rosen_derivative, starting_point, -1);
// Once the function ends the starting_point vector will contain the optimum point // Once the function ends the starting_point vector will contain the optimum point
// of (1,1). // of (1,1).
cout << "rosen solution:\n" << starting_point << endl; cout << "rosen solution:\n" << starting_point << endl;
...@@ -207,7 +207,7 @@ int main() ...@@ -207,7 +207,7 @@ int main()
starting_point = -94, 5.2; starting_point = -94, 5.2;
find_min_using_approximate_derivatives(bfgs_search_strategy(), find_min_using_approximate_derivatives(bfgs_search_strategy(),
objective_delta_stop_strategy(1e-7), objective_delta_stop_strategy(1e-7),
&rosen, starting_point, -1); rosen, starting_point, -1);
// Again the correct minimum point is found and stored in starting_point // Again the correct minimum point is found and stored in starting_point
cout << "rosen solution:\n" << starting_point << endl; cout << "rosen solution:\n" << starting_point << endl;
...@@ -222,14 +222,14 @@ int main() ...@@ -222,14 +222,14 @@ int main()
find_min(lbfgs_search_strategy(10), // The 10 here is basically a measure of how much memory L-BFGS will use. find_min(lbfgs_search_strategy(10), // The 10 here is basically a measure of how much memory L-BFGS will use.
objective_delta_stop_strategy(1e-7).be_verbose(), // Adding be_verbose() causes a message to be objective_delta_stop_strategy(1e-7).be_verbose(), // Adding be_verbose() causes a message to be
// printed for each iteration of optimization. // printed for each iteration of optimization.
&rosen, &rosen_derivative, starting_point, -1); rosen, rosen_derivative, starting_point, -1);
cout << endl << "rosen solution: \n" << starting_point << endl; cout << endl << "rosen solution: \n" << starting_point << endl;
starting_point = -94, 5.2; starting_point = -94, 5.2;
find_min_using_approximate_derivatives(lbfgs_search_strategy(10), find_min_using_approximate_derivatives(lbfgs_search_strategy(10),
objective_delta_stop_strategy(1e-7), objective_delta_stop_strategy(1e-7),
&rosen, starting_point, -1); rosen, starting_point, -1);
cout << "rosen solution: \n"<< starting_point << endl; cout << "rosen solution: \n"<< starting_point << endl;
...@@ -237,10 +237,10 @@ int main() ...@@ -237,10 +237,10 @@ int main()
// In many cases, it is useful if we also provide second derivative information // In many cases, it is useful if we also provide second derivative information
// to the optimizers. Two examples of how we can do that are shown below. // to the optimizers. Two examples of how we can do that are shown below.
starting_point = 0.8, 1.3; starting_point = 0.8, 1.3;
find_min(newton_search_strategy(&rosen_hessian), find_min(newton_search_strategy(rosen_hessian),
objective_delta_stop_strategy(1e-7), objective_delta_stop_strategy(1e-7),
&rosen, rosen,
&rosen_derivative, rosen_derivative,
starting_point, starting_point,
-1); -1);
cout << "rosen solution: \n"<< starting_point << endl; cout << "rosen solution: \n"<< starting_point << endl;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment