Commit 754610f2 authored by Davis King's avatar Davis King

Added the option to set a prior to svm_rank_trainer.

parent 461abe65
...@@ -297,6 +297,8 @@ namespace dlib ...@@ -297,6 +297,8 @@ namespace dlib
) )
{ {
last_weight_1 = should_last_weight_be_1; last_weight_1 = should_last_weight_be_1;
if (last_weight_1)
prior.set_size(0);
} }
void set_oca ( void set_oca (
...@@ -326,6 +328,33 @@ namespace dlib ...@@ -326,6 +328,33 @@ namespace dlib
) )
{ {
learn_nonnegative_weights = value; learn_nonnegative_weights = value;
if (learn_nonnegative_weights)
prior.set_size(0);
}
void set_prior (
const trained_function_type& prior_
)
{
// make sure requires clause is not broken
DLIB_ASSERT(prior_.basis_vectors.size() == 1 &&
prior_.alpha(0) == 1,
"\t void svm_rank_trainer::set_prior()"
<< "\n\t The supplied prior could not have been created by this object's train() method."
<< "\n\t prior_.basis_vectors.size(): " << prior_.basis_vectors.size()
<< "\n\t prior_.alpha(0): " << prior_.alpha(0)
<< "\n\t this: " << this
);
prior = prior_.basis_vectors(0);
learn_nonnegative_weights = false;
last_weight_1 = false;
}
bool has_prior (
) const
{
return prior.size() != 0;
} }
void set_c ( void set_c (
...@@ -379,10 +408,30 @@ namespace dlib ...@@ -379,10 +408,30 @@ namespace dlib
force_weight_1_idx = num_dims-1; force_weight_1_idx = num_dims-1;
} }
solver( make_oca_problem_ranking_svm<w_type>(C, samples, verbose, eps, max_iterations), if (has_prior())
{
if (is_matrix<sample_type>::value)
{
// make sure requires clause is not broken
DLIB_CASSERT(num_dims == (unsigned long)prior.size(),
"\t decision_function svm_rank_trainer::train(samples)"
<< "\n\t The dimension of the training vectors must match the dimension of\n"
<< "\n\t those used to create the prior."
<< "\n\t num_dims: " << num_dims
<< "\n\t prior.size(): " << prior.size()
);
}
solver( make_oca_problem_ranking_svm<w_type>(C, samples, verbose, eps, max_iterations),
w,
prior);
}
else
{
solver( make_oca_problem_ranking_svm<w_type>(C, samples, verbose, eps, max_iterations),
w, w,
num_nonnegative, num_nonnegative,
force_weight_1_idx); force_weight_1_idx);
}
// put the solution into a decision function and then return it // put the solution into a decision function and then return it
...@@ -415,6 +464,7 @@ namespace dlib ...@@ -415,6 +464,7 @@ namespace dlib
unsigned long max_iterations; unsigned long max_iterations;
bool learn_nonnegative_weights; bool learn_nonnegative_weights;
bool last_weight_1; bool last_weight_1;
matrix<scalar_type,0,1> prior;
}; };
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
......
...@@ -58,6 +58,7 @@ namespace dlib ...@@ -58,6 +58,7 @@ namespace dlib
- #get_max_iterations() == 10000 - #get_max_iterations() == 10000
- #learns_nonnegative_weights() == false - #learns_nonnegative_weights() == false
- #forces_last_weight_to_1() == false - #forces_last_weight_to_1() == false
- #has_prior() == false
!*/ !*/
explicit svm_rank_trainer ( explicit svm_rank_trainer (
...@@ -76,6 +77,7 @@ namespace dlib ...@@ -76,6 +77,7 @@ namespace dlib
- #get_max_iterations() == 10000 - #get_max_iterations() == 10000
- #learns_nonnegative_weights() == false - #learns_nonnegative_weights() == false
- #forces_last_weight_to_1() == false - #forces_last_weight_to_1() == false
- #has_prior() == false
!*/ !*/
void set_epsilon ( void set_epsilon (
...@@ -146,6 +148,8 @@ namespace dlib ...@@ -146,6 +148,8 @@ namespace dlib
/*! /*!
ensures ensures
- #forces_last_weight_to_1() == should_last_weight_be_1 - #forces_last_weight_to_1() == should_last_weight_be_1
- if (should_last_weight_be_1 == true) then
- #has_prior() == false
!*/ !*/
void set_oca ( void set_oca (
...@@ -190,6 +194,39 @@ namespace dlib ...@@ -190,6 +194,39 @@ namespace dlib
/*! /*!
ensures ensures
- #learns_nonnegative_weights() == value - #learns_nonnegative_weights() == value
- if (value == true) then
- #has_prior() == false
!*/
void set_prior (
const trained_function_type& prior
);
/*!
requires
- prior == a function produced by a call to this class's train() function.
Therefore, it must be the case that:
- prior.basis_vectors.size() == 1
- prior.alpha(0) == 1
ensures
- Subsequent calls to train() will try to learn a function similar to the
given prior.
- #has_prior() == true
- #learns_nonnegative_weights() == false
- #forces_last_weight_to_1() == false
!*/
bool has_prior (
) const
/*!
ensures
- returns true if a prior has been set and false otherwise. Having a prior
set means that you have called set_prior() and supplied a previously
trained function as a reference. In this case, any call to train() will
try to learn a function that matches the behavior of the prior as close
as possible but also fits the supplied training data. In more technical
detail, having a prior means we replace the ||w||^2 regularizer with one
of the form ||w-prior||^2 where w is the set of parameters for a learned
function.
!*/ !*/
void set_c ( void set_c (
...@@ -219,6 +256,9 @@ namespace dlib ...@@ -219,6 +256,9 @@ namespace dlib
/*! /*!
requires requires
- is_ranking_problem(samples) == true - is_ranking_problem(samples) == true
- if (has_prior()) then
- The vectors in samples must have the same dimensionality as the
vectors used to train the prior given to set_prior().
ensures ensures
- trains a ranking support vector classifier given the training samples. - trains a ranking support vector classifier given the training samples.
- returns a decision function F with the following properties: - returns a decision function F with the following properties:
...@@ -237,6 +277,9 @@ namespace dlib ...@@ -237,6 +277,9 @@ namespace dlib
/*! /*!
requires requires
- is_ranking_problem(std::vector<ranking_pair<sample_type> >(1, sample)) == true - is_ranking_problem(std::vector<ranking_pair<sample_type> >(1, sample)) == true
- if (has_prior()) then
- The vectors in samples must have the same dimensionality as the
vectors used to train the prior given to set_prior().
ensures ensures
- This is just a convenience routine for calling the above train() - This is just a convenience routine for calling the above train()
function. That is, it just copies sample into a std::vector object and function. That is, it just copies sample into a std::vector object and
......
...@@ -73,6 +73,40 @@ namespace ...@@ -73,6 +73,40 @@ namespace
} }
} }
// ----------------------------------------------------------------------------------------
void run_prior_test()
{
print_spinner();
typedef matrix<double,3,1> sample_type;
typedef linear_kernel<sample_type> kernel_type;
svm_rank_trainer<kernel_type> trainer;
ranking_pair<sample_type> data;
sample_type samp;
samp = 0, 0, 1; data.relevant.push_back(samp);
samp = 0, 1, 0; data.nonrelevant.push_back(samp);
trainer.set_c(10);
decision_function<kernel_type> df = trainer.train(data);
trainer.set_prior(df);
data.relevant.clear();
data.nonrelevant.clear();
samp = 1, 0, 0; data.relevant.push_back(samp);
samp = 0, 1, 0; data.nonrelevant.push_back(samp);
df = trainer.train(data);
dlog << LINFO << trans(df.basis_vectors(0));
DLIB_TEST(df.basis_vectors(0)(0) > 0);
DLIB_TEST(df.basis_vectors(0)(1) < 0);
DLIB_TEST(df.basis_vectors(0)(2) > 0);
}
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
void dotest1() void dotest1()
...@@ -355,6 +389,7 @@ namespace ...@@ -355,6 +389,7 @@ namespace
dotest_sparse_vectors(); dotest_sparse_vectors();
test_svmrank_weight_force_dense<true>(); test_svmrank_weight_force_dense<true>();
test_svmrank_weight_force_dense<false>(); test_svmrank_weight_force_dense<false>();
run_prior_test();
} }
} a; } a;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment