Commit fe70bd12 authored by Davis King's avatar Davis King

Fixed spelling error in method name. Also optimized and cleaned up the

automatic step size reduction code a little.
parent 9485da18
......@@ -328,12 +328,14 @@ namespace dlib
rs.clear();
}
void set_setep_size (
void set_step_size (
double ss
)
{
DLIB_CASSERT(ss > 0,"");
wait_for_thread_to_pause();
if (step_size != ss)
previous_loss_values.clear();
step_size = ss;
}
......@@ -391,24 +393,33 @@ namespace dlib
resizable_tensor t;
};
template <typename T>
void run_update(job_t& next_job, const T&)
void record_loss(double loss)
{
double loss = net.update(next_job.t, next_job.labels.begin(), make_sstack(solvers),step_size);
// Say that we will check if the gradient is bad 200 times during each
// iter_between_step_size_adjust interval of network updates. This kind of
// budgeting causes our gradient checking to use a fixed amount of
// computational resources, regardless of the size of
// iter_between_step_size_adjust.
gradient_check_budget += 200;
rs.add(loss);
previous_loss_values.push_back(loss);
if (previous_loss_values.size() > iter_between_step_size_adjust)
previous_loss_values.pop_front();
}
template <typename T>
void run_update(job_t& next_job, const T&)
{
double loss = net.update(next_job.t, next_job.labels.begin(), make_sstack(solvers),step_size);
record_loss(loss);
}
void run_update(job_t& next_job, const no_label_type&)
{
no_label_type pick_wich_run_update;
double loss = net.update(next_job.t, make_sstack(solvers), step_size);
rs.add(loss);
previous_loss_values.push_back(loss);
if (previous_loss_values.size() > iter_between_step_size_adjust)
previous_loss_values.pop_front();
record_loss(loss);
}
void thread() try
......@@ -425,9 +436,14 @@ namespace dlib
run_update(next_job, pick_wich_run_update);
// If we have been running for a while then check if the loss is still
// dropping. If it isn't then we will reduce the step size.
if (previous_loss_values.size() >= iter_between_step_size_adjust)
// dropping. If it isn't then we will reduce the step size. Note that we
// have a "budget" that prevents us from calling
// probability_gradient_greater_than() every iteration. We do this because
// it can be expensive to compute when previous_loss_values is large.
if (previous_loss_values.size() >= iter_between_step_size_adjust &&
gradient_check_budget > previous_loss_values.size())
{
gradient_check_budget = 0;
if (probability_gradient_greater_than(previous_loss_values, 0) > 0.49)
{
step_size = step_size_shrink*step_size;
......@@ -458,12 +474,13 @@ namespace dlib
verbose = false;
cuda_device_id = dlib::cuda::get_device();
step_size = 1;
min_step_size = 1e-4;
min_step_size = 1e-3;
iter_between_step_size_adjust = 2000;
step_size_shrink = 0.1;
epoch_iteration = 0;
epoch_pos = 0;
train_one_step_calls = 0;
gradient_check_budget = 0;
start();
}
......@@ -575,7 +592,7 @@ namespace dlib
std::vector<solver_type> solvers;
std::atomic<double> step_size;
double min_step_size;
std::atomic<long> iter_between_step_size_adjust;
std::atomic<unsigned long> iter_between_step_size_adjust;
std::atomic<double> step_size_shrink;
std::chrono::time_point<std::chrono::system_clock> last_sync_time;
std::string sync_filename;
......@@ -584,6 +601,7 @@ namespace dlib
unsigned long epoch_pos;
std::chrono::time_point<std::chrono::system_clock> last_time;
unsigned long long train_one_step_calls;
unsigned long gradient_check_budget;
// The job object is not logically part of the state of this object. It is here
// only to avoid reallocating it over and over.
......
......@@ -60,7 +60,7 @@ namespace dlib
- #get_max_num_epochs() == 10000
- #get_mini_batch_size() == 128
- #get_step_size() == 1
- #get_min_step_size() == 1e-4
- #get_min_step_size() == 1e-3
- #get_iterations_between_step_size_adjust() == 2000
- #get_step_size_shrink() == 0.1
!*/
......@@ -149,7 +149,7 @@ namespace dlib
- #get_max_num_epochs() == num
!*/
void set_setep_size (
void set_step_size (
double ss
);
/*!
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment