Commit a0a692f9 authored by Davis King's avatar Davis King

Added the option to use an identity matrix prior to the vector_normalizer_frobmetric

object.
parent f4e50eaa
...@@ -81,8 +81,9 @@ namespace dlib ...@@ -81,8 +81,9 @@ namespace dlib
{ {
objective ( objective (
const std::vector<compact_frobmetric_training_sample>& samples_, const std::vector<compact_frobmetric_training_sample>& samples_,
matrix<double,0,0,mem_manager_type>& Aminus_ matrix<double,0,0,mem_manager_type>& Aminus_,
) : samples(samples_), Aminus(Aminus_) {} const matrix<double,0,1,mem_manager_type>& bias_
) : samples(samples_), Aminus(Aminus_), bias(bias_) {}
double operator()(const matrix<double,0,1,mem_manager_type>& u) const double operator()(const matrix<double,0,1,mem_manager_type>& u) const
{ {
...@@ -118,12 +119,13 @@ namespace dlib ...@@ -118,12 +119,13 @@ namespace dlib
// computation can make Aminus slightly non-symmetric. // computation can make Aminus slightly non-symmetric.
Aminus = make_symmetric(Aminus); Aminus = make_symmetric(Aminus);
return sum(u) - 0.5*sum(squared(Aminus)); return dot(u,bias) - 0.5*sum(squared(Aminus));
} }
private: private:
const std::vector<compact_frobmetric_training_sample>& samples; const std::vector<compact_frobmetric_training_sample>& samples;
matrix<double,0,0,mem_manager_type>& Aminus; matrix<double,0,0,mem_manager_type>& Aminus;
const matrix<double,0,1,mem_manager_type>& bias;
}; };
struct derivative struct derivative
...@@ -131,8 +133,9 @@ namespace dlib ...@@ -131,8 +133,9 @@ namespace dlib
derivative ( derivative (
unsigned long num_triples_, unsigned long num_triples_,
const std::vector<compact_frobmetric_training_sample>& samples_, const std::vector<compact_frobmetric_training_sample>& samples_,
matrix<double,0,0,mem_manager_type>& Aminus_ matrix<double,0,0,mem_manager_type>& Aminus_,
) : num_triples(num_triples_), samples(samples_), Aminus(Aminus_) {} const matrix<double,0,1,mem_manager_type>& bias_
) : num_triples(num_triples_), samples(samples_), Aminus(Aminus_), bias(bias_) {}
matrix<double,0,1,mem_manager_type> operator()(const matrix<double,0,1,mem_manager_type>& ) const matrix<double,0,1,mem_manager_type> operator()(const matrix<double,0,1,mem_manager_type>& ) const
{ {
...@@ -158,7 +161,8 @@ namespace dlib ...@@ -158,7 +161,8 @@ namespace dlib
{ {
for (unsigned long k = 0; k < samples[i].far_vects.size(); ++k) for (unsigned long k = 0; k < samples[i].far_vects.size(); ++k)
{ {
grad(idx++) = 1 + ufar[k]-unear[j]; grad(idx) = bias(idx) + ufar[k]-unear[j];
idx++;
} }
} }
} }
...@@ -170,6 +174,7 @@ namespace dlib ...@@ -170,6 +174,7 @@ namespace dlib
const unsigned long num_triples; const unsigned long num_triples;
const std::vector<compact_frobmetric_training_sample>& samples; const std::vector<compact_frobmetric_training_sample>& samples;
matrix<double,0,0,mem_manager_type>& Aminus; matrix<double,0,0,mem_manager_type>& Aminus;
const matrix<double,0,1,mem_manager_type>& bias;
}; };
...@@ -245,6 +250,20 @@ namespace dlib ...@@ -245,6 +250,20 @@ namespace dlib
eps = 0.1; eps = 0.1;
C = 1; C = 1;
max_iter = 5000; max_iter = 5000;
_use_identity_matrix_prior = false;
}
bool uses_identity_matrix_prior (
) const
{
return _use_identity_matrix_prior;
}
void set_uses_identity_matrix_prior (
bool use_prior
)
{
_use_identity_matrix_prior = use_prior;
} }
void be_verbose( void be_verbose(
...@@ -402,27 +421,50 @@ namespace dlib ...@@ -402,27 +421,50 @@ namespace dlib
num_triples += samples[i].near_vects.size()*samples[i].far_vects.size(); num_triples += samples[i].near_vects.size()*samples[i].far_vects.size();
matrix<double,0,1,mem_manager_type> u(num_triples); matrix<double,0,1,mem_manager_type> u(num_triples);
matrix<double,0,1,mem_manager_type> bias(num_triples);
u = 0; u = 0;
bias = 1;
// precompute all the anchor_vect to far_vects/near_vects pairs // precompute all the anchor_vect to far_vects/near_vects pairs
std::vector<compact_frobmetric_training_sample> data(samples.size()); std::vector<compact_frobmetric_training_sample> data(samples.size());
unsigned long cnt = 0;
std::vector<double> far_norm, near_norm;
for (unsigned long i = 0; i < data.size(); ++i) for (unsigned long i = 0; i < data.size(); ++i)
{ {
far_norm.clear();
near_norm.clear();
data[i].far_vects.reserve(samples[i].far_vects.size()); data[i].far_vects.reserve(samples[i].far_vects.size());
data[i].near_vects.reserve(samples[i].near_vects.size()); data[i].near_vects.reserve(samples[i].near_vects.size());
for (unsigned long j = 0; j < samples[i].far_vects.size(); ++j) for (unsigned long j = 0; j < samples[i].far_vects.size(); ++j)
{
data[i].far_vects.push_back(samples[i].anchor_vect - samples[i].far_vects[j]); data[i].far_vects.push_back(samples[i].anchor_vect - samples[i].far_vects[j]);
if (_use_identity_matrix_prior)
far_norm.push_back(length_squared(data[i].far_vects.back()));
}
for (unsigned long j = 0; j < samples[i].near_vects.size(); ++j) for (unsigned long j = 0; j < samples[i].near_vects.size(); ++j)
{
data[i].near_vects.push_back(samples[i].anchor_vect - samples[i].near_vects[j]); data[i].near_vects.push_back(samples[i].anchor_vect - samples[i].near_vects[j]);
if (_use_identity_matrix_prior)
near_norm.push_back(length_squared(data[i].near_vects.back()));
}
// Note that this loop only executes if _use_identity_matrix_prior == true.
for (unsigned long j = 0; j < near_norm.size(); ++j)
{
for (unsigned long k = 0; k < far_norm.size(); ++k)
{
bias(cnt++) = 1 - (far_norm[k] - near_norm[j]);
}
}
} }
// Now run the main part of the algorithm // Now run the main part of the algorithm
matrix<double,0,0,mem_manager_type> Aminus; matrix<double,0,0,mem_manager_type> Aminus;
find_max_box_constrained(lbfgs_search_strategy(10), find_max_box_constrained(lbfgs_search_strategy(10),
custom_stop_strategy(C, eps, verbose, max_iter), custom_stop_strategy(C, eps, verbose, max_iter),
objective(data, Aminus), objective(data, Aminus, bias),
derivative(num_triples, data, Aminus), derivative(num_triples, data, Aminus, bias),
u, 0, C/num_triples); u, 0, C/num_triples);
...@@ -437,6 +479,9 @@ namespace dlib ...@@ -437,6 +479,9 @@ namespace dlib
if (eigs(i) < tol) if (eigs(i) < tol)
eigs(i) = 0; eigs(i) = 0;
} }
if (_use_identity_matrix_prior)
tform = matrix_cast<scalar_type>(identity_matrix(Aminus) + diagm(sqrt(eigs))*trans(ed.get_pseudo_v()));
else
tform = matrix_cast<scalar_type>(diagm(sqrt(eigs))*trans(ed.get_pseudo_v())); tform = matrix_cast<scalar_type>(diagm(sqrt(eigs))*trans(ed.get_pseudo_v()));
// Pre-apply the transform to m so we don't have to do it inside operator() // Pre-apply the transform to m so we don't have to do it inside operator()
...@@ -509,6 +554,7 @@ namespace dlib ...@@ -509,6 +554,7 @@ namespace dlib
double eps; double eps;
double C; double C;
unsigned long max_iter; unsigned long max_iter;
bool _use_identity_matrix_prior;
// This is just a temporary variable that doesn't contribute to the // This is just a temporary variable that doesn't contribute to the
// state of this object. // state of this object.
...@@ -525,7 +571,7 @@ namespace dlib ...@@ -525,7 +571,7 @@ namespace dlib
std::ostream& out std::ostream& out
) )
{ {
const int version = 1; const int version = 2;
serialize(version, out); serialize(version, out);
serialize(item.m, out); serialize(item.m, out);
...@@ -534,6 +580,7 @@ namespace dlib ...@@ -534,6 +580,7 @@ namespace dlib
serialize(item.eps, out); serialize(item.eps, out);
serialize(item.C, out); serialize(item.C, out);
serialize(item.max_iter, out); serialize(item.max_iter, out);
serialize(item._use_identity_matrix_prior, out);
} }
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
...@@ -548,7 +595,7 @@ namespace dlib ...@@ -548,7 +595,7 @@ namespace dlib
{ {
int version = 0; int version = 0;
deserialize(version, in); deserialize(version, in);
if (version != 1) if (version != 1 && version != 2)
throw serialization_error("Unsupported version found while deserializing dlib::vector_normalizer_frobmetric."); throw serialization_error("Unsupported version found while deserializing dlib::vector_normalizer_frobmetric.");
deserialize(item.m, in); deserialize(item.m, in);
...@@ -557,6 +604,10 @@ namespace dlib ...@@ -557,6 +604,10 @@ namespace dlib
deserialize(item.eps, in); deserialize(item.eps, in);
deserialize(item.C, in); deserialize(item.C, in);
deserialize(item.max_iter, in); deserialize(item.max_iter, in);
if (version == 2)
deserialize(item._use_identity_matrix_prior, in);
else
item._use_identity_matrix_prior = false;
} }
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
......
...@@ -73,6 +73,7 @@ namespace dlib ...@@ -73,6 +73,7 @@ namespace dlib
- get_c() == 1 - get_c() == 1
- get_max_iterations() == 5000 - get_max_iterations() == 5000
- This object is not verbose - This object is not verbose
- uses_identity_matrix_prior() == false
WHAT THIS OBJECT REPRESENTS WHAT THIS OBJECT REPRESENTS
This object is a tool for performing the FrobMetric distance metric This object is a tool for performing the FrobMetric distance metric
...@@ -110,6 +111,27 @@ namespace dlib ...@@ -110,6 +111,27 @@ namespace dlib
- this object is properly initialized - this object is properly initialized
!*/ !*/
bool uses_identity_matrix_prior (
) const;
/*!
ensures
- Normally this object will try and find a matrix transform() that
minimizes sum(squared(transform())) but also fits the training data.
However, if #uses_identity_matrix_prior() == true then it will instead
try to find the transformation matrix that minimizes
sum(squared(identity_matrix()-transform())). That is, it will try to
find the matrix most similar to the identity matrix that best fits the
training data.
!*/
void set_uses_identity_matrix_prior (
bool use_prior
);
/*!
ensures
- #uses_identity_matrix_prior() == use_prior
!*/
void be_verbose( void be_verbose(
); );
/*! /*!
......
...@@ -692,10 +692,51 @@ namespace ...@@ -692,10 +692,51 @@ namespace
} }
void prior_frobnorm_test()
{
frobmetric_training_sample<matrix<double,0,1> > sample;
std::vector<frobmetric_training_sample<matrix<double,0,1> > > samples;
matrix<double,3,1> x, near, far;
x = 0,0,0;
near = 1,0,0;
far = 0,1,0;
sample.anchor_vect = x;
sample.near_vects.push_back(near);
sample.far_vects.push_back(far);
samples.push_back(sample);
vector_normalizer_frobmetric<matrix<double,0,1> > trainer;
trainer.set_c(100);
print_spinner();
trainer.train(samples);
matrix<double,3,3> correct;
correct = 0, 0, 0,
0, 1, 0,
0, 0, 0;
dlog << LDEBUG << trainer.transform();
DLIB_TEST(max(abs(trainer.transform()-correct)) < 1e-8);
trainer.set_uses_identity_matrix_prior(true);
print_spinner();
trainer.train(samples);
correct = 1, 0, 0,
0, 2, 0,
0, 0, 1;
dlog << LDEBUG << trainer.transform();
DLIB_TEST(max(abs(trainer.transform()-correct)) < 1e-8);
}
void perform_test ( void perform_test (
) )
{ {
prior_frobnorm_test();
dlib::rand rnd; dlib::rand rnd;
for (int i = 0; i < 5; ++i) for (int i = 0; i < 5; ++i)
test_vector_normalizer_frobmetric(rnd); test_vector_normalizer_frobmetric(rnd);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment