Commit dfd9543c authored by Davis King's avatar Davis King

Made the elastic_net inputs be in terms of trans(X)*X and trans(X)*Y rather

than raw X and Y matrices.
parent 69a12074
...@@ -17,13 +17,16 @@ namespace dlib ...@@ -17,13 +17,16 @@ namespace dlib
template <typename EXP> template <typename EXP>
explicit elastic_net( explicit elastic_net(
const matrix_exp<EXP>& X_ const matrix_exp<EXP>& XX
) : eps(1e-5), max_iterations(50000), verbose(false) ) : eps(1e-5), max_iterations(50000), verbose(false)
{ {
// make sure requires clause is not broken // make sure requires clause is not broken
DLIB_ASSERT(X_.size() > 0, DLIB_ASSERT(XX.size() > 0 &&
"\t elastic_net::elastic_net(X)" XX.nr() == XX.nc(),
<< " \n\t X can't be empty" "\t elastic_net::elastic_net(XX)"
<< " \n\t XX must be a non-empty square matrix."
<< " \n\t XX.nr(): " << XX.nr()
<< " \n\t XX.nc(): " << XX.nc()
<< " \n\t this: " << this << " \n\t this: " << this
); );
...@@ -32,13 +35,11 @@ namespace dlib ...@@ -32,13 +35,11 @@ namespace dlib
// rows then we can get rid of them by doing some SVD magic. Doing this doesn't // rows then we can get rid of them by doing some SVD magic. Doing this doesn't
// make the final results of anything change but makes all the matrices have // make the final results of anything change but makes all the matrices have
// dimensions that are X.nr() in size, which can be much smaller. // dimensions that are X.nr() in size, which can be much smaller.
matrix<double> XX;
XX = X_*trans(X_);
matrix<double,0,1> s; matrix<double,0,1> s;
svd3(XX,u,eig_vals,eig_vects); svd3(XX,u,eig_vals,eig_vects);
s = sqrt(eig_vals); s = sqrt(eig_vals);
X = eig_vects*diagm(s); X = eig_vects*diagm(s);
u = trans(X_)*tmp(eig_vects*inv(diagm(s))); u = eig_vects*inv(diagm(s));
...@@ -65,46 +66,48 @@ namespace dlib ...@@ -65,46 +66,48 @@ namespace dlib
template <typename EXP1, typename EXP2> template <typename EXP1, typename EXP2>
elastic_net( elastic_net(
const matrix_exp<EXP1>& X_, const matrix_exp<EXP1>& XX,
const matrix_exp<EXP2>& Y_ const matrix_exp<EXP2>& XY
) : elastic_net(X_) ) : elastic_net(XX)
{ {
// make sure requires clause is not broken // make sure requires clause is not broken
DLIB_ASSERT(X_.size() > 0 && DLIB_ASSERT(XX.size() > 0 &&
is_col_vector(Y_) && XX.nr() == XX.nc() &&
X_.nc() == Y_.size() , is_col_vector(XY) &&
"\t elastic_net::elastic_net(X,Y)" XX.nc() == XY.size() ,
"\t elastic_net::elastic_net(XX,XY)"
<< " \n\t Invalid inputs were given to this function." << " \n\t Invalid inputs were given to this function."
<< " \n\t X_.size(): " << X_.size() << " \n\t XX.size(): " << XX.size()
<< " \n\t is_col_vector(Y_): " << is_col_vector(Y_) << " \n\t is_col_vector(XY): " << is_col_vector(XY)
<< " \n\t X_.nc(): " << X_.nc() << " \n\t XX.nr(): " << XX.nr()
<< " \n\t Y_.size(): " << Y_.size() << " \n\t XX.nc(): " << XX.nc()
<< " \n\t XY.size(): " << XY.size()
<< " \n\t this: " << this << " \n\t this: " << this
); );
set_y(Y_); set_xy(XY);
} }
long size ( long size (
) const { return u.nr(); } ) const { return u.nr(); }
template <typename EXP> template <typename EXP>
void set_y( void set_xy(
const matrix_exp<EXP>& Y_ const matrix_exp<EXP>& XY
) )
{ {
// make sure requires clause is not broken // make sure requires clause is not broken
DLIB_ASSERT(is_col_vector(Y_) && DLIB_ASSERT(is_col_vector(XY) &&
Y_.size() == size(), XY.size() == size(),
"\t void elastic_net::set_y(Y)" "\t void elastic_net::set_y(Y)"
<< " \n\t Invalid inputs were given to this function." << " \n\t Invalid inputs were given to this function."
<< " \n\t is_col_vector(Y_): " << is_col_vector(Y_) << " \n\t is_col_vector(XY): " << is_col_vector(XY)
<< " \n\t size(): " << size() << " \n\t size(): " << size()
<< " \n\t Y_.size(): " << Y_.size() << " \n\t XY.size(): " << XY.size()
<< " \n\t this: " << this << " \n\t this: " << this
); );
Y = trans(u)*Y_; Y = trans(u)*XY;
// We can use the ynorm after it has been projected because the only place Y // We can use the ynorm after it has been projected because the only place Y
// appears in the algorithm is in terms of dot products with w and x vectors. // appears in the algorithm is in terms of dot products with w and x vectors.
// But those vectors are always in the span of X and therefore we only see the // But those vectors are always in the span of X and therefore we only see the
......
...@@ -44,52 +44,61 @@ namespace dlib ...@@ -44,52 +44,61 @@ namespace dlib
template <typename EXP> template <typename EXP>
explicit elastic_net( explicit elastic_net(
const matrix_exp<EXP>& X const matrix_exp<EXP>& XX
); );
/*! /*!
requires requires
- X.size() != 0 - XX.size() != 0
- XX.nr() == XX.nc()
ensures ensures
- #get_epsilon() == 1e-5 - #get_epsilon() == 1e-5
- #get_max_iterations() == 50000 - #get_max_iterations() == 50000
- this object will not be verbose unless be_verbose() is called - This object will not be verbose unless be_verbose() is called.
- #size() == X.nc() - #size() == XX.nc()
- #have_target_values() == false - #have_target_values() == false
- We interpret XX as trans(X)*X where X is as defined in the objective
function discussed above in WHAT THIS OBJECT REPRESENTS.
!*/ !*/
template <typename EXP1, typename EXP2> template <typename EXP1, typename EXP2>
elastic_net( elastic_net(
const matrix_exp<EXP1>& X, const matrix_exp<EXP1>& XX,
const matrix_exp<EXP2>& Y const matrix_exp<EXP2>& XY
); );
/*! /*!
requires requires
- X.size() != 0 - XX.size() != 0
- is_col_vector(Y) - XX.nr() == XX.nc()
- X.nc() == Y.size() - is_col_vector(XY)
- XX.nc() == Y.size()
ensures ensures
- constructs this object by calling the elastic_net(X) constructor and then - constructs this object by calling the elastic_net(XX) constructor and
calling this->set_y(Y). then calling this->set_xy(XY).
- #have_target_values() == true - #have_target_values() == true
- We interpret XX as trans(X)*X where X is as defined in the objective
function discussed above in WHAT THIS OBJECT REPRESENTS. Similarly, XY
should be trans(X)*Y.
!*/ !*/
long size ( long size (
) const; ) const;
/*! /*!
ensures ensures
- returns the number of samples loaded into this object. - returns the dimensionality of the data loaded into this object. That is,
how many elements are in the optimal w vector? This function returns
that number.
!*/ !*/
bool have_target_values ( bool have_target_values (
) const; ) const;
/*! /*!
ensures ensures
- returns true if set_y() has been called and false otherwise. - returns true if set_xy() has been called and false otherwise.
!*/ !*/
template <typename EXP> template <typename EXP>
void set_y( void set_xy(
const matrix_exp<EXP>& Y const matrix_exp<EXP>& XY
); );
/*! /*!
requires requires
...@@ -97,8 +106,9 @@ namespace dlib ...@@ -97,8 +106,9 @@ namespace dlib
- Y.size() == size() - Y.size() == size()
ensures ensures
- #have_target_values() == true - #have_target_values() == true
- Sets the target values, the Y variable in the objective function, to the - Sets the target values of the regression. Note that we expect the given
given Y. matrix, XY, to be equal to trans(X)*Y, where X and Y have the definitions
discussed above in WHAT THIS OBJECT REPRESENTS.
!*/ !*/
void set_epsilon( void set_epsilon(
...@@ -164,6 +174,7 @@ namespace dlib ...@@ -164,6 +174,7 @@ namespace dlib
ensures ensures
- Solves the optimization problem described in the WHAT THIS OBJECT - Solves the optimization problem described in the WHAT THIS OBJECT
REPRESENTS section above and returns the optimal w. REPRESENTS section above and returns the optimal w.
- The returned vector has size() elements.
- if (lasso_budget == infinity) then - if (lasso_budget == infinity) then
- The lasso constraint is ignored - The lasso constraint is ignored
!*/ !*/
......
...@@ -95,7 +95,7 @@ namespace ...@@ -95,7 +95,7 @@ namespace
double lasso_budget = sum(abs(w)); double lasso_budget = sum(abs(w));
double eps = 0.0000001; double eps = 0.0000001;
dlib::elastic_net solver(X,Y); dlib::elastic_net solver(X*trans(X),X*Y);
solver.set_epsilon(eps); solver.set_epsilon(eps);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment