Commit 0ad2cb71 authored by Davis King's avatar Davis King

Gave the dnn_trainer a nice verbose mode

parent 758f606d
...@@ -6,6 +6,9 @@ ...@@ -6,6 +6,9 @@
#include "trainer_abstract.h" #include "trainer_abstract.h"
#include "core.h" #include "core.h"
#include "solvers.h" #include "solvers.h"
#include "../statistics.h"
#include "../console_progress_indicator.h"
#include <chrono>
namespace dlib namespace dlib
{ {
...@@ -84,6 +87,18 @@ namespace dlib ...@@ -84,6 +87,18 @@ namespace dlib
num_epochs = num; num_epochs = num;
} }
void be_verbose (
)
{
verbose = true;
}
void be_quiet (
)
{
verbose = false;
}
const sstack<solver_type,net_type::num_layers>& get_solvers ( const sstack<solver_type,net_type::num_layers>& get_solvers (
) const { return solvers; } ) const { return solvers; }
...@@ -101,8 +116,12 @@ namespace dlib ...@@ -101,8 +116,12 @@ namespace dlib
resizable_tensor t1, t2; resizable_tensor t1, t2;
console_progress_indicator pbar(num_epochs);
pbar.print_status(0);
for (unsigned long epoch_iteration = 0; epoch_iteration < num_epochs; ++epoch_iteration) for (unsigned long epoch_iteration = 0; epoch_iteration < num_epochs; ++epoch_iteration)
{ {
running_stats<double> rs;
unsigned long j = 0; unsigned long j = 0;
// Load two tensors worth of data at once so we can overlap the computation // Load two tensors worth of data at once so we can overlap the computation
...@@ -121,9 +140,11 @@ namespace dlib ...@@ -121,9 +140,11 @@ namespace dlib
} }
unsigned long i = 0; unsigned long i = 0;
using namespace std::chrono;
auto last_time = system_clock::now();
while (i < data.size()) while (i < data.size())
{ {
net.update(t1, labels.begin()+i, solvers); rs.add(net.update(t1, labels.begin()+i, solvers));
i += mini_batch_size; i += mini_batch_size;
if (j < data.size()) if (j < data.size())
{ {
...@@ -134,7 +155,7 @@ namespace dlib ...@@ -134,7 +155,7 @@ namespace dlib
if (i < data.size()) if (i < data.size())
{ {
net.update(t2, labels.begin()+i, solvers); rs.add(net.update(t2, labels.begin()+i, solvers));
i += mini_batch_size; i += mini_batch_size;
if (j < data.size()) if (j < data.size())
{ {
...@@ -144,6 +165,29 @@ namespace dlib ...@@ -144,6 +165,29 @@ namespace dlib
} }
} }
if (verbose)
{
auto now_time = system_clock::now();
if (now_time-last_time > seconds(20))
{
last_time = now_time;
auto iter = epoch_iteration + i/(double)data.size();
std::cout << "epoch: " << rpad(cast_to_string(iter),string_pad) << " "
<< "average loss: " << rpad(cast_to_string(rs.mean()),string_pad) << " ";
pbar.print_status(iter, true);
std::cout << std::endl;
}
}
}
if (verbose)
{
// Capitalize the E in Epoch so it's easy to grep out the lines that
// are for full epoch status statements.
std::cout << "Epoch: " << rpad(cast_to_string(epoch_iteration+1),string_pad) << " "
<< "average loss: " << rpad(cast_to_string(rs.mean()),string_pad) << " ";
pbar.print_status(epoch_iteration+1, true);
std::cout << std::endl;
} }
} }
return net; return net;
...@@ -161,8 +205,11 @@ namespace dlib ...@@ -161,8 +205,11 @@ namespace dlib
resizable_tensor t1, t2; resizable_tensor t1, t2;
console_progress_indicator pbar(num_epochs);
pbar.print_status(0);
for (unsigned long epoch_iteration = 0; epoch_iteration < num_epochs; ++epoch_iteration) for (unsigned long epoch_iteration = 0; epoch_iteration < num_epochs; ++epoch_iteration)
{ {
running_stats<double> rs;
unsigned long j = 0; unsigned long j = 0;
// Load two tensors worth of data at once so we can overlap the computation // Load two tensors worth of data at once so we can overlap the computation
...@@ -181,9 +228,11 @@ namespace dlib ...@@ -181,9 +228,11 @@ namespace dlib
} }
unsigned long i = 0; unsigned long i = 0;
using namespace std::chrono;
auto last_time = system_clock::now();
while (i < data.size()) while (i < data.size())
{ {
net.update(t1, solvers); rs.add(net.update(t1, solvers));
i += mini_batch_size; i += mini_batch_size;
if (j < data.size()) if (j < data.size())
{ {
...@@ -194,7 +243,7 @@ namespace dlib ...@@ -194,7 +243,7 @@ namespace dlib
if (i < data.size()) if (i < data.size())
{ {
net.update(t2, solvers); rs.add(net.update(t2, solvers));
i += mini_batch_size; i += mini_batch_size;
if (j < data.size()) if (j < data.size())
{ {
...@@ -204,6 +253,29 @@ namespace dlib ...@@ -204,6 +253,29 @@ namespace dlib
} }
} }
if (verbose)
{
auto now_time = system_clock::now();
if (now_time-last_time > seconds(20))
{
last_time = now_time;
auto iter = epoch_iteration + i/(double)data.size();
std::cout << "epoch: " << rpad(cast_to_string(iter),string_pad) << " "
<< "average loss: " << rpad(cast_to_string(rs.mean()),string_pad) << " ";
pbar.print_status(iter, true);
std::cout << std::endl;
}
}
}
if (verbose)
{
// Capitalize the E in Epoch so it's easy to grep out the lines that
// are for full epoch status statements.
std::cout << "Epoch: " << rpad(cast_to_string(epoch_iteration+1),string_pad) << " "
<< "average loss: " << rpad(cast_to_string(rs.mean()),string_pad) << " ";
pbar.print_status(epoch_iteration+1, true);
std::cout << std::endl;
} }
} }
return net; return net;
...@@ -215,10 +287,13 @@ namespace dlib ...@@ -215,10 +287,13 @@ namespace dlib
{ {
num_epochs = 300; num_epochs = 300;
mini_batch_size = 11; mini_batch_size = 11;
verbose = false;
} }
unsigned long num_epochs; unsigned long num_epochs;
unsigned long mini_batch_size; unsigned long mini_batch_size;
bool verbose;
const static long string_pad = 10;
net_type net; net_type net;
sstack<solver_type,net_type::num_layers> solvers; sstack<solver_type,net_type::num_layers> solvers;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment