Commit 3b75b335 authored by Davis King's avatar Davis King

Gave dnn_trainer the ability to train on out of core data by adding the

train_one_step() member function.  Also improved how the host to device transfers
are overlapped with kernel computation.
parent adec3eef
...@@ -1284,8 +1284,9 @@ namespace dlib ...@@ -1284,8 +1284,9 @@ namespace dlib
// "no label". So here we make the constructor private with the exception that // "no label". So here we make the constructor private with the exception that
// add_loss_layer objects can make it (again, just to simplify add_loss_layer's // add_loss_layer objects can make it (again, just to simplify add_loss_layer's
// implementation). // implementation).
no_label_type()=default; no_label_type(){};
template <typename LOSS_DETAILS, typename SUBNET> friend class add_loss_layer; template <typename LOSS_DETAILS, typename SUBNET> friend class add_loss_layer;
template < typename net_type, typename solver_type > friend class dnn_trainer;
}; };
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
......
...@@ -10,6 +10,23 @@ namespace dlib ...@@ -10,6 +10,23 @@ namespace dlib
namespace cuda namespace cuda
{ {
// -----------------------------------------------------------------------------------
void set_device (
int dev
)
{
CHECK_CUDA(cudaSetDevice(dev));
}
int get_device (
)
{
int dev = 0;
CHECK_CUDA(cudaGetDevice(&dev));
return dev;
}
// ----------------------------------------------------------------------------------- // -----------------------------------------------------------------------------------
__global__ void _cuda_multiply(float* d, const float* s, size_t n) __global__ void _cuda_multiply(float* d, const float* s, size_t n)
......
...@@ -3,7 +3,6 @@ ...@@ -3,7 +3,6 @@
#ifndef DLIB_DNN_CuDA_H_ #ifndef DLIB_DNN_CuDA_H_
#define DLIB_DNN_CuDA_H_ #define DLIB_DNN_CuDA_H_
#ifdef DLIB_USE_CUDA
#include "tensor.h" #include "tensor.h"
...@@ -12,6 +11,17 @@ namespace dlib ...@@ -12,6 +11,17 @@ namespace dlib
namespace cuda namespace cuda
{ {
#ifdef DLIB_USE_CUDA
// ----------------------------------------------------------------------------------------
void set_device (
int dev
);
int get_device (
);
// ----------------------------------------------------------------------------------- // -----------------------------------------------------------------------------------
void multiply ( void multiply (
...@@ -120,11 +130,24 @@ namespace dlib ...@@ -120,11 +130,24 @@ namespace dlib
); );
// ------------------------------------------------------------------------------------ // ------------------------------------------------------------------------------------
// ------------------------------------------------------------------------------------
// ------------------------------------------------------------------------------------
// ------------------------------------------------------------------------------------
#else // if DLIB_USE_CUDA NOT DEFINED
inline void set_device (
int dev
){}
inline int get_device (
){}
#endif // DLIB_USE_CUDA
} }
} }
#endif // DLIB_USE_CUDA
#endif // DLIB_DNN_CuDA_H_ #endif // DLIB_DNN_CuDA_H_
This diff is collapsed.
...@@ -192,6 +192,8 @@ namespace dlib ...@@ -192,6 +192,8 @@ namespace dlib
and having it execute 10 epochs of training during each call. This also and having it execute 10 epochs of training during each call. This also
means you can serialize a trainer to disk and then, at a later date, means you can serialize a trainer to disk and then, at a later date,
deserialize it and resume training your network where you left off. deserialize it and resume training your network where you left off.
- You can obtain the average loss value during the final training epoch by
calling get_average_loss().
!*/ !*/
const net_type& train ( const net_type& train (
...@@ -218,6 +220,67 @@ namespace dlib ...@@ -218,6 +220,67 @@ namespace dlib
and having it execute 10 epochs of training during each call. This also and having it execute 10 epochs of training during each call. This also
means you can serialize a trainer to disk and then, at a later date, means you can serialize a trainer to disk and then, at a later date,
deserialize it and resume training your network where you left off. deserialize it and resume training your network where you left off.
- You can obtain the average loss value during the final training epoch by
calling get_average_loss().
!*/
void train_one_step (
const std::vector<input_type>& data,
const std::vector<label_type>& labels
);
/*!
requires
- data.size() == labels.size()
- net_type uses a supervised loss.
i.e. net_type::label_type != no_label_type.
ensures
- Performs one stochastic gradient update step based on the mini-batch of
data and labels supplied to this function. In particular, calling
train_one_step() in a loop is equivalent to calling the train() method
defined above. However, train_one_step() allows you to stream data from
disk into the training process while train() requires you to first load
all the training data into RAM. Otherwise, these training methods are
equivalent.
- You can observe the current average loss value by calling get_average_loss().
!*/
void train_one_step (
const std::vector<input_type>& data
);
/*!
requires
- net_type uses an unsupervised loss.
i.e. net_type::label_type == no_label_type.
ensures
- Performs one stochastic gradient update step based on the mini-batch of
data supplied to this function. In particular, calling train_one_step()
in a loop is equivalent to calling the train() method defined above.
However, train_one_step() allows you to stream data from disk into the
training process while train() requires you to first load all the
training data into RAM. Otherwise, these training methods are
equivalent.
- You can observe the current average loss value by calling get_average_loss().
!*/
double get_average_loss (
) const;
/*!
ensures
- returns the average loss value observed during previous calls to
train_one_step() or train(). That is, the average output of
net_type::update() during the previous mini-batch updates.
!*/
void clear_average_loss (
);
/*!
ensures
- #get_average_loss() == 0
- get_average_loss() uses a dlib::running_stats object to keep a running
average of the loss values seen during the previous mini-batch updates
applied during training. Calling clear_average_loss() resets the
running_stats object so it forgets about all previous loss values
observed.
!*/ !*/
}; };
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment