Commit 46985273 authored by Davis King's avatar Davis King

made tests more robust

parent 70360c76
...@@ -45,6 +45,17 @@ namespace ...@@ -45,6 +45,17 @@ namespace
return temp; return temp;
} }
// ----------------------------------------------------------------------------------------
matrix<double> rm_zeros (
const matrix<double>& m
)
{
// Do this to avoid trying to correlate super small numbers that are really just
// zero. Doing this avoids some potential false alarms in the unit tests below.
return round_zeros(m, max(abs(m))*1e-14);
}
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
void check_correlation ( void check_correlation (
...@@ -102,25 +113,26 @@ namespace ...@@ -102,25 +113,26 @@ namespace
DLIB_TEST(Ltrans.nc() == Rtrans.nc()); DLIB_TEST(Ltrans.nc() == Rtrans.nc());
dlog << LINFO << "correlations: "<< trans(correlations); dlog << LINFO << "correlations: "<< trans(correlations);
const double corr_error = max(abs(compute_correlations(tmp(L*Ltrans), tmp(R*Rtrans)) - correlations)); const double corr_error = max(abs(compute_correlations(rm_zeros(L*Ltrans), rm_zeros(R*Rtrans)) - correlations));
dlog << LINFO << "correlation error: "<< corr_error; dlog << LINFO << "correlation error: "<< corr_error;
DLIB_TEST(corr_error < 1e-13); DLIB_TEST_MSG(corr_error < 1e-13, Ltrans << "\n\n" << Rtrans);
const double trans_error = max(abs(L*Ltrans - R*Rtrans)); const double trans_error = max(abs(L*Ltrans - R*Rtrans));
dlog << LINFO << "trans_error: "<< trans_error; dlog << LINFO << "trans_error: "<< trans_error;
DLIB_TEST(trans_error < 1e-10); DLIB_TEST(trans_error < 1e-10);
} }
{ {
correlations = cca(mat_to_sparse(L), mat_to_sparse(R), Ltrans, Rtrans, min(m,n), max(n,n2)); correlations = cca(mat_to_sparse(L), mat_to_sparse(R), Ltrans, Rtrans, min(m,n), max(n,n2)+6, 4);
DLIB_TEST(Ltrans.nc() == Rtrans.nc()); DLIB_TEST(Ltrans.nc() == Rtrans.nc());
dlog << LINFO << "correlations: "<< trans(correlations); dlog << LINFO << "correlations: "<< trans(correlations);
dlog << LINFO << "computed cors: " << trans(compute_correlations(rm_zeros(L*Ltrans), rm_zeros(R*Rtrans)));
const double corr_error = max(abs(compute_correlations(tmp(L*Ltrans), tmp(R*Rtrans)) - correlations));
dlog << LINFO << "correlation error: "<< corr_error;
DLIB_TEST(corr_error < 1e-12);
const double trans_error = max(abs(L*Ltrans - R*Rtrans)); const double trans_error = max(abs(L*Ltrans - R*Rtrans));
dlog << LINFO << "trans_error: "<< trans_error; dlog << LINFO << "trans_error: "<< trans_error;
const double corr_error = max(abs(compute_correlations(rm_zeros(L*Ltrans), rm_zeros(R*Rtrans)) - correlations));
dlog << LINFO << "correlation error: "<< corr_error;
DLIB_TEST_MSG(corr_error < 1e-13, Ltrans << "\n\n" << Rtrans);
DLIB_TEST(trans_error < 1e-10); DLIB_TEST(trans_error < 1e-10);
} }
...@@ -152,7 +164,7 @@ namespace ...@@ -152,7 +164,7 @@ namespace
DLIB_TEST(Ltrans.nc() == Rtrans.nc()); DLIB_TEST(Ltrans.nc() == Rtrans.nc());
dlog << LINFO << "correlations: "<< trans(correlations); dlog << LINFO << "correlations: "<< trans(correlations);
const double corr_error = max(abs(compute_correlations(tmp(L*Ltrans), tmp(R*Rtrans)) - correlations)); const double corr_error = max(abs(compute_correlations(rm_zeros(L*Ltrans), rm_zeros(R*Rtrans)) - correlations));
dlog << LINFO << "correlation error: "<< corr_error; dlog << LINFO << "correlation error: "<< corr_error;
DLIB_TEST(corr_error < 1e-13); DLIB_TEST(corr_error < 1e-13);
} }
...@@ -161,7 +173,7 @@ namespace ...@@ -161,7 +173,7 @@ namespace
DLIB_TEST(Ltrans.nc() == Rtrans.nc()); DLIB_TEST(Ltrans.nc() == Rtrans.nc());
dlog << LINFO << "correlations: "<< trans(correlations); dlog << LINFO << "correlations: "<< trans(correlations);
const double corr_error = max(abs(compute_correlations(tmp(L*Ltrans), tmp(R*Rtrans)) - correlations)); const double corr_error = max(abs(compute_correlations(rm_zeros(L*Ltrans), rm_zeros(R*Rtrans)) - correlations));
dlog << LINFO << "correlation error: "<< corr_error; dlog << LINFO << "correlation error: "<< corr_error;
DLIB_TEST(corr_error < 1e-13); DLIB_TEST(corr_error < 1e-13);
} }
...@@ -192,7 +204,7 @@ namespace ...@@ -192,7 +204,7 @@ namespace
{ {
correlations = cca(L, R, Ltrans, Rtrans, rank); correlations = cca(L, R, Ltrans, Rtrans, rank);
DLIB_TEST(Ltrans.nc() == Rtrans.nc()); DLIB_TEST(Ltrans.nc() == Rtrans.nc());
const double corr_error = max(abs(compute_correlations(tmp(L*Ltrans), tmp(R*Rtrans)) - correlations)); const double corr_error = max(abs(compute_correlations(rm_zeros(L*Ltrans), rm_zeros(R*Rtrans)) - correlations));
dlog << LINFO << "correlation error: "<< corr_error; dlog << LINFO << "correlation error: "<< corr_error;
DLIB_TEST(corr_error < 1e-13); DLIB_TEST(corr_error < 1e-13);
...@@ -205,7 +217,7 @@ namespace ...@@ -205,7 +217,7 @@ namespace
{ {
correlations = cca(mat_to_sparse(L), mat_to_sparse(R), Ltrans, Rtrans, rank); correlations = cca(mat_to_sparse(L), mat_to_sparse(R), Ltrans, Rtrans, rank);
DLIB_TEST(Ltrans.nc() == Rtrans.nc()); DLIB_TEST(Ltrans.nc() == Rtrans.nc());
const double corr_error = max(abs(compute_correlations(tmp(L*Ltrans), tmp(R*Rtrans)) - correlations)); const double corr_error = max(abs(compute_correlations(rm_zeros(L*Ltrans), rm_zeros(R*Rtrans)) - correlations));
dlog << LINFO << "correlation error: "<< corr_error; dlog << LINFO << "correlation error: "<< corr_error;
DLIB_TEST(corr_error < 1e-13); DLIB_TEST(corr_error < 1e-13);
...@@ -277,7 +289,7 @@ namespace ...@@ -277,7 +289,7 @@ namespace
DLIB_TEST(v.nc() == rank); DLIB_TEST(v.nc() == rank);
DLIB_TEST(max(abs(trans(u)*u - identity_matrix<double>(u.nc()))) < 1e-13); DLIB_TEST(max(abs(trans(u)*u - identity_matrix<double>(u.nc()))) < 1e-13);
DLIB_TEST(max(abs(trans(v)*v - identity_matrix<double>(u.nc()))) < 1e-13); DLIB_TEST(max(abs(trans(v)*v - identity_matrix<double>(u.nc()))) < 1e-13);
DLIB_TEST(max(abs(tmp(A - u*diagm(w)*trans(v)))) < 1e-11); DLIB_TEST(max(abs(tmp(A - u*diagm(w)*trans(v)))) < 1e-10);
svd_fast(A, u, w, v, rank+5, 0); svd_fast(A, u, w, v, rank+5, 0);
DLIB_TEST(max(abs(trans(u)*u - identity_matrix<double>(u.nc()))) < 1e-13); DLIB_TEST(max(abs(trans(u)*u - identity_matrix<double>(u.nc()))) < 1e-13);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment