Commit c63e4598 authored by Davis King's avatar Davis King

Added the ability to set a prior to the svm_c_linear_trainer.

parent 8c797ae9
...@@ -445,6 +445,8 @@ namespace dlib ...@@ -445,6 +445,8 @@ namespace dlib
) )
{ {
learn_nonnegative_weights = value; learn_nonnegative_weights = value;
if (learns_nonnegative_weights)
prior.set_size(0);
} }
bool forces_last_weight_to_1 ( bool forces_last_weight_to_1 (
...@@ -458,6 +460,33 @@ namespace dlib ...@@ -458,6 +460,33 @@ namespace dlib
) )
{ {
last_weight_1 = should_last_weight_be_1; last_weight_1 = should_last_weight_be_1;
if (last_weight_1)
prior.set_size(0);
}
void set_prior (
const trained_function_type& prior_
)
{
// make sure requires clause is not broken
DLIB_ASSERT(prior_.basis_vectors.size() == 1 &&
prior_.alpha(0) == 1,
"\t void svm_c_linear_trainer::set_prior()"
<< "\n\t The supplied prior could not have been created by this object's train() method."
<< "\n\t prior_.basis_vectors.size(): " << prior_.basis_vectors.size()
<< "\n\t prior_.alpha(0): " << prior_.alpha(0)
<< "\n\t this: " << this
);
prior = join_cols(prior_.basis_vectors(0), mat((scalar_type)prior_.b));
learn_nonnegative_weights = false;
last_weight_1 = false;
}
bool has_prior (
) const
{
return prior.size() != 0;
} }
void set_c ( void set_c (
...@@ -597,11 +626,21 @@ namespace dlib ...@@ -597,11 +626,21 @@ namespace dlib
} }
svm_objective = solver( if (has_prior())
make_oca_problem_c_svm<w_type>(Cpos, Cneg, x, y, verbose, eps, max_iterations), {
w, svm_objective = solver(
num_nonnegative, make_oca_problem_c_svm<w_type>(Cpos, Cneg, x, y, verbose, eps, max_iterations),
force_weight_1_idx); w,
prior);
}
else
{
svm_objective = solver(
make_oca_problem_c_svm<w_type>(Cpos, Cneg, x, y, verbose, eps, max_iterations),
w,
num_nonnegative,
force_weight_1_idx);
}
// put the solution into a decision function and then return it // put the solution into a decision function and then return it
decision_function<kernel_type> df; decision_function<kernel_type> df;
...@@ -627,6 +666,7 @@ namespace dlib ...@@ -627,6 +666,7 @@ namespace dlib
unsigned long max_iterations; unsigned long max_iterations;
bool learn_nonnegative_weights; bool learn_nonnegative_weights;
bool last_weight_1; bool last_weight_1;
matrix<scalar_type,0,1> prior;
}; };
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
......
...@@ -54,6 +54,7 @@ namespace dlib ...@@ -54,6 +54,7 @@ namespace dlib
- #get_max_iterations() == 10000 - #get_max_iterations() == 10000
- #learns_nonnegative_weights() == false - #learns_nonnegative_weights() == false
- #force_last_weight_to_1() == false - #force_last_weight_to_1() == false
- #has_prior() == false
!*/ !*/
explicit svm_c_linear_trainer ( explicit svm_c_linear_trainer (
...@@ -73,6 +74,7 @@ namespace dlib ...@@ -73,6 +74,7 @@ namespace dlib
- #get_max_iterations() == 10000 - #get_max_iterations() == 10000
- #learns_nonnegative_weights() == false - #learns_nonnegative_weights() == false
- #force_last_weight_to_1() == false - #force_last_weight_to_1() == false
- #has_prior() == false
!*/ !*/
void set_epsilon ( void set_epsilon (
...@@ -170,6 +172,39 @@ namespace dlib ...@@ -170,6 +172,39 @@ namespace dlib
/*! /*!
ensures ensures
- #learns_nonnegative_weights() == value - #learns_nonnegative_weights() == value
- if (value == true) then
- #has_prior() == false
!*/
void set_prior (
const trained_function_type& prior
);
/*!
requires
- prior == a function produced by a call to this classes train() function.
Therefore, it must be the case that:
- prior.basis_vectors.size() == 1
- prior.alpha(0) == 1
ensures
- Subsequent calls to train() will try to learn a function similar to the
given prior.
- #has_prior() == true
- #learns_nonnegative_weights() == false
- #forces_last_weight_to_1() == false
!*/
bool has_prior (
) const
/*!
ensures
- returns true if a prior has been set and false otherwise. Having a prior
set means that you have called set_prior() and supplied a previously
trained function as a reference. In this case, any call to train() will
try to learn a function that matches the behavior of the prior as close
as possible but also fits the supplied training data. In more technical
detail, having a prior means we replace the ||w||^2 regularizer with one
of the form ||w-prior||^2 where w is the set of parameters for a learned
function.
!*/ !*/
bool forces_last_weight_to_1 ( bool forces_last_weight_to_1 (
...@@ -189,6 +224,8 @@ namespace dlib ...@@ -189,6 +224,8 @@ namespace dlib
/*! /*!
ensures ensures
- #forces_last_weight_to_1() == should_last_weight_be_1 - #forces_last_weight_to_1() == should_last_weight_be_1
- if (should_last_weight_be_1 == true) then
- #has_prior() == false
!*/ !*/
void set_c ( void set_c (
...@@ -263,6 +300,9 @@ namespace dlib ...@@ -263,6 +300,9 @@ namespace dlib
Also, x should contain sample_type objects. Also, x should contain sample_type objects.
- y == a matrix or something convertible to a matrix via mat(). - y == a matrix or something convertible to a matrix via mat().
Also, y should contain scalar_type objects. Also, y should contain scalar_type objects.
- if (has_prior()) then
- The vectors in x must have the same dimensionality as the vectors
used to train the prior given to set_prior().
ensures ensures
- trains a C support vector classifier given the training samples in x and - trains a C support vector classifier given the training samples in x and
labels in y. labels in y.
...@@ -294,6 +334,9 @@ namespace dlib ...@@ -294,6 +334,9 @@ namespace dlib
Also, x should contain sample_type objects. Also, x should contain sample_type objects.
- y == a matrix or something convertible to a matrix via mat(). - y == a matrix or something convertible to a matrix via mat().
Also, y should contain scalar_type objects. Also, y should contain scalar_type objects.
- if (has_prior()) then
- The vectors in x must have the same dimensionality as the vectors
used to train the prior given to set_prior().
ensures ensures
- trains a C support vector classifier given the training samples in x and - trains a C support vector classifier given the training samples in x and
labels in y. labels in y.
......
...@@ -32,6 +32,44 @@ namespace ...@@ -32,6 +32,44 @@ namespace
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
void run_prior_test()
{
typedef matrix<double,3,1> sample_type;
typedef linear_kernel<sample_type> kernel_type;
svm_c_linear_trainer<kernel_type> trainer;
std::vector<sample_type> samples;
std::vector<double> labels;
sample_type samp;
samp = 0, 0, 1; samples.push_back(samp); labels.push_back(+1);
samp = 0, 1, 0; samples.push_back(samp); labels.push_back(-1);
trainer.set_c(10);
decision_function<kernel_type> df = trainer.train(samples, labels);
trainer.set_prior(df);
samples.clear();
labels.clear();
samp = 1, 0, 0; samples.push_back(samp); labels.push_back(+1);
samp = 0, 1, 0; samples.push_back(samp); labels.push_back(-1);
df = trainer.train(samples, labels);
samp = 0, 0, 1; samples.push_back(samp); labels.push_back(+1);
matrix<double,1,2> rs = test_binary_decision_function(df, samples, labels);
dlog << LINFO << rs;
DLIB_TEST(rs(0) == 1);
DLIB_TEST(rs(1) == 1);
dlog << LINFO << trans(df.basis_vectors(0));
DLIB_TEST(df.basis_vectors(0)(0) > 0);
DLIB_TEST(df.basis_vectors(0)(1) < 0);
DLIB_TEST(df.basis_vectors(0)(2) > 0);
}
void get_simple_points ( void get_simple_points (
std::vector<sample_type>& samples, std::vector<sample_type>& samples,
std::vector<double>& labels std::vector<double>& labels
...@@ -216,6 +254,7 @@ namespace ...@@ -216,6 +254,7 @@ namespace
{ {
test_dense(); test_dense();
test_sparse(); test_sparse();
run_prior_test();
// test mixed sparse and dense dot products // test mixed sparse and dense dot products
{ {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment