From 46985273b10968e8117c88bf26d24d466c19c2ac Mon Sep 17 00:00:00 2001
From: Davis King <davis@dlib.net>
Date: Mon, 14 Jan 2013 23:46:54 -0500
Subject: [PATCH] made tests more robust

---
 dlib/test/cca.cpp | 36 ++++++++++++++++++++++++------------
 1 file changed, 24 insertions(+), 12 deletions(-)

diff --git a/dlib/test/cca.cpp b/dlib/test/cca.cpp
index c05e2cd8..49b7a08d 100644
--- a/dlib/test/cca.cpp
+++ b/dlib/test/cca.cpp
@@ -45,6 +45,17 @@ namespace
         return temp;
     }
 
+// ----------------------------------------------------------------------------------------
+
+    matrix<double> rm_zeros (
+        const matrix<double>& m
+    )
+    {
+        // Do this to avoid trying to correlate super small numbers that are really just
+        // zero.  Doing this avoids some potential false alarms in the unit tests below.
+        return round_zeros(m, max(abs(m))*1e-14);
+    }
+
 // ----------------------------------------------------------------------------------------
 
     void check_correlation (
@@ -102,25 +113,26 @@ namespace
             DLIB_TEST(Ltrans.nc() == Rtrans.nc());
             dlog << LINFO << "correlations: "<< trans(correlations);
 
-            const double corr_error = max(abs(compute_correlations(tmp(L*Ltrans), tmp(R*Rtrans)) - correlations));
+            const double corr_error = max(abs(compute_correlations(rm_zeros(L*Ltrans), rm_zeros(R*Rtrans)) - correlations));
             dlog << LINFO << "correlation error: "<< corr_error;
-            DLIB_TEST(corr_error < 1e-13);
+            DLIB_TEST_MSG(corr_error < 1e-13, Ltrans << "\n\n" << Rtrans);
 
             const double trans_error = max(abs(L*Ltrans - R*Rtrans));
             dlog << LINFO << "trans_error: "<< trans_error;
             DLIB_TEST(trans_error < 1e-10);
         }
         {
-            correlations = cca(mat_to_sparse(L), mat_to_sparse(R), Ltrans, Rtrans, min(m,n), max(n,n2));
+            correlations = cca(mat_to_sparse(L), mat_to_sparse(R), Ltrans, Rtrans, min(m,n), max(n,n2)+6, 4);
             DLIB_TEST(Ltrans.nc() == Rtrans.nc());
             dlog << LINFO << "correlations: "<< trans(correlations);
-
-            const double corr_error = max(abs(compute_correlations(tmp(L*Ltrans), tmp(R*Rtrans)) - correlations));
-            dlog << LINFO << "correlation error: "<< corr_error;
-            DLIB_TEST(corr_error < 1e-12);
+            dlog << LINFO << "computed cors: " << trans(compute_correlations(rm_zeros(L*Ltrans), rm_zeros(R*Rtrans)));
 
             const double trans_error = max(abs(L*Ltrans - R*Rtrans));
             dlog << LINFO << "trans_error: "<< trans_error;
+            const double corr_error = max(abs(compute_correlations(rm_zeros(L*Ltrans), rm_zeros(R*Rtrans)) - correlations));
+            dlog << LINFO << "correlation error: "<< corr_error;
+            DLIB_TEST_MSG(corr_error < 1e-13, Ltrans << "\n\n" << Rtrans);
+
             DLIB_TEST(trans_error < 1e-10);
         }
 
@@ -152,7 +164,7 @@ namespace
             DLIB_TEST(Ltrans.nc() == Rtrans.nc());
             dlog << LINFO << "correlations: "<< trans(correlations);
 
-            const double corr_error = max(abs(compute_correlations(tmp(L*Ltrans), tmp(R*Rtrans)) - correlations));
+            const double corr_error = max(abs(compute_correlations(rm_zeros(L*Ltrans), rm_zeros(R*Rtrans)) - correlations));
             dlog << LINFO << "correlation error: "<< corr_error;
             DLIB_TEST(corr_error < 1e-13);
         }
@@ -161,7 +173,7 @@ namespace
             DLIB_TEST(Ltrans.nc() == Rtrans.nc());
             dlog << LINFO << "correlations: "<< trans(correlations);
 
-            const double corr_error = max(abs(compute_correlations(tmp(L*Ltrans), tmp(R*Rtrans)) - correlations));
+            const double corr_error = max(abs(compute_correlations(rm_zeros(L*Ltrans), rm_zeros(R*Rtrans)) - correlations));
             dlog << LINFO << "correlation error: "<< corr_error;
             DLIB_TEST(corr_error < 1e-13);
         }
@@ -192,7 +204,7 @@ namespace
         {
             correlations = cca(L, R, Ltrans, Rtrans, rank);
             DLIB_TEST(Ltrans.nc() == Rtrans.nc());
-            const double corr_error = max(abs(compute_correlations(tmp(L*Ltrans), tmp(R*Rtrans)) - correlations));
+            const double corr_error = max(abs(compute_correlations(rm_zeros(L*Ltrans), rm_zeros(R*Rtrans)) - correlations));
             dlog << LINFO << "correlation error: "<< corr_error;
             DLIB_TEST(corr_error < 1e-13);
 
@@ -205,7 +217,7 @@ namespace
         {
             correlations = cca(mat_to_sparse(L), mat_to_sparse(R), Ltrans, Rtrans, rank);
             DLIB_TEST(Ltrans.nc() == Rtrans.nc());
-            const double corr_error = max(abs(compute_correlations(tmp(L*Ltrans), tmp(R*Rtrans)) - correlations));
+            const double corr_error = max(abs(compute_correlations(rm_zeros(L*Ltrans), rm_zeros(R*Rtrans)) - correlations));
             dlog << LINFO << "correlation error: "<< corr_error;
             DLIB_TEST(corr_error < 1e-13);
 
@@ -277,7 +289,7 @@ namespace
         DLIB_TEST(v.nc() == rank);
         DLIB_TEST(max(abs(trans(u)*u - identity_matrix<double>(u.nc()))) < 1e-13);
         DLIB_TEST(max(abs(trans(v)*v - identity_matrix<double>(u.nc()))) < 1e-13);
-        DLIB_TEST(max(abs(tmp(A - u*diagm(w)*trans(v)))) < 1e-11);
+        DLIB_TEST(max(abs(tmp(A - u*diagm(w)*trans(v)))) < 1e-10);
 
         svd_fast(A, u, w, v, rank+5, 0);
         DLIB_TEST(max(abs(trans(u)*u - identity_matrix<double>(u.nc()))) < 1e-13);
-- 
2.18.0