Commit 0b235fe5 authored by Davis King's avatar Davis King

Added the repeat layer and generally optimized the code for really deep

networks.  This revolved mostly around removing really deep template recursions
since that upsets the compiler when you make really deep networks.
parent 7991275e
This diff is collapsed.
This diff is collapsed.
......@@ -48,7 +48,7 @@ namespace dlib
dnn_trainer(
const net_type& net_,
const solver_type& solver_
) : job_pipe(0), net(net_), solvers(solver_)
) : job_pipe(0), net(net_), solvers(net_type::num_layers, solver_)
{
init();
}
......@@ -81,7 +81,7 @@ namespace dlib
)
{
wait_for_thread_to_pause();
solvers = solver_;
solvers = std::vector<solver_type>(net_type::num_layers, solver_);
}
unsigned long get_mini_batch_size (
......@@ -119,14 +119,14 @@ namespace dlib
}
const sstack<solver_type,net_type::num_layers>& get_solvers (
const std::vector<solver_type>& get_solvers (
) const
{
wait_for_thread_to_pause();
return solvers;
}
sstack<solver_type,net_type::num_layers>& get_solvers (
std::vector<solver_type>& get_solvers (
)
{
wait_for_thread_to_pause();
......@@ -260,7 +260,7 @@ namespace dlib
friend void serialize(const dnn_trainer& item, std::ostream& out)
{
item.wait_for_thread_to_pause();
int version = 1;
int version = 2;
serialize(version, out);
serialize(item.rs, out);
serialize(item.num_epochs, out);
......@@ -275,7 +275,7 @@ namespace dlib
item.wait_for_thread_to_pause();
int version = 0;
deserialize(version, in);
if (version != 1)
if (version != 2)
throw serialization_error("Unexpected version found while deserializing dlib::dnn_trainer.");
deserialize(item.rs, in);
deserialize(item.num_epochs, in);
......@@ -309,13 +309,13 @@ namespace dlib
template <typename T>
void run_update(job_t& next_job, const T&)
{
rs.add(net.update(next_job.t, next_job.labels.begin(), solvers));
rs.add(net.update(next_job.t, next_job.labels.begin(), make_sstack(solvers)));
}
void run_update(job_t& next_job, const no_label_type&)
{
no_label_type pick_wich_run_update;
rs.add(net.update(next_job.t, solvers));
rs.add(net.update(next_job.t, make_sstack(solvers)));
}
void thread()
......@@ -361,7 +361,7 @@ namespace dlib
int cuda_device_id;
net_type net;
sstack<solver_type,net_type::num_layers> solvers;
std::vector<solver_type> solvers;
};
// ----------------------------------------------------------------------------------------
......
......@@ -93,24 +93,30 @@ namespace dlib
assigned to each element in get_solvers().
!*/
const sstack<solver_type,net_type::num_layers>& get_solvers (
const std::vector<solver_type>& get_solvers (
) const;
/*!
ensures
- returns the solvers used to optimize each layer of the neural network
get_net(). In particular, the first layer's solver is
get_solvers().top(), the second layer's solver is
get_solvers().pop().top(), and so on.
get_solvers()[0], the second layer's solver is
get_solvers()[1], and so on.
!*/
sstack<solver_type,net_type::num_layers>& get_solvers (
std::vector<solver_type>& get_solvers (
);
/*!
ensures
- returns the solvers used to optimize each layer of the neural network
get_net(). In particular, the first layer's solver is
get_solvers().top(), the second layer's solver is
get_solvers().pop().top(), and so on.
get_solvers()[0], the second layer's solver is
get_solvers()[1], and so on.
- It should be noted that you should never change the number of elements in
the vector returned by get_solvers() (i.e. don't do something that
changes get_solvers().size()). It will be set to net_type::num_layers by
this object and you should leave it at that. The non-const version of
get_solvers() is provided only so you can tweak the parameters of a
particular solver.
!*/
unsigned long get_mini_batch_size (
......
......@@ -974,8 +974,8 @@ namespace
rcon_(6)
);
DLIB_TEST(layer<tag1>(net).num_layers == 9);
DLIB_TEST(layer<skip1>(net).num_layers == 9+3+3+1);
DLIB_TEST(layer<tag1>(net).num_layers == 8);
DLIB_TEST(layer<skip1>(net).num_layers == 8+3+3);
DLIB_TEST(&layer<skip1>(net).get_output() == &layer<tag1>(net).get_output());
DLIB_TEST(&layer<skip1>(net).get_output() != &layer<tag1>(net).subnet().subnet().get_output());
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment