Commit cea4f73d authored by Davis King's avatar Davis King

Improved numeric stability of rls. This change also makes it keep the

regularization at a constant level even when exponential forgetting is
used.
parent 609e8de7
......@@ -81,8 +81,16 @@ namespace dlib
w = 0;
}
// multiply by forget factor and incorporate x*trans(x) into R.
const double l = 1.0/forget_factor;
R = l*R - (l*l*R*x*trans(x)*trans(R))/(1 + l*trans(x)*R*x);
const double temp = 1 + l*trans(x)*R*x;
matrix<double,0,1> tmp = R*x;
R = l*R - l*l*(tmp*trans(tmp))/temp;
// Since we multiplied by the forget factor, we need to add (1-forget_factor) of the
// identity matrix back in to keep the regularization alive.
add_eye_to_inv(R, (1-forget_factor)/C);
// R should always be symmetric. This line improves numeric stability of this algorithm.
R = 0.5*(R + trans(R));
......@@ -137,6 +145,23 @@ namespace dlib
private:
void add_eye_to_inv(
matrix<double>& m,
double C
)
/*!
ensures
- Let m == inv(M)
- this function returns inv(M + C*identity_matrix<double>(m.nr()))
!*/
{
for (long r = 0; r < m.nr(); ++r)
{
m = m - colm(m,r)*trans(colm(m,r))/(1/C + m(r,r));
}
}
matrix<double,0,1> w;
matrix<double> R;
double C;
......
......@@ -25,10 +25,11 @@ namespace dlib
This object can also be configured to use exponential forgetting. This is
where each training example is weighted by pow(forget_factor, i), where i
indicates the sample's age. So older samples are weighted less in the
least squares solution and therefore become forgotten after some time. Note
also that this forgetting applies to the regularizer as well. So if forgetting
is used then this object slowly converts itself to an unregularized version
of recursive least squares.
least squares solution and therefore become forgotten after some time.
Therefore, with forgetting, this object solves the following optimization
problem at each step:
find w minimizing: 0.5*dot(w,w) + C*sum_i pow(forget_factor, i)*(y_i - trans(x_i)*w)^2
Where i starts at 0 and i==0 corresponds to the most recent training point.
!*/
public:
......
......@@ -25,7 +25,7 @@ namespace
{
dlib::rand rnd;
running_stats<double> rs1, rs2, rs3, rs4;
running_stats<double> rs1, rs2, rs3, rs4, rs5;
for (int k = 0; k < 2; ++k)
{
......@@ -53,6 +53,32 @@ namespace
rs1.add(length(r.get_w() - w));
}
{
matrix<double> X = randm(size,num_vars,rnd);
matrix<double,0,1> Y = randm(size,1,rnd);
matrix<double,0,1> G(size,1);
const double C = 10000;
const double forget_factor = 0.8;
rls r(forget_factor, C);
for (long i = 0; i < Y.size(); ++i)
{
r.train(trans(rowm(X,i)), Y(i));
G(i) = std::pow(forget_factor, i/2.0);
}
G = flipud(G);
X = diagm(G)*X;
Y = diagm(G)*Y;
matrix<double> w = pinv(1.0/C*identity_matrix<double>(X.nc()) + trans(X)*X)*trans(X)*Y;
rs5.add(length(r.get_w() - w));
}
{
matrix<double> X = randm(size,num_vars,rnd);
matrix<double> Y = colm(X,0)*10;
......@@ -124,20 +150,24 @@ namespace
dlog << LINFO << "rs2.mean(): " << rs2.mean();
dlog << LINFO << "rs3.mean(): " << rs3.mean();
dlog << LINFO << "rs4.mean(): " << rs4.mean();
dlog << LINFO << "rs5.mean(): " << rs5.mean();
dlog << LINFO << "rs1.max(): " << rs1.max();
dlog << LINFO << "rs2.max(): " << rs2.max();
dlog << LINFO << "rs3.max(): " << rs3.max();
dlog << LINFO << "rs4.max(): " << rs4.max();
dlog << LINFO << "rs5.max(): " << rs5.max();
DLIB_TEST_MSG(rs1.mean() < 1e-10, rs1.mean());
DLIB_TEST_MSG(rs2.mean() < 1e-9, rs2.mean());
DLIB_TEST_MSG(rs3.mean() < 1e-6, rs3.mean());
DLIB_TEST_MSG(rs4.mean() < 1e-6, rs4.mean());
DLIB_TEST_MSG(rs5.mean() < 1e-3, rs5.mean());
DLIB_TEST_MSG(rs1.max() < 1e-10, rs1.max());
DLIB_TEST_MSG(rs2.max() < 1e-6, rs2.max());
DLIB_TEST_MSG(rs3.max() < 0.001, rs3.max());
DLIB_TEST_MSG(rs4.max() < 0.01, rs4.max());
DLIB_TEST_MSG(rs5.max() < 0.1, rs5.max());
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment