Commit f9cb3150 authored by Davis King's avatar Davis King

upgraded to cudnn v5. Also changed the affine_ layer to not be templated but

to automatically select the right mode.  The serialization format for bn_
layers has also changed, but the code will still be able to deserialize older
bn_ objects.
parent c81825c3
...@@ -445,7 +445,7 @@ if (NOT TARGET dlib) ...@@ -445,7 +445,7 @@ if (NOT TARGET dlib)
if (DLIB_USE_CUDA) if (DLIB_USE_CUDA)
find_package(CUDA 7.0) find_package(CUDA 7.5)
if (CUDA_FOUND AND COMPILER_CAN_DO_CPP_11) if (CUDA_FOUND AND COMPILER_CAN_DO_CPP_11)
...@@ -505,7 +505,7 @@ if (NOT TARGET dlib) ...@@ -505,7 +505,7 @@ if (NOT TARGET dlib)
set(DLIB_USE_CUDA OFF CACHE STRING ${DLIB_USE_BLAS_STR} FORCE ) set(DLIB_USE_CUDA OFF CACHE STRING ${DLIB_USE_BLAS_STR} FORCE )
toggle_preprocessor_switch(DLIB_USE_CUDA) toggle_preprocessor_switch(DLIB_USE_CUDA)
if (NOT cudnn OR NOT cudnn_include OR NOT cudnn_test_compile_worked) if (NOT cudnn OR NOT cudnn_include OR NOT cudnn_test_compile_worked)
message(STATUS "*** cuDNN V4.0 OR GREATER NOT FOUND. DLIB WILL NOT USE CUDA. ***") message(STATUS "*** cuDNN V5.0 OR GREATER NOT FOUND. DLIB WILL NOT USE CUDA. ***")
message(STATUS "*** If you have cuDNN then set CMAKE_PREFIX_PATH to include cuDNN's folder.") message(STATUS "*** If you have cuDNN then set CMAKE_PREFIX_PATH to include cuDNN's folder.")
endif() endif()
if (NOT COMPILER_CAN_DO_CPP_11) if (NOT COMPILER_CAN_DO_CPP_11)
......
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
// This file contains CPU implementations of the GPU based functions in cuda_dlib.h // This file contains CPU implementations of the GPU based functions in cuda_dlib.h
#include "cpu_dlib.h" #include "cpu_dlib.h"
#include "tensor_tools.h"
namespace dlib namespace dlib
{ {
...@@ -510,7 +511,7 @@ namespace dlib ...@@ -510,7 +511,7 @@ namespace dlib
{ {
for (long k = 0; k < num; ++k) for (long k = 0; k < num; ++k)
{ {
*d = g[k]*(*s - m[k])*i[k] + b[k]; *d = g[k]*(*s - m[k])/std::sqrt(i[k]+dlib::tt::BATCH_NORM_EPS) + b[k];
++d; ++d;
++s; ++s;
} }
...@@ -579,10 +580,18 @@ namespace dlib ...@@ -579,10 +580,18 @@ namespace dlib
invstds.host(); means.host(); invstds.host(); means.host();
// compute variances // compute variances
running_invstds.copy_size(invstds);
auto rvar = running_invstds.host();
const double scale = (src.num_samples())/(src.num_samples()-1.0);
for (long i = 0; i < num; ++i) for (long i = 0; i < num; ++i)
{ {
auto actual_var = p_invstds[i] - p_means[i]*p_means[i]; auto actual_var = p_invstds[i] - p_means[i]*p_means[i];
p_invstds[i] = 1.0f/std::sqrt(actual_var+BATCH_NORM_EPS); if (averaging_factor == 1)
rvar[i] = scale*actual_var;
else
rvar[i] = (1-averaging_factor)*rvar[i] + scale*averaging_factor*actual_var;
p_invstds[i] = 1.0f/std::sqrt(actual_var + dlib::tt::BATCH_NORM_EPS);
} }
p_src = src.host(); p_src = src.host();
...@@ -600,19 +609,12 @@ namespace dlib ...@@ -600,19 +609,12 @@ namespace dlib
} }
} }
// now keep track of the running means and invstds // now keep track of the running means
running_means.copy_size(means); running_means.copy_size(means);
running_invstds.copy_size(invstds);
if (averaging_factor != 1) if (averaging_factor != 1)
{
running_means = (1-averaging_factor)*mat(running_means) + averaging_factor*mat(means); running_means = (1-averaging_factor)*mat(running_means) + averaging_factor*mat(means);
running_invstds = (1-averaging_factor)*mat(running_invstds) + averaging_factor*mat(invstds);
}
else else
{
running_means = means; running_means = means;
running_invstds = invstds;
}
} }
void batch_normalize_gradient ( void batch_normalize_gradient (
...@@ -761,9 +763,10 @@ namespace dlib ...@@ -761,9 +763,10 @@ namespace dlib
{ {
for (long k = 0; k < src.k(); ++k) for (long k = 0; k < src.k(); ++k)
{ {
const float invstd = 1.0f/std::sqrt(i[k] + dlib::tt::BATCH_NORM_EPS);
for (long j = 0; j < num; ++j) for (long j = 0; j < num; ++j)
{ {
*d = g[k]*(*s - m[k])*i[k] + b[k]; *d = g[k]*(*s - m[k])*invstd + b[k];
++d; ++d;
++s; ++s;
} }
...@@ -841,10 +844,18 @@ namespace dlib ...@@ -841,10 +844,18 @@ namespace dlib
p_src = src.host(); p_src = src.host();
// compute variances // compute variances
running_invstds.copy_size(invstds);
auto rvar = running_invstds.host();
const double scale = (src.num_samples()*num)/(src.num_samples()*num-1.0);
for (long k = 0; k < src.k(); ++k) for (long k = 0; k < src.k(); ++k)
{ {
float actual_var = p_invstds[k] - p_means[k]*p_means[k]; float actual_var = p_invstds[k] - p_means[k]*p_means[k];
p_invstds[k] = 1.0f/std::sqrt(actual_var + BATCH_NORM_EPS); if (averaging_factor == 1)
rvar[k] = scale*actual_var;
else
rvar[k] = (1-averaging_factor)*rvar[k] + scale*averaging_factor*actual_var;
p_invstds[k] = 1.0f/std::sqrt(actual_var + dlib::tt::BATCH_NORM_EPS);
} }
p_src = src.host(); p_src = src.host();
...@@ -863,19 +874,12 @@ namespace dlib ...@@ -863,19 +874,12 @@ namespace dlib
} }
} }
// now keep track of the running means and invstds // now keep track of the running means
running_means.copy_size(means); running_means.copy_size(means);
running_invstds.copy_size(invstds);
if (averaging_factor != 1) if (averaging_factor != 1)
{
running_means = (1-averaging_factor)*mat(running_means) + averaging_factor*mat(means); running_means = (1-averaging_factor)*mat(running_means) + averaging_factor*mat(means);
running_invstds = (1-averaging_factor)*mat(running_invstds) + averaging_factor*mat(invstds);
}
else else
{
running_means = means; running_means = means;
running_invstds = invstds;
}
} }
void batch_normalize_conv_gradient( void batch_normalize_conv_gradient(
......
...@@ -13,10 +13,6 @@ namespace dlib ...@@ -13,10 +13,6 @@ namespace dlib
namespace cpu namespace cpu
{ {
// ----------------------------------------------------------------------------------------
const double BATCH_NORM_EPS = 0.00001;
// ----------------------------------------------------------------------------------- // -----------------------------------------------------------------------------------
void multiply ( void multiply (
......
...@@ -112,6 +112,57 @@ namespace dlib ...@@ -112,6 +112,57 @@ namespace dlib
return c.get_handle(); return c.get_handle();
} }
// ------------------------------------------------------------------------------------
class cudnn_activation_descriptor
{
public:
// not copyable
cudnn_activation_descriptor(const cudnn_activation_descriptor&) = delete;
cudnn_activation_descriptor& operator=(const cudnn_activation_descriptor&) = delete;
cudnn_activation_descriptor(
cudnnActivationMode_t mode,
cudnnNanPropagation_t reluNanOpt,
double reluCeiling
)
{
CHECK_CUDNN(cudnnCreateActivationDescriptor(&handle));
CHECK_CUDNN(cudnnSetActivationDescriptor(handle, mode, reluNanOpt, reluCeiling));
}
~cudnn_activation_descriptor()
{
cudnnDestroyActivationDescriptor(handle);
}
cudnnActivationDescriptor_t get_handle (
)
{
return handle;
}
private:
cudnnActivationDescriptor_t handle;
};
static cudnnActivationDescriptor_t relu_activation_descriptor()
{
thread_local cudnn_activation_descriptor des(CUDNN_ACTIVATION_RELU, CUDNN_PROPAGATE_NAN,0);
return des.get_handle();
}
static cudnnActivationDescriptor_t sigmoid_activation_descriptor()
{
thread_local cudnn_activation_descriptor des(CUDNN_ACTIVATION_SIGMOID, CUDNN_PROPAGATE_NAN,0);
return des.get_handle();
}
static cudnnActivationDescriptor_t tanh_activation_descriptor()
{
thread_local cudnn_activation_descriptor des(CUDNN_ACTIVATION_TANH, CUDNN_PROPAGATE_NAN,0);
return des.get_handle();
}
// ------------------------------------------------------------------------------------ // ------------------------------------------------------------------------------------
tensor_descriptor:: tensor_descriptor::
...@@ -223,7 +274,7 @@ namespace dlib ...@@ -223,7 +274,7 @@ namespace dlib
return; return;
} }
CHECK_CUDNN(cudnnAddTensor_v3(context(), CHECK_CUDNN(cudnnAddTensor(context(),
&alpha, &alpha,
descriptor(src), descriptor(src),
src.device(), src.device(),
...@@ -342,7 +393,7 @@ namespace dlib ...@@ -342,7 +393,7 @@ namespace dlib
beta.device(), beta.device(),
running_means.device(), running_means.device(),
running_invstds.device(), running_invstds.device(),
dlib::cpu::BATCH_NORM_EPS)); dlib::tt::BATCH_NORM_EPS));
} }
void batch_normalize ( void batch_normalize (
...@@ -404,7 +455,7 @@ namespace dlib ...@@ -404,7 +455,7 @@ namespace dlib
averaging_factor, averaging_factor,
running_means.device(), running_means.device(),
running_invstds.device(), running_invstds.device(),
dlib::cpu::BATCH_NORM_EPS, dlib::tt::BATCH_NORM_EPS,
means.device(), means.device(),
invstds.device())); invstds.device()));
} }
...@@ -452,7 +503,7 @@ namespace dlib ...@@ -452,7 +503,7 @@ namespace dlib
gamma.device(), gamma.device(),
gamma_grad.device(), gamma_grad.device(),
beta_grad.device(), beta_grad.device(),
dlib::cpu::BATCH_NORM_EPS, dlib::tt::BATCH_NORM_EPS,
means.device(), means.device(),
invstds.device())); invstds.device()));
} }
...@@ -515,7 +566,7 @@ namespace dlib ...@@ -515,7 +566,7 @@ namespace dlib
beta.device(), beta.device(),
running_means.device(), running_means.device(),
running_invstds.device(), running_invstds.device(),
dlib::cpu::BATCH_NORM_EPS)); dlib::tt::BATCH_NORM_EPS));
} }
void batch_normalize_conv ( void batch_normalize_conv (
...@@ -578,7 +629,7 @@ namespace dlib ...@@ -578,7 +629,7 @@ namespace dlib
averaging_factor, averaging_factor,
running_means.device(), running_means.device(),
running_invstds.device(), running_invstds.device(),
dlib::cpu::BATCH_NORM_EPS, dlib::tt::BATCH_NORM_EPS,
means.device(), means.device(),
invstds.device())); invstds.device()));
} }
...@@ -625,7 +676,7 @@ namespace dlib ...@@ -625,7 +676,7 @@ namespace dlib
gamma.device(), gamma.device(),
gamma_grad.device(), gamma_grad.device(),
beta_grad.device(), beta_grad.device(),
dlib::cpu::BATCH_NORM_EPS, dlib::tt::BATCH_NORM_EPS,
means.device(), means.device(),
invstds.device())); invstds.device()));
} }
...@@ -739,6 +790,7 @@ namespace dlib ...@@ -739,6 +790,7 @@ namespace dlib
CHECK_CUDNN(cudnnCreateFilterDescriptor((cudnnFilterDescriptor_t*)&filter_handle)); CHECK_CUDNN(cudnnCreateFilterDescriptor((cudnnFilterDescriptor_t*)&filter_handle));
CHECK_CUDNN(cudnnSetFilter4dDescriptor((cudnnFilterDescriptor_t)filter_handle, CHECK_CUDNN(cudnnSetFilter4dDescriptor((cudnnFilterDescriptor_t)filter_handle,
CUDNN_DATA_FLOAT, CUDNN_DATA_FLOAT,
CUDNN_TENSOR_NCHW,
filters.num_samples(), filters.num_samples(),
filters.k(), filters.k(),
filters.nr(), filters.nr(),
...@@ -900,7 +952,7 @@ namespace dlib ...@@ -900,7 +952,7 @@ namespace dlib
const float beta = 1; const float beta = 1;
CHECK_CUDNN(cudnnConvolutionBackwardData_v3(context(), CHECK_CUDNN(cudnnConvolutionBackwardData(context(),
&alpha, &alpha,
(const cudnnFilterDescriptor_t)filter_handle, (const cudnnFilterDescriptor_t)filter_handle,
filters.device(), filters.device(),
...@@ -924,7 +976,7 @@ namespace dlib ...@@ -924,7 +976,7 @@ namespace dlib
{ {
const float alpha = 1; const float alpha = 1;
const float beta = 0; const float beta = 0;
CHECK_CUDNN(cudnnConvolutionBackwardFilter_v3(context(), CHECK_CUDNN(cudnnConvolutionBackwardFilter(context(),
&alpha, &alpha,
descriptor(data), descriptor(data),
data.device(), data.device(),
...@@ -1020,6 +1072,7 @@ namespace dlib ...@@ -1020,6 +1072,7 @@ namespace dlib
CHECK_CUDNN(cudnnSetPooling2dDescriptor(poolingDesc, CHECK_CUDNN(cudnnSetPooling2dDescriptor(poolingDesc,
(cudnnPoolingMode_t)pooling_mode, (cudnnPoolingMode_t)pooling_mode,
CUDNN_PROPAGATE_NAN,
window_height, window_height,
window_width, window_width,
window_height/2, window_height/2,
...@@ -1176,7 +1229,7 @@ namespace dlib ...@@ -1176,7 +1229,7 @@ namespace dlib
const float alpha = 1; const float alpha = 1;
const float beta = 0; const float beta = 0;
CHECK_CUDNN(cudnnActivationForward(context(), CHECK_CUDNN(cudnnActivationForward(context(),
CUDNN_ACTIVATION_SIGMOID, sigmoid_activation_descriptor(),
&alpha, &alpha,
descriptor(src), descriptor(src),
src.device(), src.device(),
...@@ -1200,7 +1253,7 @@ namespace dlib ...@@ -1200,7 +1253,7 @@ namespace dlib
const float alpha = 1; const float alpha = 1;
const float beta = 0; const float beta = 0;
CHECK_CUDNN(cudnnActivationBackward(context(), CHECK_CUDNN(cudnnActivationBackward(context(),
CUDNN_ACTIVATION_SIGMOID, sigmoid_activation_descriptor(),
&alpha, &alpha,
descriptor(dest), descriptor(dest),
dest.device(), dest.device(),
...@@ -1227,7 +1280,7 @@ namespace dlib ...@@ -1227,7 +1280,7 @@ namespace dlib
const float alpha = 1; const float alpha = 1;
const float beta = 0; const float beta = 0;
CHECK_CUDNN(cudnnActivationForward(context(), CHECK_CUDNN(cudnnActivationForward(context(),
CUDNN_ACTIVATION_RELU, relu_activation_descriptor(),
&alpha, &alpha,
descriptor(src), descriptor(src),
src.device(), src.device(),
...@@ -1251,7 +1304,7 @@ namespace dlib ...@@ -1251,7 +1304,7 @@ namespace dlib
const float alpha = 1; const float alpha = 1;
const float beta = 0; const float beta = 0;
CHECK_CUDNN(cudnnActivationBackward(context(), CHECK_CUDNN(cudnnActivationBackward(context(),
CUDNN_ACTIVATION_RELU, relu_activation_descriptor(),
&alpha, &alpha,
descriptor(dest), descriptor(dest),
dest.device(), dest.device(),
...@@ -1278,7 +1331,7 @@ namespace dlib ...@@ -1278,7 +1331,7 @@ namespace dlib
const float alpha = 1; const float alpha = 1;
const float beta = 0; const float beta = 0;
CHECK_CUDNN(cudnnActivationForward(context(), CHECK_CUDNN(cudnnActivationForward(context(),
CUDNN_ACTIVATION_TANH, tanh_activation_descriptor(),
&alpha, &alpha,
descriptor(src), descriptor(src),
src.device(), src.device(),
...@@ -1302,7 +1355,7 @@ namespace dlib ...@@ -1302,7 +1355,7 @@ namespace dlib
const float alpha = 1; const float alpha = 1;
const float beta = 0; const float beta = 0;
CHECK_CUDNN(cudnnActivationBackward(context(), CHECK_CUDNN(cudnnActivationBackward(context(),
CUDNN_ACTIVATION_TANH, tanh_activation_descriptor(),
&alpha, &alpha,
descriptor(dest), descriptor(dest),
dest.device(), dest.device(),
......
...@@ -133,7 +133,7 @@ namespace dlib ...@@ -133,7 +133,7 @@ namespace dlib
std::string version; std::string version;
deserialize(version, in); deserialize(version, in);
if (version != "con_") if (version != "con_")
throw serialization_error("Unexpected version found while deserializing dlib::con_."); throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::con_.");
deserialize(item.params, in); deserialize(item.params, in);
...@@ -258,7 +258,7 @@ namespace dlib ...@@ -258,7 +258,7 @@ namespace dlib
std::string version; std::string version;
deserialize(version, in); deserialize(version, in);
if (version != "max_pool_") if (version != "max_pool_")
throw serialization_error("Unexpected version found while deserializing dlib::max_pool_."); throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::max_pool_.");
item.mp.setup_max_pooling(_nr, _nc, _stride_y, _stride_x); item.mp.setup_max_pooling(_nr, _nc, _stride_y, _stride_x);
...@@ -374,7 +374,7 @@ namespace dlib ...@@ -374,7 +374,7 @@ namespace dlib
std::string version; std::string version;
deserialize(version, in); deserialize(version, in);
if (version != "avg_pool_") if (version != "avg_pool_")
throw serialization_error("Unexpected version found while deserializing dlib::avg_pool_."); throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::avg_pool_.");
item.ap.setup_avg_pooling(_nr, _nc, _stride_y, _stride_x); item.ap.setup_avg_pooling(_nr, _nc, _stride_y, _stride_x);
...@@ -500,7 +500,10 @@ namespace dlib ...@@ -500,7 +500,10 @@ namespace dlib
friend void serialize(const bn_& item, std::ostream& out) friend void serialize(const bn_& item, std::ostream& out)
{ {
serialize("bn_", out); if (mode == CONV_MODE)
serialize("bn_con", out);
else // if FC_MODE
serialize("bn_fc", out);
serialize(item.params, out); serialize(item.params, out);
serialize(item.gamma, out); serialize(item.gamma, out);
serialize(item.beta, out); serialize(item.beta, out);
...@@ -510,7 +513,6 @@ namespace dlib ...@@ -510,7 +513,6 @@ namespace dlib
serialize(item.running_invstds, out); serialize(item.running_invstds, out);
serialize(item.num_updates, out); serialize(item.num_updates, out);
serialize(item.running_stats_window_size, out); serialize(item.running_stats_window_size, out);
serialize((int)mode, out);
} }
friend void deserialize(bn_& item, std::istream& in) friend void deserialize(bn_& item, std::istream& in)
...@@ -518,7 +520,19 @@ namespace dlib ...@@ -518,7 +520,19 @@ namespace dlib
std::string version; std::string version;
deserialize(version, in); deserialize(version, in);
if (version != "bn_") if (version != "bn_")
throw serialization_error("Unexpected version found while deserializing dlib::bn_."); {
if (mode == CONV_MODE)
{
if (version != "bn_con")
throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::bn_.");
}
else // must be in FC_MODE
{
if (version != "bn_fc")
throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::bn_.");
}
}
deserialize(item.params, in); deserialize(item.params, in);
deserialize(item.gamma, in); deserialize(item.gamma, in);
deserialize(item.beta, in); deserialize(item.beta, in);
...@@ -528,14 +542,23 @@ namespace dlib ...@@ -528,14 +542,23 @@ namespace dlib
deserialize(item.running_invstds, in); deserialize(item.running_invstds, in);
deserialize(item.num_updates, in); deserialize(item.num_updates, in);
deserialize(item.running_stats_window_size, in); deserialize(item.running_stats_window_size, in);
int _mode;
deserialize(_mode, in); // if this is the older "bn_" version then check its saved mode value and make
if (mode != (layer_mode)_mode) throw serialization_error("Wrong mode found while deserializing dlib::bn_"); // sure it is the one we are really using.
if (version == "bn_")
{
int _mode;
deserialize(_mode, in);
if (mode != (layer_mode)_mode) throw serialization_error("Wrong mode found while deserializing dlib::bn_");
// We also need to flip the running_invstds around since the previous
// format saved the inverse standard deviations instead of variances.
item.running_invstds = 1.0f/squared(mat(item.running_invstds)) - tt::BATCH_NORM_EPS;
}
} }
private: private:
template < layer_mode Mode >
friend class affine_; friend class affine_;
resizable_tensor params; resizable_tensor params;
...@@ -660,7 +683,7 @@ namespace dlib ...@@ -660,7 +683,7 @@ namespace dlib
std::string version; std::string version;
deserialize(version, in); deserialize(version, in);
if (version != "fc_") if (version != "fc_")
throw serialization_error("Unexpected version found while deserializing dlib::fc_."); throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::fc_.");
deserialize(item.num_outputs, in); deserialize(item.num_outputs, in);
deserialize(item.num_inputs, in); deserialize(item.num_inputs, in);
...@@ -760,7 +783,7 @@ namespace dlib ...@@ -760,7 +783,7 @@ namespace dlib
std::string version; std::string version;
deserialize(version, in); deserialize(version, in);
if (version != "dropout_") if (version != "dropout_")
throw serialization_error("Unexpected version found while deserializing dlib::dropout_."); throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::dropout_.");
deserialize(item.drop_rate, in); deserialize(item.drop_rate, in);
deserialize(item.mask, in); deserialize(item.mask, in);
} }
...@@ -840,7 +863,7 @@ namespace dlib ...@@ -840,7 +863,7 @@ namespace dlib
} }
if (version != "multiply_") if (version != "multiply_")
throw serialization_error("Unexpected version found while deserializing dlib::multiply_."); throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::multiply_.");
deserialize(item.val, in); deserialize(item.val, in);
} }
...@@ -854,22 +877,30 @@ namespace dlib ...@@ -854,22 +877,30 @@ namespace dlib
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
template <
layer_mode mode
>
class affine_ class affine_
{ {
public: public:
affine_( affine_(
) ) : mode(FC_MODE)
{} {
}
affine_( affine_(
const bn_<mode>& item layer_mode mode_
) : mode(mode_)
{
}
template <
layer_mode bnmode
>
affine_(
const bn_<bnmode>& item
) )
{ {
gamma = item.gamma; gamma = item.gamma;
beta = item.beta; beta = item.beta;
mode = bnmode;
params.copy_size(item.params); params.copy_size(item.params);
...@@ -880,7 +911,7 @@ namespace dlib ...@@ -880,7 +911,7 @@ namespace dlib
auto sg = gamma(temp,0); auto sg = gamma(temp,0);
auto sb = beta(temp,gamma.size()); auto sb = beta(temp,gamma.size());
g = pointwise_multiply(mat(sg), mat(item.running_invstds)); g = pointwise_multiply(mat(sg), 1.0f/sqrt(mat(item.running_invstds)+tt::BATCH_NORM_EPS));
b = mat(sb) - pointwise_multiply(mat(g), mat(item.running_means)); b = mat(sb) - pointwise_multiply(mat(g), mat(item.running_means));
} }
...@@ -954,36 +985,45 @@ namespace dlib ...@@ -954,36 +985,45 @@ namespace dlib
{ {
std::string version; std::string version;
deserialize(version, in); deserialize(version, in);
if (version == "bn_") if (version == "bn_con")
{
// Since we can build an affine_ from a bn_ we check if that's what is in
// the stream and if so then just convert it right here.
unserialize sin(version, in);
bn_<CONV_MODE> temp;
deserialize(temp, sin);
item = temp;
return;
}
else if (version == "bn_fc")
{ {
// Since we can build an affine_ from a bn_ we check if that's what is in // Since we can build an affine_ from a bn_ we check if that's what is in
// the stream and if so then just convert it right here. // the stream and if so then just convert it right here.
unserialize sin(version, in); unserialize sin(version, in);
bn_<mode> temp; bn_<FC_MODE> temp;
deserialize(temp, sin); deserialize(temp, sin);
item = temp; item = temp;
return; return;
} }
if (version != "affine_") if (version != "affine_")
throw serialization_error("Unexpected version found while deserializing dlib::affine_."); throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::affine_.");
deserialize(item.params, in); deserialize(item.params, in);
deserialize(item.gamma, in); deserialize(item.gamma, in);
deserialize(item.beta, in); deserialize(item.beta, in);
int _mode; int mode;
deserialize(_mode, in); deserialize(mode, in);
if (mode != (layer_mode)_mode) throw serialization_error("Wrong mode found while deserializing dlib::affine_"); item.mode = (layer_mode)mode;
} }
private: private:
resizable_tensor params, empty_params; resizable_tensor params, empty_params;
alias_tensor gamma, beta; alias_tensor gamma, beta;
layer_mode mode;
}; };
template <typename SUBNET> template <typename SUBNET>
using affine_con = add_layer<affine_<CONV_MODE>, SUBNET>; using affine = add_layer<affine_, SUBNET>;
template <typename SUBNET>
using affine_fc = add_layer<affine_<FC_MODE>, SUBNET>;
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
...@@ -1031,7 +1071,7 @@ namespace dlib ...@@ -1031,7 +1071,7 @@ namespace dlib
std::string version; std::string version;
deserialize(version, in); deserialize(version, in);
if (version != "add_prev_") if (version != "add_prev_")
throw serialization_error("Unexpected version found while deserializing dlib::add_prev_."); throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::add_prev_.");
} }
private: private:
...@@ -1108,7 +1148,7 @@ namespace dlib ...@@ -1108,7 +1148,7 @@ namespace dlib
std::string version; std::string version;
deserialize(version, in); deserialize(version, in);
if (version != "relu_") if (version != "relu_")
throw serialization_error("Unexpected version found while deserializing dlib::relu_."); throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::relu_.");
} }
private: private:
...@@ -1176,7 +1216,7 @@ namespace dlib ...@@ -1176,7 +1216,7 @@ namespace dlib
std::string version; std::string version;
deserialize(version, in); deserialize(version, in);
if (version != "prelu_") if (version != "prelu_")
throw serialization_error("Unexpected version found while deserializing dlib::prelu_."); throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::prelu_.");
deserialize(item.params, in); deserialize(item.params, in);
deserialize(item.initial_param_value, in); deserialize(item.initial_param_value, in);
} }
...@@ -1231,7 +1271,7 @@ namespace dlib ...@@ -1231,7 +1271,7 @@ namespace dlib
std::string version; std::string version;
deserialize(version, in); deserialize(version, in);
if (version != "sig_") if (version != "sig_")
throw serialization_error("Unexpected version found while deserializing dlib::sig_."); throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::sig_.");
} }
private: private:
...@@ -1284,7 +1324,7 @@ namespace dlib ...@@ -1284,7 +1324,7 @@ namespace dlib
std::string version; std::string version;
deserialize(version, in); deserialize(version, in);
if (version != "htan_") if (version != "htan_")
throw serialization_error("Unexpected version found while deserializing dlib::htan_."); throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::htan_.");
} }
private: private:
...@@ -1337,7 +1377,7 @@ namespace dlib ...@@ -1337,7 +1377,7 @@ namespace dlib
std::string version; std::string version;
deserialize(version, in); deserialize(version, in);
if (version != "softmax_") if (version != "softmax_")
throw serialization_error("Unexpected version found while deserializing dlib::softmax_."); throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::softmax_.");
} }
private: private:
......
...@@ -736,9 +736,6 @@ namespace dlib ...@@ -736,9 +736,6 @@ namespace dlib
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
template <
layer_mode mode
>
class affine_ class affine_
{ {
/*! /*!
...@@ -777,11 +774,22 @@ namespace dlib ...@@ -777,11 +774,22 @@ namespace dlib
affine_( affine_(
); );
/*!
ensures
- #get_mode() == FC_MODE
!*/
affine_(
layer_mode mode
);
/*! /*!
ensures ensures
- #get_mode() == mode - #get_mode() == mode
!*/ !*/
template <
layer_mode mode
>
affine_( affine_(
const bn_<mode>& layer const bn_<mode>& layer
); );
...@@ -812,17 +820,16 @@ namespace dlib ...@@ -812,17 +820,16 @@ namespace dlib
are no learnable parameters in this object. are no learnable parameters in this object.
!*/ !*/
friend void serialize(const affine_& item, std::ostream& out);
friend void deserialize(affine_& item, std::istream& in);
/*!
provides serialization support
!*/
}; };
void serialize(const affine_& item, std::ostream& out);
void deserialize(affine_& item, std::istream& in);
/*!
provides serialization support
!*/
template <typename SUBNET> template <typename SUBNET>
using affine_con = add_layer<affine_<CONV_MODE>, SUBNET>; using affine = add_layer<affine_, SUBNET>;
template <typename SUBNET>
using affine_fc = add_layer<affine_<FC_MODE>, SUBNET>;
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
......
...@@ -286,6 +286,8 @@ namespace dlib { namespace tt ...@@ -286,6 +286,8 @@ namespace dlib { namespace tt
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
const double BATCH_NORM_EPS = 0.00001;
void batch_normalize_inference ( void batch_normalize_inference (
resizable_tensor& dest, resizable_tensor& dest,
const tensor& src, const tensor& src,
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
cmake_minimum_required(VERSION 2.8.4) cmake_minimum_required(VERSION 2.8.4)
project(cuda_test) project(cuda_test)
find_package(CUDA 7.0 REQUIRED) find_package(CUDA 7.5 REQUIRED)
set(CUDA_HOST_COMPILATION_CPP ON) set(CUDA_HOST_COMPILATION_CPP ON)
list(APPEND CUDA_NVCC_FLAGS "-arch=sm_30;-std=c++11;-D__STRICT_ANSI__") list(APPEND CUDA_NVCC_FLAGS "-arch=sm_30;-std=c++11;-D__STRICT_ANSI__")
......
...@@ -3,7 +3,7 @@ cmake_minimum_required(VERSION 2.8.4) ...@@ -3,7 +3,7 @@ cmake_minimum_required(VERSION 2.8.4)
project(cudnn_test) project(cudnn_test)
include(../../use_cpp_11.cmake) include(../../use_cpp_11.cmake)
find_package(CUDA 7.0 REQUIRED) find_package(CUDA 7.5 REQUIRED)
set(CUDA_HOST_COMPILATION_CPP ON) set(CUDA_HOST_COMPILATION_CPP ON)
list(APPEND CUDA_NVCC_FLAGS "-arch=sm_30;-std=c++11;-D__STRICT_ANSI__") list(APPEND CUDA_NVCC_FLAGS "-arch=sm_30;-std=c++11;-D__STRICT_ANSI__")
add_definitions(-DDLIB_USE_CUDA) add_definitions(-DDLIB_USE_CUDA)
......
...@@ -166,8 +166,11 @@ namespace ...@@ -166,8 +166,11 @@ namespace
resizable_tensor running_means; resizable_tensor running_means;
resizable_tensor running_invstds; resizable_tensor running_invstds;
batch_normalize(dest, means, vars, 1, running_means, running_invstds, src, gamma, beta); batch_normalize(dest, means, vars, 1, running_means, running_invstds, src, gamma, beta);
const double scale = (src.num_samples())/(src.num_samples()-1.0);
// Turn back into biased variance estimate because that's how batch_normalize() works, so if we want to match it this is necessary.
running_invstds = mat(running_invstds)/scale;
batch_normalize_inference(dest2, src, gamma, beta, running_means, running_invstds); batch_normalize_inference(dest2, src, gamma, beta, running_means, running_invstds);
DLIB_TEST(max(abs(mat(dest2)-mat(dest))) < 1e-5); DLIB_TEST_MSG(max(abs(mat(dest2)-mat(dest))) < 1e-5, max(abs(mat(dest2)-mat(dest))));
auto grad_src = [&](long idx) { auto grad_src = [&](long idx) {
...@@ -246,6 +249,10 @@ namespace ...@@ -246,6 +249,10 @@ namespace
resizable_tensor running_means; resizable_tensor running_means;
resizable_tensor running_invstds; resizable_tensor running_invstds;
batch_normalize_conv(dest, means, vars, 1, running_means, running_invstds, src, gamma, beta); batch_normalize_conv(dest, means, vars, 1, running_means, running_invstds, src, gamma, beta);
const double scale = (src.num_samples()*src.nr()*src.nc())/(src.num_samples()*src.nr()*src.nc()-1.0);
// Turn back into biased variance estimate because that's how
// batch_normalize_conv() works, so if we want to match it this is necessary.
running_invstds = mat(running_invstds)/scale;
batch_normalize_conv_inference(dest2, src, gamma, beta, running_means, running_invstds); batch_normalize_conv_inference(dest2, src, gamma, beta, running_means, running_invstds);
DLIB_TEST(max(abs(mat(dest2)-mat(dest))) < 1e-5); DLIB_TEST(max(abs(mat(dest2)-mat(dest))) < 1e-5);
...@@ -1086,12 +1093,12 @@ namespace ...@@ -1086,12 +1093,12 @@ namespace
} }
{ {
print_spinner(); print_spinner();
affine_<CONV_MODE> l; affine_ l(CONV_MODE);
DLIB_TEST_MSG(test_layer(l), test_layer(l)); DLIB_TEST_MSG(test_layer(l), test_layer(l));
} }
{ {
print_spinner(); print_spinner();
affine_<FC_MODE> l; affine_ l(FC_MODE);
DLIB_TEST_MSG(test_layer(l), test_layer(l)); DLIB_TEST_MSG(test_layer(l), test_layer(l));
} }
{ {
......
...@@ -13,7 +13,7 @@ template <int stride, typename SUBNET> ...@@ -13,7 +13,7 @@ template <int stride, typename SUBNET>
using base_res = relu<add_prev1< bn_con<con<8,3,3,1,1,relu< bn_con<con<8,3,3,stride,stride,tag1<SUBNET>>>>>>>>; using base_res = relu<add_prev1< bn_con<con<8,3,3,1,1,relu< bn_con<con<8,3,3,stride,stride,tag1<SUBNET>>>>>>>>;
template <int stride, typename SUBNET> template <int stride, typename SUBNET>
using base_ares = relu<add_prev1<affine_con<con<8,3,3,1,1,relu<affine_con<con<8,3,3,stride,stride,tag1<SUBNET>>>>>>>>; using base_ares = relu<add_prev1<affine<con<8,3,3,1,1,relu<affine<con<8,3,3,stride,stride,tag1<SUBNET>>>>>>>>;
template <typename SUBNET> using res = base_res<1,SUBNET>; template <typename SUBNET> using res = base_res<1,SUBNET>;
template <typename SUBNET> using res_down = base_res<2,SUBNET>; template <typename SUBNET> using res_down = base_res<2,SUBNET>;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment