Commit af1af6d2 authored by Davis King's avatar Davis King

Fixed the decayed running stats objects so they use unbiased estimators.

parent 18489b11
......@@ -499,7 +499,7 @@ namespace dlib
sum_x = sum_x*forget + x;
sum_y = sum_y*forget + y;
n = n*forget + forget;
n = n*forget + 1;
}
T current_n (
......@@ -530,20 +530,20 @@ namespace dlib
) const
{
// make sure requires clause is not broken
DLIB_ASSERT(current_n() > 0,
DLIB_ASSERT(current_n() > 1,
"\tT running_scalar_covariance_decayed::covariance()"
<< "\n\tyou have to add some numbers to this object first"
<< "\n\tthis: " << this
);
return 1/n * (sum_xy - sum_y*sum_x/n);
return 1/(n-1) * (sum_xy - sum_y*sum_x/n);
}
T correlation (
) const
{
// make sure requires clause is not broken
DLIB_ASSERT(current_n() > 0,
DLIB_ASSERT(current_n() > 1,
"\tT running_scalar_covariance_decayed::correlation()"
<< "\n\tyou have to add some numbers to this object first"
<< "\n\tthis: " << this
......@@ -560,13 +560,13 @@ namespace dlib
) const
{
// make sure requires clause is not broken
DLIB_ASSERT(current_n() > 0,
DLIB_ASSERT(current_n() > 1,
"\tT running_scalar_covariance_decayed::variance_x()"
<< "\n\tyou have to add some numbers to this object first"
<< "\n\tthis: " << this
);
T temp = 1/n * (sum_xx - sum_x*sum_x/n);
T temp = 1/(n-1) * (sum_xx - sum_x*sum_x/n);
// make sure the variance is never negative. This might
// happen due to numerical errors.
if (temp >= 0)
......@@ -579,13 +579,13 @@ namespace dlib
) const
{
// make sure requires clause is not broken
DLIB_ASSERT(current_n() > 0,
DLIB_ASSERT(current_n() > 1,
"\tT running_scalar_covariance_decayed::variance_y()"
<< "\n\tyou have to add some numbers to this object first"
<< "\n\tthis: " << this
);
T temp = 1/n * (sum_yy - sum_y*sum_y/n);
T temp = 1/(n-1) * (sum_yy - sum_y*sum_y/n);
// make sure the variance is never negative. This might
// happen due to numerical errors.
if (temp >= 0)
......@@ -598,7 +598,7 @@ namespace dlib
) const
{
// make sure requires clause is not broken
DLIB_ASSERT(current_n() > 0,
DLIB_ASSERT(current_n() > 1,
"\tT running_scalar_covariance_decayed::stddev_x()"
<< "\n\tyou have to add some numbers to this object first"
<< "\n\tthis: " << this
......@@ -611,7 +611,7 @@ namespace dlib
) const
{
// make sure requires clause is not broken
DLIB_ASSERT(current_n() > 0,
DLIB_ASSERT(current_n() > 1,
"\tT running_scalar_covariance_decayed::stddev_y()"
<< "\n\tyou have to add some numbers to this object first"
<< "\n\tthis: " << this
......@@ -673,7 +673,7 @@ namespace dlib
sum_x = sum_x*forget + x;
n = n*forget + forget;
n = n*forget + 1;
}
T current_n (
......@@ -695,13 +695,13 @@ namespace dlib
) const
{
// make sure requires clause is not broken
DLIB_ASSERT(current_n() > 0,
DLIB_ASSERT(current_n() > 1,
"\tT running_stats_decayed::variance()"
<< "\n\tyou have to add some numbers to this object first"
<< "\n\tthis: " << this
);
T temp = 1/n * (sum_xx - sum_x*sum_x/n);
T temp = 1/(n-1) * (sum_xx - sum_x*sum_x/n);
// make sure the variance is never negative. This might
// happen due to numerical errors.
if (temp >= 0)
......@@ -714,7 +714,7 @@ namespace dlib
) const
{
// make sure requires clause is not broken
DLIB_ASSERT(current_n() > 0,
DLIB_ASSERT(current_n() > 1,
"\tT running_stats_decayed::stddev()"
<< "\n\tyou have to add some numbers to this object first"
<< "\n\tthis: " << this
......
......@@ -570,7 +570,7 @@ namespace dlib
) const;
/*!
requires
- current_n() > 0
- current_n() > 1
ensures
- returns the covariance between all the x and y samples presented
to this object via add()
......@@ -580,7 +580,7 @@ namespace dlib
) const;
/*!
requires
- current_n() > 0
- current_n() > 1
ensures
- returns the correlation coefficient between all the x and y samples
presented to this object via add()
......@@ -590,7 +590,7 @@ namespace dlib
) const;
/*!
requires
- current_n() > 0
- current_n() > 1
ensures
- returns the sample variance value of all x samples presented
to this object via add().
......@@ -600,7 +600,7 @@ namespace dlib
) const;
/*!
requires
- current_n() > 0
- current_n() > 1
ensures
- returns the sample variance value of all y samples presented
to this object via add().
......@@ -610,7 +610,7 @@ namespace dlib
) const;
/*!
requires
- current_n() > 0
- current_n() > 1
ensures
- returns the sample standard deviation of all x samples
presented to this object via add().
......@@ -620,7 +620,7 @@ namespace dlib
) const;
/*!
requires
- current_n() > 0
- current_n() > 1
ensures
- returns the sample standard deviation of all y samples
presented to this object via add().
......@@ -703,7 +703,7 @@ namespace dlib
) const;
/*!
requires
- current_n() > 0
- current_n() > 1
ensures
- returns the sample variance value of all x samples presented to this
object via add().
......@@ -713,7 +713,7 @@ namespace dlib
) const;
/*!
requires
- current_n() > 0
- current_n() > 1
ensures
- returns the sample standard deviation of all x samples presented to this
object via add().
......
......@@ -340,14 +340,12 @@ namespace
DLIB_TEST(std::abs(rsc2.correlation() - 0) < 1e-10);
const double s = 99/100.0;
const double ss = std::sqrt(s);;
DLIB_TEST_MSG(std::abs(rs.mean() - rscd1.mean_x()) < 1e-2, std::abs(rs.mean() - rscd1.mean_x()) << " " << rscd1.mean_x());
DLIB_TEST(std::abs(rs.mean() - rscd1.mean_y()) < 1e-2);
DLIB_TEST_MSG(std::abs(ss*rs.stddev() - rscd1.stddev_x()) < 1e-2, std::abs(ss*rs.stddev() - rscd1.stddev_x()));
DLIB_TEST(std::abs(ss*rs.stddev() - rscd1.stddev_y()) < 1e-2);
DLIB_TEST_MSG(std::abs(s*rs.variance() - rscd1.variance_x()) < 1e-2, std::abs(s*rs.variance() - rscd1.variance_x()));
DLIB_TEST(std::abs(s*rs.variance() - rscd1.variance_y()) < 1e-2);
DLIB_TEST_MSG(std::abs(rs.stddev() - rscd1.stddev_x()) < 1e-2, std::abs(rs.stddev() - rscd1.stddev_x()));
DLIB_TEST(std::abs(rs.stddev() - rscd1.stddev_y()) < 1e-2);
DLIB_TEST_MSG(std::abs(rs.variance() - rscd1.variance_x()) < 1e-2, std::abs(rs.variance() - rscd1.variance_x()));
DLIB_TEST(std::abs(rs.variance() - rscd1.variance_y()) < 1e-2);
DLIB_TEST(std::abs(rscd1.correlation() - 1) < 1e-2);
DLIB_TEST(std::abs(rscd2.correlation() - 0) < 1e-2);
......@@ -803,6 +801,68 @@ namespace
DLIB_TEST(equal_error_rate(vals2, vals1).first == 1);
}
void test_running_stats_decayed()
{
print_spinner();
std::vector<double> tmp(300);
std::vector<double> tmp_var(tmp.size());
dlib::rand rnd;
const int num_rounds = 100000;
for (int rounds = 0; rounds < num_rounds; ++rounds)
{
running_stats_decayed<double> rs(100);
for (size_t i = 0; i < tmp.size(); ++i)
{
rs.add(rnd.get_random_gaussian() + 1);
tmp[i] += rs.mean();
tmp_var[i] += rs.variance();
}
}
// should print all 1s basically since the mean and variance should always be 1.
for (size_t i = 0; i < tmp.size(); ++i)
{
DLIB_TEST(std::abs(1-tmp[i]/num_rounds) < 0.001);
if (i > 1)
DLIB_TEST(std::abs(1-tmp_var[i]/num_rounds) < 0.01);
}
}
void test_running_scalar_covariance_decayed()
{
print_spinner();
std::vector<double> tmp(300);
std::vector<double> tmp_var(tmp.size());
std::vector<double> tmp_covar(tmp.size());
dlib::rand rnd;
const int num_rounds = 500000;
for (int rounds = 0; rounds < num_rounds; ++rounds)
{
running_scalar_covariance_decayed<double> rs(100);
for (size_t i = 0; i < tmp.size(); ++i)
{
rs.add(rnd.get_random_gaussian() + 1, rnd.get_random_gaussian() + 1);
tmp[i] += (rs.mean_y()+rs.mean_x())/2;
tmp_var[i] += (rs.variance_y()+rs.variance_x())/2;
tmp_covar[i] += rs.covariance();
}
}
// should print all 1s basically since the mean and variance should always be 1.
for (size_t i = 0; i < tmp.size(); ++i)
{
DLIB_TEST(std::abs(1-tmp[i]/num_rounds) < 0.001);
if (i > 1)
{
DLIB_TEST(std::abs(1-tmp_var[i]/num_rounds) < 0.01);
DLIB_TEST(std::abs(tmp_covar[i]/num_rounds) < 0.001);
}
}
}
void test_event_corr()
{
print_spinner();
......@@ -841,6 +901,8 @@ namespace
test_average_precision();
test_lda();
test_event_corr();
test_running_stats_decayed();
test_running_scalar_covariance_decayed();
}
} a;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment