Commit d9d6fa12 authored by Davis King's avatar Davis King

Added the ability to set a previously trained function as a prior to the

svm_multiclass_linear_trainer.
parent a7047b35
......@@ -10,6 +10,7 @@
#include "../matrix.h"
#include "sparse_vector.h"
#include "function.h"
#include <algorithm>
namespace dlib
{
......@@ -46,13 +47,15 @@ namespace dlib
multiclass_svm_problem (
const std::vector<sample_type>& samples_,
const std::vector<label_type>& labels_,
const std::vector<label_type>& distinct_labels_,
const unsigned long dims_,
const unsigned long num_threads
) :
structural_svm_problem_threaded<matrix_type, std::vector<std::pair<unsigned long,typename matrix_type::type> > >(num_threads),
samples(samples_),
labels(labels_),
distinct_labels(select_all_distinct_labels(labels_)),
dims(max_index_plus_one(samples_)+1) // +1 for the bias
distinct_labels(distinct_labels_),
dims(dims_+1) // +1 for the bias
{}
virtual long get_num_dimensions (
......@@ -151,7 +154,7 @@ namespace dlib
const std::vector<sample_type>& samples;
const std::vector<label_type>& labels;
const std::vector<label_type> distinct_labels;
const std::vector<label_type>& distinct_labels;
const long dims;
};
......@@ -260,6 +263,7 @@ namespace dlib
)
{
learn_nonnegative_weights = value;
prior = trained_function_type();
}
void set_c (
......@@ -283,6 +287,20 @@ namespace dlib
return C;
}
void set_prior (
const trained_function_type& prior_
)
{
prior = prior_;
learn_nonnegative_weights = false;
}
bool has_prior (
) const
{
return prior.labels.size() != 0;
}
trained_function_type train (
const std::vector<sample_type>& all_samples,
const std::vector<label_type>& all_labels
......@@ -306,9 +324,33 @@ namespace dlib
<< "\n\t all_labels.size(): " << all_labels.size()
);
trained_function_type df;
df.labels = select_all_distinct_labels(all_labels);
if (has_prior())
{
df.labels.insert(df.labels.end(), prior.labels.begin(), prior.labels.end());
df.labels = select_all_distinct_labels(df.labels);
}
const long input_sample_dimensionality = max_index_plus_one(all_samples);
// If the samples are sparse then the right thing to do is to take the max
// dimensionality between the prior and the new samples. But if the samples
// are dense vectors then they definitely all have to have exactly the same
// dimensionality.
const long dims = std::max(df.weights.nc(),input_sample_dimensionality);
if (is_matrix<sample_type>::value && has_prior())
{
DLIB_ASSERT(input_sample_dimensionality == prior.weights.nc(),
"\t trained_function_type svm_multiclass_linear_trainer::train(all_samples,all_labels)"
<< "\n\t The training samples given to this function are not the same kind of training "
<< "\n\t samples used to create the prior."
<< "\n\t input_sample_dimensionality: " << input_sample_dimensionality
<< "\n\t prior.weights.nc(): " << prior.weights.nc()
);
}
typedef matrix<scalar_type,0,1> w_type;
w_type weights;
multiclass_svm_problem<w_type, sample_type, label_type> problem(all_samples, all_labels, num_threads);
multiclass_svm_problem<w_type, sample_type, label_type> problem(all_samples, all_labels, df.labels, dims, num_threads);
if (verbose)
problem.be_verbose();
......@@ -322,12 +364,33 @@ namespace dlib
num_nonnegative = problem.get_num_dimensions();
}
svm_objective = solver(problem, weights, num_nonnegative);
if (!has_prior())
{
svm_objective = solver(problem, weights, num_nonnegative);
}
else
{
matrix<scalar_type> temp(df.labels.size(),dims);
w_type b(df.labels.size());
temp = 0;
b = 0;
// Copy the prior into the temp and b matrices. We have to do this row
// by row copy because the new training data might have new labels we
// haven't seen before and therefore the sizes of these matrices could be
// different.
for (unsigned long i = 0; i < prior.labels.size(); ++i)
{
const long r = std::find(df.labels.begin(), df.labels.end(), prior.labels[i])-df.labels.begin();
set_rowm(temp,r) = rowm(prior.weights,i);
b(r) = prior.b(i);
}
const w_type prior_vect = reshape_to_column_vector(join_rows(temp,b));
svm_objective = solver(problem, weights, prior_vect);
}
trained_function_type df;
const long dims = max_index_plus_one(all_samples);
df.labels = select_all_distinct_labels(all_labels);
df.weights = colm(reshape(weights, df.labels.size(), dims+1), range(0,dims-1));
df.b = colm(reshape(weights, df.labels.size(), dims+1), dims);
return df;
......@@ -341,6 +404,8 @@ namespace dlib
bool verbose;
oca solver;
bool learn_nonnegative_weights;
trained_function_type prior;
};
// ----------------------------------------------------------------------------------------
......
......@@ -37,6 +37,7 @@ namespace dlib
- get_c() == 1
- this object will not be verbose unless be_verbose() is called
- #get_oca() == oca() (i.e. an instance of oca with default parameters)
- has_prior() == false
WHAT THIS OBJECT REPRESENTS
This object represents a tool for training a multiclass support
......@@ -176,6 +177,29 @@ namespace dlib
- #learns_nonnegative_weights() == value
!*/
void set_prior (
const trained_function_type& prior
);
/*!
ensures
- #has_prior() == true
- #learns_nonnegative_weights() == false
!*/
bool has_prior (
) const
/*!
ensures
- returns true if a prior has been set and false otherwise. Having a prior
set means that you have called set_prior() and supplied a previously
trained function as a reference. In this case, any call to train() will
try to learn a function that matches the behavior of the prior as close
as possible but also fits the supplied training data. In more technical
detail, having a prior means we replace the ||w||^2 regularizer with one
of the form ||w-prior||^2 where w is the set of parameters for a learned
function.
!*/
trained_function_type train (
const std::vector<sample_type>& all_samples,
const std::vector<label_type>& all_labels
......@@ -183,6 +207,10 @@ namespace dlib
/*!
requires
- is_learning_problem(all_samples, all_labels)
- All the vectors in all_samples must have the same dimensionality.
- if (has_prior()) then
- The vectors in all_samples must have the same dimensionality as the
vectors used to train the prior given to set_prior().
ensures
- trains a multiclass SVM to solve the given multiclass classification problem.
- returns a multiclass_linear_decision_function F with the following properties:
......@@ -200,6 +228,10 @@ namespace dlib
/*!
requires
- is_learning_problem(all_samples, all_labels)
- All the vectors in all_samples must have the same dimensionality.
- if (has_prior()) then
- The vectors in all_samples must have the same dimensionality as the
vectors used to train the prior given to set_prior().
ensures
- trains a multiclass SVM to solve the given multiclass classification problem.
- returns a multiclass_linear_decision_function F with the following properties:
......
......@@ -35,6 +35,63 @@ namespace
}
void test_prior ()
{
print_spinner();
typedef matrix<double,4,1> sample_type;
typedef linear_kernel<sample_type> kernel_type;
std::vector<sample_type> samples;
std::vector<int> labels;
for (int i = 0; i < 4; ++i)
{
if (i==2)
++i;
for (int iter = 0; iter < 5; ++iter)
{
sample_type samp;
samp = 0;
samp(i) = 1;
samples.push_back(samp);
labels.push_back(i);
}
}
svm_multiclass_linear_trainer<kernel_type,int> trainer;
multiclass_linear_decision_function<kernel_type,int> df = trainer.train(samples, labels);
//cout << "test: \n" << test_multiclass_decision_function(df, samples, labels) << endl;
//cout << df.weights << endl;
//cout << df.b << endl;
std::vector<sample_type> samples2;
std::vector<int> labels2;
int i = 2;
for (int iter = 0; iter < 5; ++iter)
{
sample_type samp;
samp = 0;
samp(i) = 1;
samples2.push_back(samp);
labels2.push_back(i);
samples.push_back(samp);
labels.push_back(i);
}
trainer.set_prior(df);
trainer.set_c(0.1);
df = trainer.train(samples2, labels2);
matrix<double> res = test_multiclass_decision_function(df, samples, labels);
dlog << LINFO << "test: \n" << res;
dlog << LINFO << df.weights;
dlog << LINFO << df.b;
DLIB_TEST((unsigned int)sum(diag(res))==samples.size());
}
template <typename sample_type>
void run_test()
{
......@@ -99,6 +156,8 @@ namespace
run_test<std::map<unsigned int, float> >();
run_test<std::vector<std::pair<unsigned int, float> > >();
run_test<std::vector<std::pair<unsigned long, double> > >();
test_prior();
}
};
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment