Commit b1909d5c authored by Davis King's avatar Davis King

merged

parents 08a89c80 28d76d01
......@@ -368,6 +368,31 @@ namespace mex_binding
struct is_column_major_matrix<matrix<T,num_rows,num_cols,mem_manager,column_major_layout> >
{ static const bool value = true; };
// -------------------------------------------------------
string escape_percent(const string& str)
{
string temp;
for(auto c : str)
{
if (c != '%')
{
temp += c;
}
else
{
temp += c;
temp += c;
}
}
return temp;
}
string escape_percent(const std::ostringstream& sout)
{
return escape_percent(sout.str());
}
// -------------------------------------------------------
template <
......@@ -386,14 +411,14 @@ namespace mex_binding
std::ostringstream sout;
sout << "Argument " << arg_idx+1 << " expects a matrix with " << matrix_type::NR << " rows but got one with " << src.nc();
mexErrMsgIdAndTxt("mex_function:validate_and_populate_arg",
sout.str().c_str());
escape_percent(sout).c_str());
}
if (matrix_type::NC != 0 && matrix_type::NC != src.nr())
{
std::ostringstream sout;
sout << "Argument " << arg_idx+1 << " expects a matrix with " << matrix_type::NC << " columns but got one with " << src.nr();
mexErrMsgIdAndTxt("mex_function:validate_and_populate_arg",
sout.str().c_str());
escape_percent(sout).c_str());
}
......@@ -429,7 +454,7 @@ namespace mex_binding
std::ostringstream sout;
sout << "mex_function has some bug in it related to processing input argument " << arg_idx+1;
mexErrMsgIdAndTxt("mex_function:validate_and_populate_arg",
sout.str().c_str());
escape_percent(sout).c_str());
}
......@@ -451,7 +476,7 @@ namespace mex_binding
std::ostringstream sout;
sout << "Error, input argument " << arg_idx+1 << " must be a non-negative number.";
mexErrMsgIdAndTxt("mex_function:validate_and_populate_arg",
sout.str().c_str());
escape_percent(sout).c_str());
}
else
{
......@@ -473,7 +498,7 @@ namespace mex_binding
std::ostringstream sout;
sout << "mex_function has some bug in it related to processing input argument " << arg_idx+1;
mexErrMsgIdAndTxt("mex_function:validate_and_populate_arg",
sout.str().c_str());
escape_percent(sout).c_str());
}
......@@ -500,7 +525,7 @@ namespace mex_binding
std::ostringstream sout;
sout << "mex_function has some bug in it related to processing input argument " << arg_idx+1;
mexErrMsgIdAndTxt("mex_function:validate_and_populate_arg",
sout.str().c_str());
escape_percent(sout).c_str());
}
......@@ -567,7 +592,7 @@ namespace mex_binding
std::ostringstream sout;
sout << "mex_function has some bug in it related to processing input argument " << arg_idx+1;
mexErrMsgIdAndTxt("mex_function:validate_and_populate_arg",
sout.str().c_str());
escape_percent(sout).c_str());
}
// -------------------------------------------------------
......@@ -584,7 +609,7 @@ namespace mex_binding
std::ostringstream sout;
sout << "mex_function has some bug in it related to processing input argument " << arg_idx+1;
mexErrMsgIdAndTxt("mex_function:validate_and_populate_arg",
sout.str().c_str());
escape_percent(sout).c_str());
}
template <typename MM>
......@@ -2913,7 +2938,7 @@ namespace mex_binding
<< " and " << expected_nrhs << " input arguments, got " << nrhs << ".";
mexErrMsgIdAndTxt("mex_function:nrhs",
sout.str().c_str());
escape_percent(sout).c_str());
}
if (nlhs > expected_nlhs)
......@@ -2922,7 +2947,7 @@ namespace mex_binding
sout << "Expected at most " << expected_nlhs << " output arguments, got " << nlhs << ".";
mexErrMsgIdAndTxt("mex_function:nlhs",
sout.str().c_str());
escape_percent(sout).c_str());
}
call_mex_function_helper<sig_traits<funct>::num_args> helper;
......@@ -4988,7 +5013,7 @@ void mexFunction( int nlhs, mxArray *plhs[],
catch (mex_binding::invalid_args_exception& e)
{
mexErrMsgIdAndTxt("mex_function:validate_and_populate_arg",
("Input" + e.msg).c_str());
mex_binding::escape_percent("Input" + e.msg).c_str());
}
catch (mex_binding::user_hit_ctrl_c& )
{
......@@ -4997,7 +5022,7 @@ void mexFunction( int nlhs, mxArray *plhs[],
catch (std::exception& e)
{
mexErrMsgIdAndTxt("mex_function:error",
e.what());
mex_binding::escape_percent(e.what()).c_str());
}
cout << flush;
......
......@@ -20,7 +20,8 @@ namespace dlib
explicit rls(
double forget_factor_,
double C_ = 1000
double C_ = 1000,
bool apply_forget_factor_to_C_ = false
)
{
// make sure requires clause is not broken
......@@ -36,6 +37,7 @@ namespace dlib
C = C_;
forget_factor = forget_factor_;
apply_forget_factor_to_C = apply_forget_factor_to_C_;
}
rls(
......@@ -43,6 +45,7 @@ namespace dlib
{
C = 1000;
forget_factor = 1;
apply_forget_factor_to_C = false;
}
double get_c(
......@@ -57,6 +60,12 @@ namespace dlib
return forget_factor;
}
bool should_apply_forget_factor_to_C (
) const
{
return apply_forget_factor_to_C;
}
template <typename EXP>
void train (
const matrix_exp<EXP>& x,
......@@ -84,20 +93,25 @@ namespace dlib
// multiply by forget factor and incorporate x*trans(x) into R.
const double l = 1.0/forget_factor;
const double temp = 1 + l*trans(x)*R*x;
matrix<double,0,1> tmp = R*x;
tmp = R*x;
R = l*R - l*l*(tmp*trans(tmp))/temp;
// Since we multiplied by the forget factor, we need to add (1-forget_factor) of the
// identity matrix back in to keep the regularization alive.
if (forget_factor != 1 && !apply_forget_factor_to_C)
add_eye_to_inv(R, (1-forget_factor)/C);
// R should always be symmetric. This line improves numeric stability of this algorithm.
if (cnt%100 == 0)
R = 0.5*(R + trans(R));
++cnt;
w = w + R*x*(y - trans(x)*w);
}
const matrix<double,0,1>& get_w(
) const
{
......@@ -145,26 +159,38 @@ namespace dlib
friend inline void serialize(const rls& item, std::ostream& out)
{
int version = 1;
int version = 2;
serialize(version, out);
serialize(item.w, out);
serialize(item.R, out);
serialize(item.C, out);
serialize(item.forget_factor, out);
serialize(item.cnt, out);
serialize(item.apply_forget_factor_to_C, out);
}
friend inline void deserialize(rls& item, std::istream& in)
{
int version = 0;
deserialize(version, in);
if (version != 1)
if (!(1 <= version && version <= 2))
throw dlib::serialization_error("Unknown version number found while deserializing rls object.");
if (version >= 1)
{
deserialize(item.w, in);
deserialize(item.R, in);
deserialize(item.C, in);
deserialize(item.forget_factor, in);
}
item.cnt = 0;
item.apply_forget_factor_to_C = false;
if (version >= 2)
{
deserialize(item.cnt, in);
deserialize(item.apply_forget_factor_to_C, in);
}
}
private:
......@@ -189,6 +215,13 @@ namespace dlib
matrix<double> R;
double C;
double forget_factor;
int cnt = 0;
bool apply_forget_factor_to_C;
// This object is here only to avoid reallocation during training. It don't
// logically contribute to the state of this object.
matrix<double,0,1> tmp;
};
// ----------------------------------------------------------------------------------------
......
......@@ -37,7 +37,8 @@ namespace dlib
explicit rls(
double forget_factor,
double C = 1000
double C = 1000,
bool apply_forget_factor_to_C = false
);
/*!
requires
......@@ -47,6 +48,7 @@ namespace dlib
- #get_w().size() == 0
- #get_c() == C
- #get_forget_factor() == forget_factor
- #should_apply_forget_factor_to_C() == apply_forget_factor_to_C
!*/
rls(
......@@ -56,6 +58,7 @@ namespace dlib
- #get_w().size() == 0
- #get_c() == 1000
- #get_forget_factor() == 1
- #should_apply_forget_factor_to_C() == false
!*/
double get_c(
......@@ -80,6 +83,18 @@ namespace dlib
zero the faster old examples are forgotten.
!*/
bool should_apply_forget_factor_to_C (
) const;
/*!
ensures
- If this function returns false then it means we are optimizing the
objective function discussed in the WHAT THIS OBJECT REPRESENTS section
above. However, if it returns true then we will allow the forget factor
(get_forget_factor()) to be applied to the C value which causes the
algorithm to slowly increase C and convert into a textbook version of RLS
without regularization. The main reason you might want to do this is
because it can make the algorithm run significantly faster.
!*/
template <typename EXP>
void train (
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment