Commit 160337da authored by Davis King's avatar Davis King

Made the one_vs_one_trainer and one_vs_all_trainer objects multithreaded

so they can run each binary trainer on a different core.
parent 525f2a52
...@@ -32,7 +32,6 @@ ...@@ -32,7 +32,6 @@
#include "svm/svr_trainer.h" #include "svm/svr_trainer.h"
#include "svm/one_vs_one_decision_function.h" #include "svm/one_vs_one_decision_function.h"
#include "svm/one_vs_one_trainer.h"
#include "svm/multiclass_tools.h" #include "svm/multiclass_tools.h"
#include "svm/cross_validate_multiclass_trainer.h" #include "svm/cross_validate_multiclass_trainer.h"
#include "svm/cross_validate_regression_trainer.h" #include "svm/cross_validate_regression_trainer.h"
...@@ -42,7 +41,6 @@ ...@@ -42,7 +41,6 @@
#include "svm/cross_validate_assignment_trainer.h" #include "svm/cross_validate_assignment_trainer.h"
#include "svm/one_vs_all_decision_function.h" #include "svm/one_vs_all_decision_function.h"
#include "svm/one_vs_all_trainer.h"
#include "svm/structural_svm_problem.h" #include "svm/structural_svm_problem.h"
#include "svm/sequence_labeler.h" #include "svm/sequence_labeler.h"
......
...@@ -5,7 +5,6 @@ ...@@ -5,7 +5,6 @@
#include <vector> #include <vector>
#include "../matrix.h" #include "../matrix.h"
#include "one_vs_one_trainer.h"
#include "cross_validate_multiclass_trainer_abstract.h" #include "cross_validate_multiclass_trainer_abstract.h"
#include <sstream> #include <sstream>
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include "../any.h" #include "../any.h"
#include <map> #include <map>
#include <set> #include <set>
#include "../threads.h"
namespace dlib namespace dlib
{ {
...@@ -39,7 +40,8 @@ namespace dlib ...@@ -39,7 +40,8 @@ namespace dlib
one_vs_all_trainer ( one_vs_all_trainer (
) : ) :
verbose(false) verbose(false),
num_threads(4)
{} {}
void set_trainer ( void set_trainer (
...@@ -70,6 +72,19 @@ namespace dlib ...@@ -70,6 +72,19 @@ namespace dlib
verbose = false; verbose = false;
} }
void set_num_threads (
unsigned long num
)
{
num_threads = num;
}
unsigned long get_num_threads (
) const
{
return num_threads;
}
struct invalid_label : public dlib::error struct invalid_label : public dlib::error
{ {
invalid_label(const std::string& msg, const label_type& l_ invalid_label(const std::string& msg, const label_type& l_
...@@ -96,62 +111,117 @@ namespace dlib ...@@ -96,62 +111,117 @@ namespace dlib
const std::vector<label_type> distinct_labels = select_all_distinct_labels(all_labels); const std::vector<label_type> distinct_labels = select_all_distinct_labels(all_labels);
std::vector<scalar_type> labels; // make sure we have a trainer object for each of the label types.
typename trained_function_type::binary_function_table dfs;
for (unsigned long i = 0; i < distinct_labels.size(); ++i) for (unsigned long i = 0; i < distinct_labels.size(); ++i)
{ {
labels.clear();
const label_type l = distinct_labels[i]; const label_type l = distinct_labels[i];
const typename binary_function_table::const_iterator itr = trainers.find(l);
// setup one of the one vs all training sets if (itr == trainers.end() && default_trainer.is_empty())
for (unsigned long k = 0; k < all_samples.size(); ++k)
{ {
if (all_labels[k] == l) std::ostringstream sout;
labels.push_back(+1); sout << "In one_vs_all_trainer, no trainer registered for the " << l << " label.";
else throw invalid_label(sout.str(), l);
labels.push_back(-1);
} }
}
if (verbose) // now do the training
{ parallel_for_helper helper(all_samples,all_labels,default_trainer,trainers,verbose,distinct_labels);
std::cout << "Training classifier for " << l << " vs. all" << std::endl; parallel_for(num_threads, 0, distinct_labels.size(), helper, 500);
}
// now train a binary classifier using the samples we selected if (helper.error_message.size() != 0)
const typename binary_function_table::const_iterator itr = trainers.find(l); {
throw dlib::error("binary trainer threw while training one vs. all classifier. Error was: " + helper.error_message);
}
return trained_function_type(helper.dfs);
}
if (itr != trainers.end()) private:
{
dfs[l] = itr->second.train(all_samples, labels); typedef std::map<label_type, any_trainer> binary_function_table;
} struct parallel_for_helper
else if (default_trainer.is_empty() == false) {
parallel_for_helper(
const std::vector<sample_type>& all_samples_,
const std::vector<label_type>& all_labels_,
const any_trainer& default_trainer_,
const binary_function_table& trainers_,
const bool verbose_,
const std::vector<label_type>& distinct_labels_
) :
all_samples(all_samples_),
all_labels(all_labels_),
default_trainer(default_trainer_),
trainers(trainers_),
verbose(verbose_),
distinct_labels(distinct_labels_)
{}
void operator()(long i) const
{
try
{ {
dfs[l] = default_trainer.train(all_samples, labels); std::vector<scalar_type> labels;
const label_type l = distinct_labels[i];
// setup one of the one vs all training sets
for (unsigned long k = 0; k < all_samples.size(); ++k)
{
if (all_labels[k] == l)
labels.push_back(+1);
else
labels.push_back(-1);
}
if (verbose)
{
auto_mutex lock(class_mutex);
std::cout << "Training classifier for " << l << " vs. all" << std::endl;
}
any_trainer trainer;
// now train a binary classifier using the samples we selected
{ auto_mutex lock(class_mutex);
const typename binary_function_table::const_iterator itr = trainers.find(l);
if (itr != trainers.end())
trainer = itr->second;
else
trainer = default_trainer;
}
any_decision_function<sample_type,scalar_type> binary_df = trainer.train(all_samples, labels);
auto_mutex lock(class_mutex);
dfs[l] = binary_df;
} }
else catch (std::exception& e)
{ {
std::ostringstream sout; auto_mutex lock(class_mutex);
sout << "In one_vs_all_trainer, no trainer registered for the " << l << " label."; error_message = e.what();
throw invalid_label(sout.str(), l);
} }
} }
return trained_function_type(dfs); mutable typename trained_function_type::binary_function_table dfs;
} mutex class_mutex;
mutable std::string error_message;
private: const std::vector<sample_type>& all_samples;
const std::vector<label_type>& all_labels;
const any_trainer& default_trainer;
const binary_function_table& trainers;
const bool verbose;
const std::vector<label_type>& distinct_labels;
};
any_trainer default_trainer; any_trainer default_trainer;
typedef std::map<label_type, any_trainer> binary_function_table;
binary_function_table trainers; binary_function_table trainers;
bool verbose; bool verbose;
unsigned long num_threads;
}; };
......
...@@ -55,10 +55,11 @@ namespace dlib ...@@ -55,10 +55,11 @@ namespace dlib
); );
/*! /*!
ensures ensures
- this object is properly initialized - This object is properly initialized.
- this object will not be verbose unless be_verbose() is called - This object will not be verbose unless be_verbose() is called.
- no binary trainers are associated with *this. I.e. you have to - No binary trainers are associated with *this. I.e. you have to
call set_trainer() before calling train() call set_trainer() before calling train().
- #get_num_threads() == 4
!*/ !*/
void set_trainer ( void set_trainer (
...@@ -96,6 +97,23 @@ namespace dlib ...@@ -96,6 +97,23 @@ namespace dlib
- this object will not print anything to standard out - this object will not print anything to standard out
!*/ !*/
void set_num_threads (
unsigned long num
);
/*!
ensures
- #get_num_threads() == num
!*/
unsigned long get_num_threads (
) const;
/*!
ensures
- returns the number of threads used during training. You should
usually set this equal to the number of processing cores on your
machine.
!*/
struct invalid_label : public dlib::error struct invalid_label : public dlib::error
{ {
/*! /*!
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#include "../any.h" #include "../any.h"
#include <map> #include <map>
#include <set> #include <set>
#include "../threads.h"
namespace dlib namespace dlib
{ {
...@@ -40,7 +41,8 @@ namespace dlib ...@@ -40,7 +41,8 @@ namespace dlib
one_vs_one_trainer ( one_vs_one_trainer (
) : ) :
verbose(false) verbose(false),
num_threads(4)
{} {}
void set_trainer ( void set_trainer (
...@@ -72,6 +74,19 @@ namespace dlib ...@@ -72,6 +74,19 @@ namespace dlib
verbose = false; verbose = false;
} }
void set_num_threads (
unsigned long num
)
{
num_threads = num;
}
unsigned long get_num_threads (
) const
{
return num_threads;
}
struct invalid_label : public dlib::error struct invalid_label : public dlib::error
{ {
invalid_label(const std::string& msg, const label_type& l1_, const label_type& l2_ invalid_label(const std::string& msg, const label_type& l1_, const label_type& l2_
...@@ -98,20 +113,70 @@ namespace dlib ...@@ -98,20 +113,70 @@ namespace dlib
const std::vector<label_type> distinct_labels = select_all_distinct_labels(all_labels); const std::vector<label_type> distinct_labels = select_all_distinct_labels(all_labels);
std::vector<sample_type> samples;
std::vector<scalar_type> labels;
typename trained_function_type::binary_function_table dfs;
// fill pairs with all the pairs of labels.
std::vector<unordered_pair<label_type> > pairs;
for (unsigned long i = 0; i < distinct_labels.size(); ++i) for (unsigned long i = 0; i < distinct_labels.size(); ++i)
{ {
for (unsigned long j = i+1; j < distinct_labels.size(); ++j) for (unsigned long j = i+1; j < distinct_labels.size(); ++j)
{ {
samples.clear(); pairs.push_back(unordered_pair<label_type>(distinct_labels[i], distinct_labels[j]));
labels.clear();
const unordered_pair<label_type> p(distinct_labels[i], distinct_labels[j]); // make sure we have a trainer for this pair
const typename binary_function_table::const_iterator itr = trainers.find(pairs.back());
if (itr == trainers.end() && default_trainer.is_empty())
{
std::ostringstream sout;
sout << "In one_vs_one_trainer, no trainer registered for the ("
<< pairs.back().first << ", " << pairs.back().second << ") label pair.";
throw invalid_label(sout.str(), pairs.back().first, pairs.back().second);
}
}
}
// Now train on all the label pairs.
parallel_for_helper helper(all_samples,all_labels,default_trainer,trainers,verbose,pairs);
parallel_for(num_threads, 0, pairs.size(), helper, 500);
if (helper.error_message.size() != 0)
{
throw dlib::error("binary trainer threw while training one vs. one classifier. Error was: " + helper.error_message);
}
return trained_function_type(helper.dfs);
}
private:
typedef std::map<unordered_pair<label_type>, any_trainer> binary_function_table;
struct parallel_for_helper
{
parallel_for_helper(
const std::vector<sample_type>& all_samples_,
const std::vector<label_type>& all_labels_,
const any_trainer& default_trainer_,
const binary_function_table& trainers_,
const bool verbose_,
const std::vector<unordered_pair<label_type> >& pairs_
) :
all_samples(all_samples_),
all_labels(all_labels_),
default_trainer(default_trainer_),
trainers(trainers_),
verbose(verbose_),
pairs(pairs_)
{}
void operator()(long i) const
{
try
{
std::vector<sample_type> samples;
std::vector<scalar_type> labels;
const unordered_pair<label_type> p = pairs[i];
// pick out the samples corresponding to these two classes // pick out the samples corresponding to these two classes
for (unsigned long k = 0; k < all_samples.size(); ++k) for (unsigned long k = 0; k < all_samples.size(); ++k)
...@@ -128,43 +193,51 @@ namespace dlib ...@@ -128,43 +193,51 @@ namespace dlib
} }
} }
if (verbose) if (verbose)
{ {
auto_mutex lock(class_mutex);
std::cout << "Training classifier for " << p.first << " vs. " << p.second << std::endl; std::cout << "Training classifier for " << p.first << " vs. " << p.second << std::endl;
} }
any_trainer trainer;
// now train a binary classifier using the samples we selected // now train a binary classifier using the samples we selected
{ auto_mutex lock(class_mutex);
const typename binary_function_table::const_iterator itr = trainers.find(p); const typename binary_function_table::const_iterator itr = trainers.find(p);
if (itr != trainers.end()) if (itr != trainers.end())
{ trainer = itr->second;
dfs[p] = itr->second.train(samples, labels); else
} trainer = default_trainer;
else if (default_trainer.is_empty() == false)
{
dfs[p] = default_trainer.train(samples, labels);
}
else
{
std::ostringstream sout;
sout << "In one_vs_one_trainer, no trainer registered for the (" << p.first << ", " << p.second << ") label pair.";
throw invalid_label(sout.str(), p.first, p.second);
} }
any_decision_function<sample_type,scalar_type> binary_df = trainer.train(samples, labels);
auto_mutex lock(class_mutex);
dfs[p] = binary_df;
}
catch (std::exception& e)
{
auto_mutex lock(class_mutex);
error_message = e.what();
} }
} }
return trained_function_type(dfs); mutable typename trained_function_type::binary_function_table dfs;
} mutex class_mutex;
mutable std::string error_message;
private: const std::vector<sample_type>& all_samples;
const std::vector<label_type>& all_labels;
const any_trainer& default_trainer;
const binary_function_table& trainers;
const bool verbose;
const std::vector<unordered_pair<label_type> >& pairs;
};
any_trainer default_trainer; any_trainer default_trainer;
typedef std::map<unordered_pair<label_type>, any_trainer> binary_function_table;
binary_function_table trainers; binary_function_table trainers;
bool verbose; bool verbose;
unsigned long num_threads;
}; };
......
...@@ -55,10 +55,11 @@ namespace dlib ...@@ -55,10 +55,11 @@ namespace dlib
); );
/*! /*!
ensures ensures
- this object is properly initialized - This object is properly initialized
- this object will not be verbose unless be_verbose() is called - This object will not be verbose unless be_verbose() is called.
- no binary trainers are associated with *this. I.e. you have to - No binary trainers are associated with *this. I.e. you have to
call set_trainer() before calling train() call set_trainer() before calling train().
- #get_num_threads() == 4
!*/ !*/
void set_trainer ( void set_trainer (
...@@ -99,6 +100,23 @@ namespace dlib ...@@ -99,6 +100,23 @@ namespace dlib
- this object will not print anything to standard out - this object will not print anything to standard out
!*/ !*/
void set_num_threads (
unsigned long num
);
/*!
ensures
- #get_num_threads() == num
!*/
unsigned long get_num_threads (
) const;
/*!
ensures
- returns the number of threads used during training. You should
usually set this equal to the number of processing cores on your
machine.
!*/
struct invalid_label : public dlib::error struct invalid_label : public dlib::error
{ {
/*! /*!
......
...@@ -19,6 +19,8 @@ ...@@ -19,6 +19,8 @@
#include "svm/structural_graph_labeling_trainer.h" #include "svm/structural_graph_labeling_trainer.h"
#include "svm/cross_validate_graph_labeling_trainer.h" #include "svm/cross_validate_graph_labeling_trainer.h"
#include "svm/svm_multiclass_linear_trainer.h" #include "svm/svm_multiclass_linear_trainer.h"
#include "svm/one_vs_one_trainer.h"
#include "svm/one_vs_all_trainer.h"
#endif // DLIB_SVm_THREADED_HEADER #endif // DLIB_SVm_THREADED_HEADER
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
// License: Boost Software License See LICENSE.txt for the full license. // License: Boost Software License See LICENSE.txt for the full license.
#include "tester.h" #include "tester.h"
#include <dlib/svm.h> #include <dlib/svm_threaded.h>
#include <vector> #include <vector>
#include <sstream> #include <sstream>
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
// License: Boost Software License See LICENSE.txt for the full license. // License: Boost Software License See LICENSE.txt for the full license.
#include "tester.h" #include "tester.h"
#include <dlib/svm.h> #include <dlib/svm_threaded.h>
#include <dlib/statistics.h> #include <dlib/statistics.h>
#include <vector> #include <vector>
#include <sstream> #include <sstream>
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment