Commit 5dbab3f5 authored by Davis King's avatar Davis King

Optimized this code by making it use the new ekm transformation function.

--HG--
extra : convert_revision : svn%3Afdd8eb12-d10e-0410-9acb-85c331704f74/trunk%403625
parent 2e1561d5
...@@ -425,6 +425,15 @@ namespace dlib ...@@ -425,6 +425,15 @@ namespace dlib
lisf.add(x(rnd.get_random_32bit_number()%x.size())); lisf.add(x(rnd.get_random_32bit_number()%x.size()));
} }
ekm.load(lisf);
// first project all samples into the span of the current basis
for (long i = 0; i < x.size(); ++i)
{
proj_samples[i] = ekm.project(x(i));
}
svm_c_linear_trainer<linear_kernel<matrix<scalar_type,0,1,mem_manager_type> > > trainer(ocas); svm_c_linear_trainer<linear_kernel<matrix<scalar_type,0,1,mem_manager_type> > > trainer(ocas);
const scalar_type min_epsilon = trainer.get_epsilon(); const scalar_type min_epsilon = trainer.get_epsilon();
...@@ -435,30 +444,13 @@ namespace dlib ...@@ -435,30 +444,13 @@ namespace dlib
scalar_type prev_svm_objective = std::numeric_limits<scalar_type>::max(); scalar_type prev_svm_objective = std::numeric_limits<scalar_type>::max();
empirical_kernel_map<kernel_type> prev_ekm;
// This loop is where we try to generate a basis for SVM training. We will // This loop is where we try to generate a basis for SVM training. We will
// do this by repeatedly training the SVM and adding a few points which violate the // do this by repeatedly training the SVM and adding a few points which violate the
// margin to the basis in each iteration. // margin to the basis in each iteration.
while (true) while (true)
{ {
running_stats<scalar_type> rs;
ekm.load(lisf);
// first project all samples into the span of the current basis
for (long i = 0; i < x.size(); ++i)
{
if (verbose)
{
scalar_type err;
proj_samples[i] = ekm.project(x(i), err);
rs.add(err);
}
else
{
proj_samples[i] = ekm.project(x(i));
}
}
// if the basis is already as big as it's going to get then just do the most // if the basis is already as big as it's going to get then just do the most
// accurate training right now. // accurate training right now.
if (lisf.dictionary_size() == max_basis_size) if (lisf.dictionary_size() == max_basis_size)
...@@ -478,7 +470,7 @@ namespace dlib ...@@ -478,7 +470,7 @@ namespace dlib
{ {
trainer.set_epsilon(std::max(trainer.get_epsilon()*0.5, min_epsilon)); trainer.set_epsilon(std::max(trainer.get_epsilon()*0.5, min_epsilon));
if (verbose) if (verbose)
std::cout << "Reducing epsilon to " << trainer.get_epsilon() << std::endl; std::cout << " *** Reducing epsilon to " << trainer.get_epsilon() << std::endl;
} }
else else
break; break;
...@@ -486,8 +478,6 @@ namespace dlib ...@@ -486,8 +478,6 @@ namespace dlib
if (verbose) if (verbose)
{ {
std::cout << "\nMean EKM projection error: " << rs.mean() << std::endl;
std::cout << "Standard deviaion of EKM projection error: " << rs.stddev() << std::endl;
std::cout << "svm objective: " << svm_objective << std::endl; std::cout << "svm objective: " << svm_objective << std::endl;
std::cout << "basis size: " << lisf.dictionary_size() << std::endl; std::cout << "basis size: " << lisf.dictionary_size() << std::endl;
} }
...@@ -497,6 +487,7 @@ namespace dlib ...@@ -497,6 +487,7 @@ namespace dlib
break; break;
prev_svm_objective = svm_objective; prev_svm_objective = svm_objective;
std::vector<sample_type> new_basis_elements;
// now add more elements to the basis // now add more elements to the basis
unsigned long count = 0; unsigned long count = 0;
...@@ -512,9 +503,12 @@ namespace dlib ...@@ -512,9 +503,12 @@ namespace dlib
// Add the sample into the basis set if it is linearly independent of all the // Add the sample into the basis set if it is linearly independent of all the
// vectors already in the basis set. // vectors already in the basis set.
if (lisf.add(x(idx))) if (lisf.add(x(idx)))
{
new_basis_elements.push_back(x(idx));
++count; ++count;
} }
} }
}
// if we couldn't add any more basis vectors then stop // if we couldn't add any more basis vectors then stop
if (count == 0) if (count == 0)
{ {
...@@ -522,19 +516,57 @@ namespace dlib ...@@ -522,19 +516,57 @@ namespace dlib
std::cout << "Stopping, couldn't add more basis vectors." << std::endl; std::cout << "Stopping, couldn't add more basis vectors." << std::endl;
break; break;
} }
// Project all the samples into the span of our newly enlarged basis. We will do this
// using the special transformation in the EKM that lets us project from a smaller
// basis set to a larger without needing to reevaluate kernel functions we have already
// computed.
ekm.swap(prev_ekm);
ekm.load(lisf);
projection_function<kernel_type> proj_part;
matrix<double> prev_to_new;
prev_ekm.get_transformation_to(ekm, prev_to_new, proj_part);
sample_type temp;
for (long i = 0; i < x.size(); ++i)
{
// assign to temporary to avoid memory allocation that would result if we
// assigned this expression straight into proj_samples[i]
temp = prev_to_new*proj_samples[i] + proj_part(x(i));
proj_samples[i] = temp;
}
} }
// if we haven't already done so then make sure to run training with the tight epsilon // Reproject all the data samples using the final basis. We could just use what we
// before we return our results. // already have but the recursive thing done above to compute the proj_samples
if (trainer.get_epsilon() > min_epsilon) // might have accumulated a little numerical error. So lets just be safe.
running_stats<scalar_type> rs;
for (long i = 0; i < x.size(); ++i)
{ {
if (verbose)
{
scalar_type err;
proj_samples[i] = ekm.project(x(i),err);
rs.add(err);
}
else
{
proj_samples[i] = ekm.project(x(i));
}
}
// do the final training
trainer.set_epsilon(min_epsilon); trainer.set_epsilon(min_epsilon);
df = trainer.train(proj_samples, y, svm_objective); df = trainer.train(proj_samples, y, svm_objective);
}
if (verbose) if (verbose)
{ {
std::cout << "\nMean EKM projection error: " << rs.mean() << std::endl;
std::cout << "Standard deviaion of EKM projection error: " << rs.stddev() << std::endl;
std::cout << "Final svm objective: " << svm_objective << std::endl; std::cout << "Final svm objective: " << svm_objective << std::endl;
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment