Commit fa59a5c0 authored by Davis King's avatar Davis King

Made the spec more clear and also added some tests

--HG--
extra : convert_revision : svn%3Afdd8eb12-d10e-0410-9acb-85c331704f74/trunk%402433
parent b21526cb
...@@ -25,6 +25,10 @@ namespace dlib ...@@ -25,6 +25,10 @@ namespace dlib
/*! /*!
This is a function object that represents the derivative of some other This is a function object that represents the derivative of some other
function. function.
Note that if funct is a function of a double then the derivative of
funct is just a double but if funct is a function of a dlib::matrix (i.e. a
function of many variables) then its derivative is a gradient vector.
!*/ !*/
template < template <
...@@ -65,8 +69,10 @@ namespace dlib ...@@ -65,8 +69,10 @@ namespace dlib
> >
class line_search_funct; class line_search_funct;
/*! /*!
This is a function object l(double x) == f(start + x*direction). That is, it This object is a function object that represents a line search function.
represents the function l(double x) and takes f, start, and direction as arguments.
It represents a function with the signature:
double l(double x)
!*/ !*/
template < template <
...@@ -80,18 +86,24 @@ namespace dlib ...@@ -80,18 +86,24 @@ namespace dlib
); );
/*! /*!
requires requires
- f == a function that returns a scalar
- f must take a dlib::matrix that is a column vector
- is_matrix<T>::value == true (i.e. T must be a dlib::matrix) - is_matrix<T>::value == true (i.e. T must be a dlib::matrix)
- f must take a dlib::matrix that is a column vector
- start.nc() == 1 - start.nc() == 1
- direction.nc() == 1 - direction.nc() == 1
(i.e. start and direction should be column vectors) (i.e. start and direction should be column vectors)
- f(start + 1.5*direction) should be a valid expression that - f must return either a double or a column vector the same length and
evaluates to a double type as start
- f(start + 1.5*direction) should be a valid expression
ensures ensures
- returns a function that represents the function l(double x) - if (f returns a double) then
that is defined as: - returns a line search function that computes l(x) == f(start + x*direction)
- l(x) == f(start + x*direction) - else
- returns a line search function that computes l(x) == trans(f(start + x*direction))*direction
- We assume that f is the derivative of some other function and that what
f returns is a gradient vector.
So the following two expressions both create the derivative of l(x):
- derivative(make_line_search_function(funct,start,direction))
- make_line_search_function(derivative(funct),start,direction)
!*/ !*/
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
......
...@@ -398,6 +398,39 @@ namespace ...@@ -398,6 +398,39 @@ namespace
p(3) = 1; p(3) = 1;
test_powell(p); test_powell(p);
{
matrix<double,2,1> m;
m(0) = -0.43;
m(1) = 0.919;
DLIB_CASSERT(dlib::equal(der_rosen(m) , derivative(&rosen)(m),1e-5),"");
DLIB_CASSERT(std::abs(derivative(make_line_search_function(&rosen,m,m))(0) -
make_line_search_function(derivative(&rosen),m,m)(0)) < 1e-5,"");
DLIB_CASSERT(std::abs(derivative(make_line_search_function(&rosen,m,m))(1) -
make_line_search_function(derivative(&rosen),m,m)(1)) < 1e-5,"");
DLIB_CASSERT(std::abs(derivative(make_line_search_function(&rosen,m,m))(0) -
make_line_search_function(&der_rosen,m,m)(0)) < 1e-5,"");
DLIB_CASSERT(std::abs(derivative(make_line_search_function(&rosen,m,m))(1) -
make_line_search_function(&der_rosen,m,m)(1)) < 1e-5,"");
}
{
matrix<double,2,1> m;
m(0) = 1;
m(1) = 2;
DLIB_CASSERT(dlib::equal(der_rosen(m) , derivative(&rosen)(m),1e-5),"");
DLIB_CASSERT(std::abs(derivative(make_line_search_function(&rosen,m,m))(0) -
make_line_search_function(derivative(&rosen),m,m)(0)) < 1e-5,"");
DLIB_CASSERT(std::abs(derivative(make_line_search_function(&rosen,m,m))(1) -
make_line_search_function(derivative(&rosen),m,m)(1)) < 1e-5,"");
DLIB_CASSERT(std::abs(derivative(make_line_search_function(&rosen,m,m))(0) -
make_line_search_function(&der_rosen,m,m)(0)) < 1e-5,"");
DLIB_CASSERT(std::abs(derivative(make_line_search_function(&rosen,m,m))(1) -
make_line_search_function(&der_rosen,m,m)(1)) < 1e-5,"");
}
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment