Commit 7c631857 authored by Davis King's avatar Davis King

Added running_scalar_covariance_decayd

parent 23785d53
...@@ -450,6 +450,183 @@ namespace dlib ...@@ -450,6 +450,183 @@ namespace dlib
T n; T n;
}; };
// ----------------------------------------------------------------------------------------
template <
typename T
>
class running_scalar_covariance_decayed
{
public:
explicit running_scalar_covariance_decayed(
T decay_halflife = 1000
)
{
DLIB_ASSERT(decay_halflife > 0);
sum_xy = 0;
sum_x = 0;
sum_y = 0;
sum_xx = 0;
sum_yy = 0;
forget = std::pow(0.5, 1/decay_halflife);
n = 0;
COMPILE_TIME_ASSERT ((
is_same_type<float,T>::value ||
is_same_type<double,T>::value ||
is_same_type<long double,T>::value
));
}
T forget_factor (
) const
{
return forget;
}
void add (
const T& x,
const T& y
)
{
sum_xy = sum_xy*forget + x*y;
sum_xx = sum_xx*forget + x*x;
sum_yy = sum_yy*forget + y*y;
sum_x = sum_x*forget + x;
sum_y = sum_y*forget + y;
n = n*forget + forget;
}
T current_n (
) const
{
return n;
}
T mean_x (
) const
{
if (n != 0)
return sum_x/n;
else
return 0;
}
T mean_y (
) const
{
if (n != 0)
return sum_y/n;
else
return 0;
}
T covariance (
) const
{
// make sure requires clause is not broken
DLIB_ASSERT(current_n() > 0,
"\tT running_scalar_covariance_decayed::covariance()"
<< "\n\tyou have to add some numbers to this object first"
<< "\n\tthis: " << this
);
return 1/n * (sum_xy - sum_y*sum_x/n);
}
T correlation (
) const
{
// make sure requires clause is not broken
DLIB_ASSERT(current_n() > 0,
"\tT running_scalar_covariance_decayed::correlation()"
<< "\n\tyou have to add some numbers to this object first"
<< "\n\tthis: " << this
);
return covariance() / std::sqrt(variance_x()*variance_y());
}
T variance_x (
) const
{
// make sure requires clause is not broken
DLIB_ASSERT(current_n() > 0,
"\tT running_scalar_covariance_decayed::variance_x()"
<< "\n\tyou have to add some numbers to this object first"
<< "\n\tthis: " << this
);
T temp = 1/n * (sum_xx - sum_x*sum_x/n);
// make sure the variance is never negative. This might
// happen due to numerical errors.
if (temp >= 0)
return temp;
else
return 0;
}
T variance_y (
) const
{
// make sure requires clause is not broken
DLIB_ASSERT(current_n() > 0,
"\tT running_scalar_covariance_decayed::variance_y()"
<< "\n\tyou have to add some numbers to this object first"
<< "\n\tthis: " << this
);
T temp = 1/n * (sum_yy - sum_y*sum_y/n);
// make sure the variance is never negative. This might
// happen due to numerical errors.
if (temp >= 0)
return temp;
else
return 0;
}
T stddev_x (
) const
{
// make sure requires clause is not broken
DLIB_ASSERT(current_n() > 0,
"\tT running_scalar_covariance_decayed::stddev_x()"
<< "\n\tyou have to add some numbers to this object first"
<< "\n\tthis: " << this
);
return std::sqrt(variance_x());
}
T stddev_y (
) const
{
// make sure requires clause is not broken
DLIB_ASSERT(current_n() > 0,
"\tT running_scalar_covariance_decayed::stddev_y()"
<< "\n\tyou have to add some numbers to this object first"
<< "\n\tthis: " << this
);
return std::sqrt(variance_y());
}
private:
T sum_xy;
T sum_x;
T sum_y;
T sum_xx;
T sum_yy;
T n;
T forget;
};
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
template < template <
......
...@@ -414,6 +414,150 @@ namespace dlib ...@@ -414,6 +414,150 @@ namespace dlib
!*/ !*/
}; };
// ----------------------------------------------------------------------------------------
template <
typename T
>
class running_scalar_covariance_decayed
{
/*!
REQUIREMENTS ON T
- T must be a float, double, or long double type
INITIAL VALUE
- mean_x() == 0
- mean_y() == 0
- current_n() == 0
WHAT THIS OBJECT REPRESENTS
This object represents something that can compute the running covariance of
a stream of real number pairs. It is essentially the same as
running_scalar_covariance except that it forgets about data it has seen
after a certain period of time. It does this by exponentially decaying old
statistic.
!*/
public:
running_scalar_covariance_decayed(
T decay_halflife = 1000
);
/*!
requires
- decay_halflife > 0
ensures
- #forget_factor() == std::pow(0.5, 1/decay_halflife);
(i.e. after decay_halflife calls to add() the data given to the first add
will be down weighted by 0.5 in the statistics stored in this object).
!*/
T forget_factor (
) const;
/*!
ensures
- returns the exponential forget factor used to forget old statistics when
add() is called.
!*/
void add (
const T& x,
const T& y
);
/*!
ensures
- updates the statistics stored in this object so that
the new pair (x,y) is factored into them.
- #current_n() == current_n()*forget_factor() + forget_factor()
- Down weights old statistics by a factor of forget_factor().
!*/
T current_n (
) const;
/*!
ensures
- returns the effective number of points given to this object. As add()
is called this value will converge to a constant, the value of which is
based on the decay_halflife supplied to the constructor.
!*/
T mean_x (
) const;
/*!
ensures
- returns the mean value of all x samples presented to this object
via add().
!*/
T mean_y (
) const;
/*!
ensures
- returns the mean value of all y samples presented to this object
via add().
!*/
T covariance (
) const;
/*!
requires
- current_n() > 0
ensures
- returns the covariance between all the x and y samples presented
to this object via add()
!*/
T correlation (
) const;
/*!
requires
- current_n() > 0
ensures
- returns the correlation coefficient between all the x and y samples
presented to this object via add()
!*/
T variance_x (
) const;
/*!
requires
- current_n() > 0
ensures
- returns the unbiased sample variance value of all x samples presented
to this object via add().
!*/
T variance_y (
) const;
/*!
requires
- current_n() > 0
ensures
- returns the unbiased sample variance value of all y samples presented
to this object via add().
!*/
T stddev_x (
) const;
/*!
requires
- current_n() > 0
ensures
- returns the unbiased sample standard deviation of all x samples
presented to this object via add().
!*/
T stddev_y (
) const;
/*!
requires
- current_n() > 0
ensures
- returns the unbiased sample standard deviation of all y samples
presented to this object via add().
!*/
};
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
template < template <
......
...@@ -312,6 +312,7 @@ namespace ...@@ -312,6 +312,7 @@ namespace
running_stats<double> rs, rs2; running_stats<double> rs, rs2;
running_scalar_covariance<double> rsc1, rsc2; running_scalar_covariance<double> rsc1, rsc2;
running_scalar_covariance_decayed<double> rscd1(1000000), rscd2(1000000);
for (double i = 0; i < 100; ++i) for (double i = 0; i < 100; ++i)
{ {
...@@ -320,6 +321,10 @@ namespace ...@@ -320,6 +321,10 @@ namespace
rsc1.add(i,i); rsc1.add(i,i);
rsc2.add(i,i); rsc2.add(i,i);
rsc2.add(i,-i); rsc2.add(i,-i);
rscd1.add(i,i);
rscd2.add(i,i);
rscd2.add(i,-i);
} }
// make sure the running_stats and running_scalar_covariance agree // make sure the running_stats and running_scalar_covariance agree
...@@ -335,6 +340,18 @@ namespace ...@@ -335,6 +340,18 @@ namespace
DLIB_TEST(std::abs(rsc2.correlation() - 0) < 1e-10); DLIB_TEST(std::abs(rsc2.correlation() - 0) < 1e-10);
const double s = 99/100.0;
const double ss = std::sqrt(s);;
DLIB_TEST_MSG(std::abs(rs.mean() - rscd1.mean_x()) < 1e-2, std::abs(rs.mean() - rscd1.mean_x()) << " " << rscd1.mean_x());
DLIB_TEST(std::abs(rs.mean() - rscd1.mean_y()) < 1e-2);
DLIB_TEST_MSG(std::abs(ss*rs.stddev() - rscd1.stddev_x()) < 1e-2, std::abs(ss*rs.stddev() - rscd1.stddev_x()));
DLIB_TEST(std::abs(ss*rs.stddev() - rscd1.stddev_y()) < 1e-2);
DLIB_TEST_MSG(std::abs(s*rs.variance() - rscd1.variance_x()) < 1e-2, std::abs(s*rs.variance() - rscd1.variance_x()));
DLIB_TEST(std::abs(s*rs.variance() - rscd1.variance_y()) < 1e-2);
DLIB_TEST(std::abs(rscd1.correlation() - 1) < 1e-2);
DLIB_TEST(std::abs(rscd2.correlation() - 0) < 1e-2);
// test serialization of running_stats // test serialization of running_stats
ostringstream sout; ostringstream sout;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment