Commit b1909d5c authored by Davis King's avatar Davis King

merged

parents 08a89c80 28d76d01
...@@ -368,6 +368,31 @@ namespace mex_binding ...@@ -368,6 +368,31 @@ namespace mex_binding
struct is_column_major_matrix<matrix<T,num_rows,num_cols,mem_manager,column_major_layout> > struct is_column_major_matrix<matrix<T,num_rows,num_cols,mem_manager,column_major_layout> >
{ static const bool value = true; }; { static const bool value = true; };
// -------------------------------------------------------
string escape_percent(const string& str)
{
string temp;
for(auto c : str)
{
if (c != '%')
{
temp += c;
}
else
{
temp += c;
temp += c;
}
}
return temp;
}
string escape_percent(const std::ostringstream& sout)
{
return escape_percent(sout.str());
}
// ------------------------------------------------------- // -------------------------------------------------------
template < template <
...@@ -386,14 +411,14 @@ namespace mex_binding ...@@ -386,14 +411,14 @@ namespace mex_binding
std::ostringstream sout; std::ostringstream sout;
sout << "Argument " << arg_idx+1 << " expects a matrix with " << matrix_type::NR << " rows but got one with " << src.nc(); sout << "Argument " << arg_idx+1 << " expects a matrix with " << matrix_type::NR << " rows but got one with " << src.nc();
mexErrMsgIdAndTxt("mex_function:validate_and_populate_arg", mexErrMsgIdAndTxt("mex_function:validate_and_populate_arg",
sout.str().c_str()); escape_percent(sout).c_str());
} }
if (matrix_type::NC != 0 && matrix_type::NC != src.nr()) if (matrix_type::NC != 0 && matrix_type::NC != src.nr())
{ {
std::ostringstream sout; std::ostringstream sout;
sout << "Argument " << arg_idx+1 << " expects a matrix with " << matrix_type::NC << " columns but got one with " << src.nr(); sout << "Argument " << arg_idx+1 << " expects a matrix with " << matrix_type::NC << " columns but got one with " << src.nr();
mexErrMsgIdAndTxt("mex_function:validate_and_populate_arg", mexErrMsgIdAndTxt("mex_function:validate_and_populate_arg",
sout.str().c_str()); escape_percent(sout).c_str());
} }
...@@ -429,7 +454,7 @@ namespace mex_binding ...@@ -429,7 +454,7 @@ namespace mex_binding
std::ostringstream sout; std::ostringstream sout;
sout << "mex_function has some bug in it related to processing input argument " << arg_idx+1; sout << "mex_function has some bug in it related to processing input argument " << arg_idx+1;
mexErrMsgIdAndTxt("mex_function:validate_and_populate_arg", mexErrMsgIdAndTxt("mex_function:validate_and_populate_arg",
sout.str().c_str()); escape_percent(sout).c_str());
} }
...@@ -451,7 +476,7 @@ namespace mex_binding ...@@ -451,7 +476,7 @@ namespace mex_binding
std::ostringstream sout; std::ostringstream sout;
sout << "Error, input argument " << arg_idx+1 << " must be a non-negative number."; sout << "Error, input argument " << arg_idx+1 << " must be a non-negative number.";
mexErrMsgIdAndTxt("mex_function:validate_and_populate_arg", mexErrMsgIdAndTxt("mex_function:validate_and_populate_arg",
sout.str().c_str()); escape_percent(sout).c_str());
} }
else else
{ {
...@@ -473,7 +498,7 @@ namespace mex_binding ...@@ -473,7 +498,7 @@ namespace mex_binding
std::ostringstream sout; std::ostringstream sout;
sout << "mex_function has some bug in it related to processing input argument " << arg_idx+1; sout << "mex_function has some bug in it related to processing input argument " << arg_idx+1;
mexErrMsgIdAndTxt("mex_function:validate_and_populate_arg", mexErrMsgIdAndTxt("mex_function:validate_and_populate_arg",
sout.str().c_str()); escape_percent(sout).c_str());
} }
...@@ -500,7 +525,7 @@ namespace mex_binding ...@@ -500,7 +525,7 @@ namespace mex_binding
std::ostringstream sout; std::ostringstream sout;
sout << "mex_function has some bug in it related to processing input argument " << arg_idx+1; sout << "mex_function has some bug in it related to processing input argument " << arg_idx+1;
mexErrMsgIdAndTxt("mex_function:validate_and_populate_arg", mexErrMsgIdAndTxt("mex_function:validate_and_populate_arg",
sout.str().c_str()); escape_percent(sout).c_str());
} }
...@@ -567,7 +592,7 @@ namespace mex_binding ...@@ -567,7 +592,7 @@ namespace mex_binding
std::ostringstream sout; std::ostringstream sout;
sout << "mex_function has some bug in it related to processing input argument " << arg_idx+1; sout << "mex_function has some bug in it related to processing input argument " << arg_idx+1;
mexErrMsgIdAndTxt("mex_function:validate_and_populate_arg", mexErrMsgIdAndTxt("mex_function:validate_and_populate_arg",
sout.str().c_str()); escape_percent(sout).c_str());
} }
// ------------------------------------------------------- // -------------------------------------------------------
...@@ -584,7 +609,7 @@ namespace mex_binding ...@@ -584,7 +609,7 @@ namespace mex_binding
std::ostringstream sout; std::ostringstream sout;
sout << "mex_function has some bug in it related to processing input argument " << arg_idx+1; sout << "mex_function has some bug in it related to processing input argument " << arg_idx+1;
mexErrMsgIdAndTxt("mex_function:validate_and_populate_arg", mexErrMsgIdAndTxt("mex_function:validate_and_populate_arg",
sout.str().c_str()); escape_percent(sout).c_str());
} }
template <typename MM> template <typename MM>
...@@ -2913,7 +2938,7 @@ namespace mex_binding ...@@ -2913,7 +2938,7 @@ namespace mex_binding
<< " and " << expected_nrhs << " input arguments, got " << nrhs << "."; << " and " << expected_nrhs << " input arguments, got " << nrhs << ".";
mexErrMsgIdAndTxt("mex_function:nrhs", mexErrMsgIdAndTxt("mex_function:nrhs",
sout.str().c_str()); escape_percent(sout).c_str());
} }
if (nlhs > expected_nlhs) if (nlhs > expected_nlhs)
...@@ -2922,7 +2947,7 @@ namespace mex_binding ...@@ -2922,7 +2947,7 @@ namespace mex_binding
sout << "Expected at most " << expected_nlhs << " output arguments, got " << nlhs << "."; sout << "Expected at most " << expected_nlhs << " output arguments, got " << nlhs << ".";
mexErrMsgIdAndTxt("mex_function:nlhs", mexErrMsgIdAndTxt("mex_function:nlhs",
sout.str().c_str()); escape_percent(sout).c_str());
} }
call_mex_function_helper<sig_traits<funct>::num_args> helper; call_mex_function_helper<sig_traits<funct>::num_args> helper;
...@@ -4988,7 +5013,7 @@ void mexFunction( int nlhs, mxArray *plhs[], ...@@ -4988,7 +5013,7 @@ void mexFunction( int nlhs, mxArray *plhs[],
catch (mex_binding::invalid_args_exception& e) catch (mex_binding::invalid_args_exception& e)
{ {
mexErrMsgIdAndTxt("mex_function:validate_and_populate_arg", mexErrMsgIdAndTxt("mex_function:validate_and_populate_arg",
("Input" + e.msg).c_str()); mex_binding::escape_percent("Input" + e.msg).c_str());
} }
catch (mex_binding::user_hit_ctrl_c& ) catch (mex_binding::user_hit_ctrl_c& )
{ {
...@@ -4997,7 +5022,7 @@ void mexFunction( int nlhs, mxArray *plhs[], ...@@ -4997,7 +5022,7 @@ void mexFunction( int nlhs, mxArray *plhs[],
catch (std::exception& e) catch (std::exception& e)
{ {
mexErrMsgIdAndTxt("mex_function:error", mexErrMsgIdAndTxt("mex_function:error",
e.what()); mex_binding::escape_percent(e.what()).c_str());
} }
cout << flush; cout << flush;
......
...@@ -20,7 +20,8 @@ namespace dlib ...@@ -20,7 +20,8 @@ namespace dlib
explicit rls( explicit rls(
double forget_factor_, double forget_factor_,
double C_ = 1000 double C_ = 1000,
bool apply_forget_factor_to_C_ = false
) )
{ {
// make sure requires clause is not broken // make sure requires clause is not broken
...@@ -36,6 +37,7 @@ namespace dlib ...@@ -36,6 +37,7 @@ namespace dlib
C = C_; C = C_;
forget_factor = forget_factor_; forget_factor = forget_factor_;
apply_forget_factor_to_C = apply_forget_factor_to_C_;
} }
rls( rls(
...@@ -43,6 +45,7 @@ namespace dlib ...@@ -43,6 +45,7 @@ namespace dlib
{ {
C = 1000; C = 1000;
forget_factor = 1; forget_factor = 1;
apply_forget_factor_to_C = false;
} }
double get_c( double get_c(
...@@ -57,6 +60,12 @@ namespace dlib ...@@ -57,6 +60,12 @@ namespace dlib
return forget_factor; return forget_factor;
} }
bool should_apply_forget_factor_to_C (
) const
{
return apply_forget_factor_to_C;
}
template <typename EXP> template <typename EXP>
void train ( void train (
const matrix_exp<EXP>& x, const matrix_exp<EXP>& x,
...@@ -84,20 +93,25 @@ namespace dlib ...@@ -84,20 +93,25 @@ namespace dlib
// multiply by forget factor and incorporate x*trans(x) into R. // multiply by forget factor and incorporate x*trans(x) into R.
const double l = 1.0/forget_factor; const double l = 1.0/forget_factor;
const double temp = 1 + l*trans(x)*R*x; const double temp = 1 + l*trans(x)*R*x;
matrix<double,0,1> tmp = R*x; tmp = R*x;
R = l*R - l*l*(tmp*trans(tmp))/temp; R = l*R - l*l*(tmp*trans(tmp))/temp;
// Since we multiplied by the forget factor, we need to add (1-forget_factor) of the // Since we multiplied by the forget factor, we need to add (1-forget_factor) of the
// identity matrix back in to keep the regularization alive. // identity matrix back in to keep the regularization alive.
if (forget_factor != 1 && !apply_forget_factor_to_C)
add_eye_to_inv(R, (1-forget_factor)/C); add_eye_to_inv(R, (1-forget_factor)/C);
// R should always be symmetric. This line improves numeric stability of this algorithm. // R should always be symmetric. This line improves numeric stability of this algorithm.
if (cnt%100 == 0)
R = 0.5*(R + trans(R)); R = 0.5*(R + trans(R));
++cnt;
w = w + R*x*(y - trans(x)*w); w = w + R*x*(y - trans(x)*w);
} }
const matrix<double,0,1>& get_w( const matrix<double,0,1>& get_w(
) const ) const
{ {
...@@ -145,26 +159,38 @@ namespace dlib ...@@ -145,26 +159,38 @@ namespace dlib
friend inline void serialize(const rls& item, std::ostream& out) friend inline void serialize(const rls& item, std::ostream& out)
{ {
int version = 1; int version = 2;
serialize(version, out); serialize(version, out);
serialize(item.w, out); serialize(item.w, out);
serialize(item.R, out); serialize(item.R, out);
serialize(item.C, out); serialize(item.C, out);
serialize(item.forget_factor, out); serialize(item.forget_factor, out);
serialize(item.cnt, out);
serialize(item.apply_forget_factor_to_C, out);
} }
friend inline void deserialize(rls& item, std::istream& in) friend inline void deserialize(rls& item, std::istream& in)
{ {
int version = 0; int version = 0;
deserialize(version, in); deserialize(version, in);
if (version != 1) if (!(1 <= version && version <= 2))
throw dlib::serialization_error("Unknown version number found while deserializing rls object."); throw dlib::serialization_error("Unknown version number found while deserializing rls object.");
if (version >= 1)
{
deserialize(item.w, in); deserialize(item.w, in);
deserialize(item.R, in); deserialize(item.R, in);
deserialize(item.C, in); deserialize(item.C, in);
deserialize(item.forget_factor, in); deserialize(item.forget_factor, in);
} }
item.cnt = 0;
item.apply_forget_factor_to_C = false;
if (version >= 2)
{
deserialize(item.cnt, in);
deserialize(item.apply_forget_factor_to_C, in);
}
}
private: private:
...@@ -189,6 +215,13 @@ namespace dlib ...@@ -189,6 +215,13 @@ namespace dlib
matrix<double> R; matrix<double> R;
double C; double C;
double forget_factor; double forget_factor;
int cnt = 0;
bool apply_forget_factor_to_C;
// This object is here only to avoid reallocation during training. It don't
// logically contribute to the state of this object.
matrix<double,0,1> tmp;
}; };
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
......
...@@ -37,7 +37,8 @@ namespace dlib ...@@ -37,7 +37,8 @@ namespace dlib
explicit rls( explicit rls(
double forget_factor, double forget_factor,
double C = 1000 double C = 1000,
bool apply_forget_factor_to_C = false
); );
/*! /*!
requires requires
...@@ -47,6 +48,7 @@ namespace dlib ...@@ -47,6 +48,7 @@ namespace dlib
- #get_w().size() == 0 - #get_w().size() == 0
- #get_c() == C - #get_c() == C
- #get_forget_factor() == forget_factor - #get_forget_factor() == forget_factor
- #should_apply_forget_factor_to_C() == apply_forget_factor_to_C
!*/ !*/
rls( rls(
...@@ -56,6 +58,7 @@ namespace dlib ...@@ -56,6 +58,7 @@ namespace dlib
- #get_w().size() == 0 - #get_w().size() == 0
- #get_c() == 1000 - #get_c() == 1000
- #get_forget_factor() == 1 - #get_forget_factor() == 1
- #should_apply_forget_factor_to_C() == false
!*/ !*/
double get_c( double get_c(
...@@ -80,6 +83,18 @@ namespace dlib ...@@ -80,6 +83,18 @@ namespace dlib
zero the faster old examples are forgotten. zero the faster old examples are forgotten.
!*/ !*/
bool should_apply_forget_factor_to_C (
) const;
/*!
ensures
- If this function returns false then it means we are optimizing the
objective function discussed in the WHAT THIS OBJECT REPRESENTS section
above. However, if it returns true then we will allow the forget factor
(get_forget_factor()) to be applied to the C value which causes the
algorithm to slowly increase C and convert into a textbook version of RLS
without regularization. The main reason you might want to do this is
because it can make the algorithm run significantly faster.
!*/
template <typename EXP> template <typename EXP>
void train ( void train (
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment