Commit 6e06b0bd authored by Davis King's avatar Davis King

Made the distributed structural svm tools use the same improved job/buffering

rules as I recently added to the structural_svm_problem_threaded object.
parent 4f411d5a
......@@ -8,6 +8,8 @@
#include "structural_svm_problem.h"
#include "../bridge.h"
#include "../smart_pointers.h"
#include "../misc_api.h"
#include "../statistics.h"
#include "../threads.h"
......@@ -166,6 +168,11 @@ namespace dlib
tsu_in msg;
tsu_out temp;
timestamper ts;
running_stats<double> with_buffer_time;
running_stats<double> without_buffer_time;
unsigned long num_iterations_executed = 0;
while (in.dequeue(msg))
{
// initialize the cache and compute psi_true.
......@@ -198,6 +205,8 @@ namespace dlib
}
else if (msg.template contains<oracle_request<matrix_type> >())
{
++num_iterations_executed;
const oracle_request<matrix_type>& req = msg.template get<oracle_request<matrix_type> >();
oracle_response<matrix_type>& data = temp.template get<oracle_response<matrix_type> >();
......@@ -207,16 +216,35 @@ namespace dlib
data.num = problem.get_num_samples();
// how many samples to process in a single task (aim for 100 jobs per thread)
const long block_size = std::max<long>(1, data.num / (1+tp.num_threads_in_pool()*100));
// how many samples to process in a single task (aim for 4 jobs per worker)
const long num_workers = std::max(1UL, tp.num_threads_in_pool());
const long block_size = std::max(1L, data.num/(num_workers*4));
const uint64 start_time = ts.get_timestamp();
// pick fastest buffering strategy
bool buffer_subgradients_locally = with_buffer_time.mean() < without_buffer_time.mean();
// every 50 iterations we should try to flip the buffering scheme to see if
// doing it the other way might be better.
if ((num_iterations_executed%50) == 0)
{
buffer_subgradients_locally = !buffer_subgradients_locally;
}
binder b(*this, req, data);
binder b(*this, req, data, buffer_subgradients_locally);
for (long i = 0; i < data.num; i+=block_size)
{
tp.add_task(b, &binder::call_oracle, i, std::min(i + block_size, data.num));
}
tp.wait_for_all_tasks();
const uint64 stop_time = ts.get_timestamp();
if (buffer_subgradients_locally)
with_buffer_time.add(stop_time-start_time);
else
without_buffer_time.add(stop_time-start_time);
out.enqueue(temp);
}
}
......@@ -227,29 +255,39 @@ namespace dlib
binder (
const node_type& self_,
const impl::oracle_request<matrix_type>& req_,
impl::oracle_response<matrix_type>& data_
) : self(self_), req(req_), data(data_) {}
impl::oracle_response<matrix_type>& data_,
bool buffer_subgradients_locally_
) : self(self_), req(req_), data(data_),
buffer_subgradients_locally(buffer_subgradients_locally_) {}
void call_oracle (
long begin,
long end
)
{
// If we are only going to call the separation oracle once then
// don't run the slightly more complex for loop version of this code.
if (end-begin <= 1)
// If we are only going to call the separation oracle once then don't
// run the slightly more complex for loop version of this code. Or if
// we just don't want to run the complex buffering one. The code later
// on decides if we should do the buffering based on how long it takes
// to execute. We do this because, when the subgradient is really high
// dimensional it can take a lot of time to add them together. So we
// might want to avoid doing that.
if (end-begin <= 1 || !buffer_subgradients_locally)
{
scalar_type loss;
feature_vector_type ftemp;
self.cache[begin].separation_oracle_cached(req.skip_cache,
for (long i = begin; i < end; ++i)
{
self.cache[i].separation_oracle_cached(req.skip_cache,
req.cur_risk_lower_bound,
req.current_solution,
loss,
ftemp);
auto_mutex lock(self.accum_mutex);
data.loss += loss;
add_to(data.subgradient, ftemp);
auto_mutex lock(self.accum_mutex);
data.loss += loss;
add_to(data.subgradient, ftemp);
}
}
else
{
......@@ -280,6 +318,7 @@ namespace dlib
const node_type& self;
const impl::oracle_request<matrix_type>& req;
impl::oracle_response<matrix_type>& data;
bool buffer_subgradients_locally;
};
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment