Commit 3b0f4ff1 authored by Davis King's avatar Davis King

Added more unit tests for the forces_last_weight_to_1 stuff.

parent 9ab59297
...@@ -223,7 +223,7 @@ namespace ...@@ -223,7 +223,7 @@ namespace
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
template <typename K> template <typename K, bool use_dcd_trainer>
class simple_rank_trainer class simple_rank_trainer
{ {
public: public:
...@@ -250,18 +250,35 @@ namespace ...@@ -250,18 +250,35 @@ namespace
} }
} }
svm_c_linear_dcd_trainer<K> trainer; if (use_dcd_trainer)
trainer.set_c(1.0/samples.size()); {
trainer.set_epsilon(1e-10); svm_c_linear_dcd_trainer<K> trainer;
trainer.force_last_weight_to_1(true); trainer.set_c(1.0/samples.size());
//trainer.be_verbose(); trainer.set_epsilon(1e-10);
return trainer.train(samples, labels); trainer.force_last_weight_to_1(true);
//trainer.be_verbose();
return trainer.train(samples, labels);
}
else
{
svm_c_linear_trainer<K> trainer;
trainer.set_c(1.0);
trainer.set_epsilon(1e-13);
trainer.force_last_weight_to_1(true);
//trainer.be_verbose();
decision_function<K> df = trainer.train(samples, labels);
DLIB_TEST_MSG(df.b == 0, df.b);
return df;
}
} }
}; };
template <bool use_dcd_trainer>
void test_svmrank_weight_force_dense() void test_svmrank_weight_force_dense()
{ {
print_spinner(); print_spinner();
dlog << LINFO << "use_dcd_trainer: "<< use_dcd_trainer;
typedef matrix<double,10,1> sample_type; typedef matrix<double,10,1> sample_type;
typedef linear_kernel<sample_type> kernel_type; typedef linear_kernel<sample_type> kernel_type;
...@@ -291,7 +308,7 @@ namespace ...@@ -291,7 +308,7 @@ namespace
dlog << LINFO << "ranking accuracy: " << acc1; dlog << LINFO << "ranking accuracy: " << acc1;
DLIB_TEST(std::abs(acc1 - 1) == 0); DLIB_TEST(std::abs(acc1 - 1) == 0);
simple_rank_trainer<kernel_type> strainer; simple_rank_trainer<kernel_type,use_dcd_trainer> strainer;
decision_function<kernel_type> df2; decision_function<kernel_type> df2;
df2 = strainer.train(pair); df2 = strainer.train(pair);
dlog << LINFO << "weights: "<< trans(df2.basis_vectors(0)); dlog << LINFO << "weights: "<< trans(df2.basis_vectors(0));
...@@ -325,7 +342,8 @@ namespace ...@@ -325,7 +342,8 @@ namespace
test_count_ranking_inversions(); test_count_ranking_inversions();
dotest1(); dotest1();
dotest_sparse_vectors(); dotest_sparse_vectors();
test_svmrank_weight_force_dense(); test_svmrank_weight_force_dense<true>();
test_svmrank_weight_force_dense<false>();
} }
} a; } a;
......
...@@ -250,6 +250,7 @@ namespace ...@@ -250,6 +250,7 @@ namespace
typedef linear_kernel<sample_type> kernel_type; typedef linear_kernel<sample_type> kernel_type;
svm_c_linear_trainer<kernel_type> linear_trainer_cpa;
svm_c_linear_dcd_trainer<kernel_type> linear_trainer; svm_c_linear_dcd_trainer<kernel_type> linear_trainer;
...@@ -257,7 +258,9 @@ namespace ...@@ -257,7 +258,9 @@ namespace
const double C = 1; const double C = 1;
linear_trainer.set_epsilon(1e-10); linear_trainer.set_epsilon(1e-10);
linear_trainer_cpa.set_epsilon(1e-11);
linear_trainer_cpa.force_last_weight_to_1(force_weight);
linear_trainer.force_last_weight_to_1(force_weight); linear_trainer.force_last_weight_to_1(force_weight);
linear_trainer.include_bias(have_bias); linear_trainer.include_bias(have_bias);
...@@ -268,7 +271,7 @@ namespace ...@@ -268,7 +271,7 @@ namespace
// make an instance of a sample vector so we can use it below // make an instance of a sample vector so we can use it below
sample_type sample; sample_type sample;
decision_function<kernel_type> df; decision_function<kernel_type> df, df2;
running_stats<double> rs; running_stats<double> rs;
...@@ -299,11 +302,22 @@ namespace ...@@ -299,11 +302,22 @@ namespace
labels.push_back(label); labels.push_back(label);
linear_trainer.set_c(C); linear_trainer.set_c(C);
linear_trainer_cpa.set_c(C*samples.size());
df = linear_trainer.train(samples, labels, state); df = linear_trainer.train(samples, labels, state);
if (force_weight) if (force_weight)
{
DLIB_TEST(std::abs(df.basis_vectors(0)(9) - 1) < 1e-8); DLIB_TEST(std::abs(df.basis_vectors(0)(9) - 1) < 1e-8);
DLIB_TEST(std::abs(df.b) < 1e-8);
if (samples.size() > 1)
{
df2 = linear_trainer_cpa.train(samples, labels);
DLIB_TEST_MSG( max(abs(df.basis_vectors(0) - df2.basis_vectors(0))) < 1e-7, max(abs(df.basis_vectors(0) - df2.basis_vectors(0))));
DLIB_TEST( std::abs(df.b - df2.b) < 1e-7);
}
}
if (!have_bias) if (!have_bias)
DLIB_TEST(std::abs(df.b) < 1e-8); DLIB_TEST(std::abs(df.b) < 1e-8);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment