Commit c63e4598 authored by Davis King's avatar Davis King

Added the ability to set a prior to the svm_c_linear_trainer.

parent 8c797ae9
......@@ -445,6 +445,8 @@ namespace dlib
)
{
learn_nonnegative_weights = value;
if (learns_nonnegative_weights)
prior.set_size(0);
}
bool forces_last_weight_to_1 (
......@@ -458,6 +460,33 @@ namespace dlib
)
{
last_weight_1 = should_last_weight_be_1;
if (last_weight_1)
prior.set_size(0);
}
void set_prior (
const trained_function_type& prior_
)
{
// make sure requires clause is not broken
DLIB_ASSERT(prior_.basis_vectors.size() == 1 &&
prior_.alpha(0) == 1,
"\t void svm_c_linear_trainer::set_prior()"
<< "\n\t The supplied prior could not have been created by this object's train() method."
<< "\n\t prior_.basis_vectors.size(): " << prior_.basis_vectors.size()
<< "\n\t prior_.alpha(0): " << prior_.alpha(0)
<< "\n\t this: " << this
);
prior = join_cols(prior_.basis_vectors(0), mat((scalar_type)prior_.b));
learn_nonnegative_weights = false;
last_weight_1 = false;
}
bool has_prior (
) const
{
return prior.size() != 0;
}
void set_c (
......@@ -597,11 +626,21 @@ namespace dlib
}
svm_objective = solver(
make_oca_problem_c_svm<w_type>(Cpos, Cneg, x, y, verbose, eps, max_iterations),
w,
num_nonnegative,
force_weight_1_idx);
if (has_prior())
{
svm_objective = solver(
make_oca_problem_c_svm<w_type>(Cpos, Cneg, x, y, verbose, eps, max_iterations),
w,
prior);
}
else
{
svm_objective = solver(
make_oca_problem_c_svm<w_type>(Cpos, Cneg, x, y, verbose, eps, max_iterations),
w,
num_nonnegative,
force_weight_1_idx);
}
// put the solution into a decision function and then return it
decision_function<kernel_type> df;
......@@ -627,6 +666,7 @@ namespace dlib
unsigned long max_iterations;
bool learn_nonnegative_weights;
bool last_weight_1;
matrix<scalar_type,0,1> prior;
};
// ----------------------------------------------------------------------------------------
......
......@@ -54,6 +54,7 @@ namespace dlib
- #get_max_iterations() == 10000
- #learns_nonnegative_weights() == false
- #force_last_weight_to_1() == false
- #has_prior() == false
!*/
explicit svm_c_linear_trainer (
......@@ -73,6 +74,7 @@ namespace dlib
- #get_max_iterations() == 10000
- #learns_nonnegative_weights() == false
- #force_last_weight_to_1() == false
- #has_prior() == false
!*/
void set_epsilon (
......@@ -170,6 +172,39 @@ namespace dlib
/*!
ensures
- #learns_nonnegative_weights() == value
- if (value == true) then
- #has_prior() == false
!*/
void set_prior (
const trained_function_type& prior
);
/*!
requires
- prior == a function produced by a call to this classes train() function.
Therefore, it must be the case that:
- prior.basis_vectors.size() == 1
- prior.alpha(0) == 1
ensures
- Subsequent calls to train() will try to learn a function similar to the
given prior.
- #has_prior() == true
- #learns_nonnegative_weights() == false
- #forces_last_weight_to_1() == false
!*/
bool has_prior (
) const
/*!
ensures
- returns true if a prior has been set and false otherwise. Having a prior
set means that you have called set_prior() and supplied a previously
trained function as a reference. In this case, any call to train() will
try to learn a function that matches the behavior of the prior as close
as possible but also fits the supplied training data. In more technical
detail, having a prior means we replace the ||w||^2 regularizer with one
of the form ||w-prior||^2 where w is the set of parameters for a learned
function.
!*/
bool forces_last_weight_to_1 (
......@@ -189,6 +224,8 @@ namespace dlib
/*!
ensures
- #forces_last_weight_to_1() == should_last_weight_be_1
- if (should_last_weight_be_1 == true) then
- #has_prior() == false
!*/
void set_c (
......@@ -263,6 +300,9 @@ namespace dlib
Also, x should contain sample_type objects.
- y == a matrix or something convertible to a matrix via mat().
Also, y should contain scalar_type objects.
- if (has_prior()) then
- The vectors in x must have the same dimensionality as the vectors
used to train the prior given to set_prior().
ensures
- trains a C support vector classifier given the training samples in x and
labels in y.
......@@ -294,6 +334,9 @@ namespace dlib
Also, x should contain sample_type objects.
- y == a matrix or something convertible to a matrix via mat().
Also, y should contain scalar_type objects.
- if (has_prior()) then
- The vectors in x must have the same dimensionality as the vectors
used to train the prior given to set_prior().
ensures
- trains a C support vector classifier given the training samples in x and
labels in y.
......
......@@ -32,6 +32,44 @@ namespace
// ----------------------------------------------------------------------------------------
void run_prior_test()
{
typedef matrix<double,3,1> sample_type;
typedef linear_kernel<sample_type> kernel_type;
svm_c_linear_trainer<kernel_type> trainer;
std::vector<sample_type> samples;
std::vector<double> labels;
sample_type samp;
samp = 0, 0, 1; samples.push_back(samp); labels.push_back(+1);
samp = 0, 1, 0; samples.push_back(samp); labels.push_back(-1);
trainer.set_c(10);
decision_function<kernel_type> df = trainer.train(samples, labels);
trainer.set_prior(df);
samples.clear();
labels.clear();
samp = 1, 0, 0; samples.push_back(samp); labels.push_back(+1);
samp = 0, 1, 0; samples.push_back(samp); labels.push_back(-1);
df = trainer.train(samples, labels);
samp = 0, 0, 1; samples.push_back(samp); labels.push_back(+1);
matrix<double,1,2> rs = test_binary_decision_function(df, samples, labels);
dlog << LINFO << rs;
DLIB_TEST(rs(0) == 1);
DLIB_TEST(rs(1) == 1);
dlog << LINFO << trans(df.basis_vectors(0));
DLIB_TEST(df.basis_vectors(0)(0) > 0);
DLIB_TEST(df.basis_vectors(0)(1) < 0);
DLIB_TEST(df.basis_vectors(0)(2) > 0);
}
void get_simple_points (
std::vector<sample_type>& samples,
std::vector<double>& labels
......@@ -216,6 +254,7 @@ namespace
{
test_dense();
test_sparse();
run_prior_test();
// test mixed sparse and dense dot products
{
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment