Commit cea4f73d authored by Davis King's avatar Davis King

Improved numeric stability of rls. This change also makes it keep the

regularization at a constant level even when exponential forgetting is
used.
parent 609e8de7
...@@ -81,8 +81,16 @@ namespace dlib ...@@ -81,8 +81,16 @@ namespace dlib
w = 0; w = 0;
} }
// multiply by forget factor and incorporate x*trans(x) into R.
const double l = 1.0/forget_factor; const double l = 1.0/forget_factor;
R = l*R - (l*l*R*x*trans(x)*trans(R))/(1 + l*trans(x)*R*x); const double temp = 1 + l*trans(x)*R*x;
matrix<double,0,1> tmp = R*x;
R = l*R - l*l*(tmp*trans(tmp))/temp;
// Since we multiplied by the forget factor, we need to add (1-forget_factor) of the
// identity matrix back in to keep the regularization alive.
add_eye_to_inv(R, (1-forget_factor)/C);
// R should always be symmetric. This line improves numeric stability of this algorithm. // R should always be symmetric. This line improves numeric stability of this algorithm.
R = 0.5*(R + trans(R)); R = 0.5*(R + trans(R));
...@@ -137,6 +145,23 @@ namespace dlib ...@@ -137,6 +145,23 @@ namespace dlib
private: private:
void add_eye_to_inv(
matrix<double>& m,
double C
)
/*!
ensures
- Let m == inv(M)
- this function returns inv(M + C*identity_matrix<double>(m.nr()))
!*/
{
for (long r = 0; r < m.nr(); ++r)
{
m = m - colm(m,r)*trans(colm(m,r))/(1/C + m(r,r));
}
}
matrix<double,0,1> w; matrix<double,0,1> w;
matrix<double> R; matrix<double> R;
double C; double C;
......
...@@ -25,10 +25,11 @@ namespace dlib ...@@ -25,10 +25,11 @@ namespace dlib
This object can also be configured to use exponential forgetting. This is This object can also be configured to use exponential forgetting. This is
where each training example is weighted by pow(forget_factor, i), where i where each training example is weighted by pow(forget_factor, i), where i
indicates the sample's age. So older samples are weighted less in the indicates the sample's age. So older samples are weighted less in the
least squares solution and therefore become forgotten after some time. Note least squares solution and therefore become forgotten after some time.
also that this forgetting applies to the regularizer as well. So if forgetting Therefore, with forgetting, this object solves the following optimization
is used then this object slowly converts itself to an unregularized version problem at each step:
of recursive least squares. find w minimizing: 0.5*dot(w,w) + C*sum_i pow(forget_factor, i)*(y_i - trans(x_i)*w)^2
Where i starts at 0 and i==0 corresponds to the most recent training point.
!*/ !*/
public: public:
......
...@@ -25,7 +25,7 @@ namespace ...@@ -25,7 +25,7 @@ namespace
{ {
dlib::rand rnd; dlib::rand rnd;
running_stats<double> rs1, rs2, rs3, rs4; running_stats<double> rs1, rs2, rs3, rs4, rs5;
for (int k = 0; k < 2; ++k) for (int k = 0; k < 2; ++k)
{ {
...@@ -53,6 +53,32 @@ namespace ...@@ -53,6 +53,32 @@ namespace
rs1.add(length(r.get_w() - w)); rs1.add(length(r.get_w() - w));
} }
{
matrix<double> X = randm(size,num_vars,rnd);
matrix<double,0,1> Y = randm(size,1,rnd);
matrix<double,0,1> G(size,1);
const double C = 10000;
const double forget_factor = 0.8;
rls r(forget_factor, C);
for (long i = 0; i < Y.size(); ++i)
{
r.train(trans(rowm(X,i)), Y(i));
G(i) = std::pow(forget_factor, i/2.0);
}
G = flipud(G);
X = diagm(G)*X;
Y = diagm(G)*Y;
matrix<double> w = pinv(1.0/C*identity_matrix<double>(X.nc()) + trans(X)*X)*trans(X)*Y;
rs5.add(length(r.get_w() - w));
}
{ {
matrix<double> X = randm(size,num_vars,rnd); matrix<double> X = randm(size,num_vars,rnd);
matrix<double> Y = colm(X,0)*10; matrix<double> Y = colm(X,0)*10;
...@@ -124,20 +150,24 @@ namespace ...@@ -124,20 +150,24 @@ namespace
dlog << LINFO << "rs2.mean(): " << rs2.mean(); dlog << LINFO << "rs2.mean(): " << rs2.mean();
dlog << LINFO << "rs3.mean(): " << rs3.mean(); dlog << LINFO << "rs3.mean(): " << rs3.mean();
dlog << LINFO << "rs4.mean(): " << rs4.mean(); dlog << LINFO << "rs4.mean(): " << rs4.mean();
dlog << LINFO << "rs5.mean(): " << rs5.mean();
dlog << LINFO << "rs1.max(): " << rs1.max(); dlog << LINFO << "rs1.max(): " << rs1.max();
dlog << LINFO << "rs2.max(): " << rs2.max(); dlog << LINFO << "rs2.max(): " << rs2.max();
dlog << LINFO << "rs3.max(): " << rs3.max(); dlog << LINFO << "rs3.max(): " << rs3.max();
dlog << LINFO << "rs4.max(): " << rs4.max(); dlog << LINFO << "rs4.max(): " << rs4.max();
dlog << LINFO << "rs5.max(): " << rs5.max();
DLIB_TEST_MSG(rs1.mean() < 1e-10, rs1.mean()); DLIB_TEST_MSG(rs1.mean() < 1e-10, rs1.mean());
DLIB_TEST_MSG(rs2.mean() < 1e-9, rs2.mean()); DLIB_TEST_MSG(rs2.mean() < 1e-9, rs2.mean());
DLIB_TEST_MSG(rs3.mean() < 1e-6, rs3.mean()); DLIB_TEST_MSG(rs3.mean() < 1e-6, rs3.mean());
DLIB_TEST_MSG(rs4.mean() < 1e-6, rs4.mean()); DLIB_TEST_MSG(rs4.mean() < 1e-6, rs4.mean());
DLIB_TEST_MSG(rs5.mean() < 1e-3, rs5.mean());
DLIB_TEST_MSG(rs1.max() < 1e-10, rs1.max()); DLIB_TEST_MSG(rs1.max() < 1e-10, rs1.max());
DLIB_TEST_MSG(rs2.max() < 1e-6, rs2.max()); DLIB_TEST_MSG(rs2.max() < 1e-6, rs2.max());
DLIB_TEST_MSG(rs3.max() < 0.001, rs3.max()); DLIB_TEST_MSG(rs3.max() < 0.001, rs3.max());
DLIB_TEST_MSG(rs4.max() < 0.01, rs4.max()); DLIB_TEST_MSG(rs4.max() < 0.01, rs4.max());
DLIB_TEST_MSG(rs5.max() < 0.1, rs5.max());
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment