Commit f9cb3150 authored by Davis King's avatar Davis King

upgraded to cudnn v5. Also changed the affine_ layer to not be templated but

to automatically select the right mode.  The serialization format for bn_
layers has also changed, but the code will still be able to deserialize older
bn_ objects.
parent c81825c3
......@@ -445,7 +445,7 @@ if (NOT TARGET dlib)
if (DLIB_USE_CUDA)
find_package(CUDA 7.0)
find_package(CUDA 7.5)
if (CUDA_FOUND AND COMPILER_CAN_DO_CPP_11)
......@@ -505,7 +505,7 @@ if (NOT TARGET dlib)
set(DLIB_USE_CUDA OFF CACHE STRING ${DLIB_USE_BLAS_STR} FORCE )
toggle_preprocessor_switch(DLIB_USE_CUDA)
if (NOT cudnn OR NOT cudnn_include OR NOT cudnn_test_compile_worked)
message(STATUS "*** cuDNN V4.0 OR GREATER NOT FOUND. DLIB WILL NOT USE CUDA. ***")
message(STATUS "*** cuDNN V5.0 OR GREATER NOT FOUND. DLIB WILL NOT USE CUDA. ***")
message(STATUS "*** If you have cuDNN then set CMAKE_PREFIX_PATH to include cuDNN's folder.")
endif()
if (NOT COMPILER_CAN_DO_CPP_11)
......
......@@ -6,6 +6,7 @@
// This file contains CPU implementations of the GPU based functions in cuda_dlib.h
#include "cpu_dlib.h"
#include "tensor_tools.h"
namespace dlib
{
......@@ -510,7 +511,7 @@ namespace dlib
{
for (long k = 0; k < num; ++k)
{
*d = g[k]*(*s - m[k])*i[k] + b[k];
*d = g[k]*(*s - m[k])/std::sqrt(i[k]+dlib::tt::BATCH_NORM_EPS) + b[k];
++d;
++s;
}
......@@ -579,10 +580,18 @@ namespace dlib
invstds.host(); means.host();
// compute variances
running_invstds.copy_size(invstds);
auto rvar = running_invstds.host();
const double scale = (src.num_samples())/(src.num_samples()-1.0);
for (long i = 0; i < num; ++i)
{
auto actual_var = p_invstds[i] - p_means[i]*p_means[i];
p_invstds[i] = 1.0f/std::sqrt(actual_var+BATCH_NORM_EPS);
if (averaging_factor == 1)
rvar[i] = scale*actual_var;
else
rvar[i] = (1-averaging_factor)*rvar[i] + scale*averaging_factor*actual_var;
p_invstds[i] = 1.0f/std::sqrt(actual_var + dlib::tt::BATCH_NORM_EPS);
}
p_src = src.host();
......@@ -600,19 +609,12 @@ namespace dlib
}
}
// now keep track of the running means and invstds
// now keep track of the running means
running_means.copy_size(means);
running_invstds.copy_size(invstds);
if (averaging_factor != 1)
{
running_means = (1-averaging_factor)*mat(running_means) + averaging_factor*mat(means);
running_invstds = (1-averaging_factor)*mat(running_invstds) + averaging_factor*mat(invstds);
}
else
{
running_means = means;
running_invstds = invstds;
}
}
void batch_normalize_gradient (
......@@ -761,9 +763,10 @@ namespace dlib
{
for (long k = 0; k < src.k(); ++k)
{
const float invstd = 1.0f/std::sqrt(i[k] + dlib::tt::BATCH_NORM_EPS);
for (long j = 0; j < num; ++j)
{
*d = g[k]*(*s - m[k])*i[k] + b[k];
*d = g[k]*(*s - m[k])*invstd + b[k];
++d;
++s;
}
......@@ -841,10 +844,18 @@ namespace dlib
p_src = src.host();
// compute variances
running_invstds.copy_size(invstds);
auto rvar = running_invstds.host();
const double scale = (src.num_samples()*num)/(src.num_samples()*num-1.0);
for (long k = 0; k < src.k(); ++k)
{
float actual_var = p_invstds[k] - p_means[k]*p_means[k];
p_invstds[k] = 1.0f/std::sqrt(actual_var + BATCH_NORM_EPS);
if (averaging_factor == 1)
rvar[k] = scale*actual_var;
else
rvar[k] = (1-averaging_factor)*rvar[k] + scale*averaging_factor*actual_var;
p_invstds[k] = 1.0f/std::sqrt(actual_var + dlib::tt::BATCH_NORM_EPS);
}
p_src = src.host();
......@@ -863,19 +874,12 @@ namespace dlib
}
}
// now keep track of the running means and invstds
// now keep track of the running means
running_means.copy_size(means);
running_invstds.copy_size(invstds);
if (averaging_factor != 1)
{
running_means = (1-averaging_factor)*mat(running_means) + averaging_factor*mat(means);
running_invstds = (1-averaging_factor)*mat(running_invstds) + averaging_factor*mat(invstds);
}
else
{
running_means = means;
running_invstds = invstds;
}
}
void batch_normalize_conv_gradient(
......
......@@ -13,10 +13,6 @@ namespace dlib
namespace cpu
{
// ----------------------------------------------------------------------------------------
const double BATCH_NORM_EPS = 0.00001;
// -----------------------------------------------------------------------------------
void multiply (
......
......@@ -112,6 +112,57 @@ namespace dlib
return c.get_handle();
}
// ------------------------------------------------------------------------------------
class cudnn_activation_descriptor
{
public:
// not copyable
cudnn_activation_descriptor(const cudnn_activation_descriptor&) = delete;
cudnn_activation_descriptor& operator=(const cudnn_activation_descriptor&) = delete;
cudnn_activation_descriptor(
cudnnActivationMode_t mode,
cudnnNanPropagation_t reluNanOpt,
double reluCeiling
)
{
CHECK_CUDNN(cudnnCreateActivationDescriptor(&handle));
CHECK_CUDNN(cudnnSetActivationDescriptor(handle, mode, reluNanOpt, reluCeiling));
}
~cudnn_activation_descriptor()
{
cudnnDestroyActivationDescriptor(handle);
}
cudnnActivationDescriptor_t get_handle (
)
{
return handle;
}
private:
cudnnActivationDescriptor_t handle;
};
static cudnnActivationDescriptor_t relu_activation_descriptor()
{
thread_local cudnn_activation_descriptor des(CUDNN_ACTIVATION_RELU, CUDNN_PROPAGATE_NAN,0);
return des.get_handle();
}
static cudnnActivationDescriptor_t sigmoid_activation_descriptor()
{
thread_local cudnn_activation_descriptor des(CUDNN_ACTIVATION_SIGMOID, CUDNN_PROPAGATE_NAN,0);
return des.get_handle();
}
static cudnnActivationDescriptor_t tanh_activation_descriptor()
{
thread_local cudnn_activation_descriptor des(CUDNN_ACTIVATION_TANH, CUDNN_PROPAGATE_NAN,0);
return des.get_handle();
}
// ------------------------------------------------------------------------------------
tensor_descriptor::
......@@ -223,7 +274,7 @@ namespace dlib
return;
}
CHECK_CUDNN(cudnnAddTensor_v3(context(),
CHECK_CUDNN(cudnnAddTensor(context(),
&alpha,
descriptor(src),
src.device(),
......@@ -342,7 +393,7 @@ namespace dlib
beta.device(),
running_means.device(),
running_invstds.device(),
dlib::cpu::BATCH_NORM_EPS));
dlib::tt::BATCH_NORM_EPS));
}
void batch_normalize (
......@@ -404,7 +455,7 @@ namespace dlib
averaging_factor,
running_means.device(),
running_invstds.device(),
dlib::cpu::BATCH_NORM_EPS,
dlib::tt::BATCH_NORM_EPS,
means.device(),
invstds.device()));
}
......@@ -452,7 +503,7 @@ namespace dlib
gamma.device(),
gamma_grad.device(),
beta_grad.device(),
dlib::cpu::BATCH_NORM_EPS,
dlib::tt::BATCH_NORM_EPS,
means.device(),
invstds.device()));
}
......@@ -515,7 +566,7 @@ namespace dlib
beta.device(),
running_means.device(),
running_invstds.device(),
dlib::cpu::BATCH_NORM_EPS));
dlib::tt::BATCH_NORM_EPS));
}
void batch_normalize_conv (
......@@ -578,7 +629,7 @@ namespace dlib
averaging_factor,
running_means.device(),
running_invstds.device(),
dlib::cpu::BATCH_NORM_EPS,
dlib::tt::BATCH_NORM_EPS,
means.device(),
invstds.device()));
}
......@@ -625,7 +676,7 @@ namespace dlib
gamma.device(),
gamma_grad.device(),
beta_grad.device(),
dlib::cpu::BATCH_NORM_EPS,
dlib::tt::BATCH_NORM_EPS,
means.device(),
invstds.device()));
}
......@@ -739,6 +790,7 @@ namespace dlib
CHECK_CUDNN(cudnnCreateFilterDescriptor((cudnnFilterDescriptor_t*)&filter_handle));
CHECK_CUDNN(cudnnSetFilter4dDescriptor((cudnnFilterDescriptor_t)filter_handle,
CUDNN_DATA_FLOAT,
CUDNN_TENSOR_NCHW,
filters.num_samples(),
filters.k(),
filters.nr(),
......@@ -900,7 +952,7 @@ namespace dlib
const float beta = 1;
CHECK_CUDNN(cudnnConvolutionBackwardData_v3(context(),
CHECK_CUDNN(cudnnConvolutionBackwardData(context(),
&alpha,
(const cudnnFilterDescriptor_t)filter_handle,
filters.device(),
......@@ -924,7 +976,7 @@ namespace dlib
{
const float alpha = 1;
const float beta = 0;
CHECK_CUDNN(cudnnConvolutionBackwardFilter_v3(context(),
CHECK_CUDNN(cudnnConvolutionBackwardFilter(context(),
&alpha,
descriptor(data),
data.device(),
......@@ -1020,6 +1072,7 @@ namespace dlib
CHECK_CUDNN(cudnnSetPooling2dDescriptor(poolingDesc,
(cudnnPoolingMode_t)pooling_mode,
CUDNN_PROPAGATE_NAN,
window_height,
window_width,
window_height/2,
......@@ -1176,7 +1229,7 @@ namespace dlib
const float alpha = 1;
const float beta = 0;
CHECK_CUDNN(cudnnActivationForward(context(),
CUDNN_ACTIVATION_SIGMOID,
sigmoid_activation_descriptor(),
&alpha,
descriptor(src),
src.device(),
......@@ -1200,7 +1253,7 @@ namespace dlib
const float alpha = 1;
const float beta = 0;
CHECK_CUDNN(cudnnActivationBackward(context(),
CUDNN_ACTIVATION_SIGMOID,
sigmoid_activation_descriptor(),
&alpha,
descriptor(dest),
dest.device(),
......@@ -1227,7 +1280,7 @@ namespace dlib
const float alpha = 1;
const float beta = 0;
CHECK_CUDNN(cudnnActivationForward(context(),
CUDNN_ACTIVATION_RELU,
relu_activation_descriptor(),
&alpha,
descriptor(src),
src.device(),
......@@ -1251,7 +1304,7 @@ namespace dlib
const float alpha = 1;
const float beta = 0;
CHECK_CUDNN(cudnnActivationBackward(context(),
CUDNN_ACTIVATION_RELU,
relu_activation_descriptor(),
&alpha,
descriptor(dest),
dest.device(),
......@@ -1278,7 +1331,7 @@ namespace dlib
const float alpha = 1;
const float beta = 0;
CHECK_CUDNN(cudnnActivationForward(context(),
CUDNN_ACTIVATION_TANH,
tanh_activation_descriptor(),
&alpha,
descriptor(src),
src.device(),
......@@ -1302,7 +1355,7 @@ namespace dlib
const float alpha = 1;
const float beta = 0;
CHECK_CUDNN(cudnnActivationBackward(context(),
CUDNN_ACTIVATION_TANH,
tanh_activation_descriptor(),
&alpha,
descriptor(dest),
dest.device(),
......
......@@ -133,7 +133,7 @@ namespace dlib
std::string version;
deserialize(version, in);
if (version != "con_")
throw serialization_error("Unexpected version found while deserializing dlib::con_.");
throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::con_.");
deserialize(item.params, in);
......@@ -258,7 +258,7 @@ namespace dlib
std::string version;
deserialize(version, in);
if (version != "max_pool_")
throw serialization_error("Unexpected version found while deserializing dlib::max_pool_.");
throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::max_pool_.");
item.mp.setup_max_pooling(_nr, _nc, _stride_y, _stride_x);
......@@ -374,7 +374,7 @@ namespace dlib
std::string version;
deserialize(version, in);
if (version != "avg_pool_")
throw serialization_error("Unexpected version found while deserializing dlib::avg_pool_.");
throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::avg_pool_.");
item.ap.setup_avg_pooling(_nr, _nc, _stride_y, _stride_x);
......@@ -500,7 +500,10 @@ namespace dlib
friend void serialize(const bn_& item, std::ostream& out)
{
serialize("bn_", out);
if (mode == CONV_MODE)
serialize("bn_con", out);
else // if FC_MODE
serialize("bn_fc", out);
serialize(item.params, out);
serialize(item.gamma, out);
serialize(item.beta, out);
......@@ -510,7 +513,6 @@ namespace dlib
serialize(item.running_invstds, out);
serialize(item.num_updates, out);
serialize(item.running_stats_window_size, out);
serialize((int)mode, out);
}
friend void deserialize(bn_& item, std::istream& in)
......@@ -518,7 +520,19 @@ namespace dlib
std::string version;
deserialize(version, in);
if (version != "bn_")
throw serialization_error("Unexpected version found while deserializing dlib::bn_.");
{
if (mode == CONV_MODE)
{
if (version != "bn_con")
throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::bn_.");
}
else // must be in FC_MODE
{
if (version != "bn_fc")
throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::bn_.");
}
}
deserialize(item.params, in);
deserialize(item.gamma, in);
deserialize(item.beta, in);
......@@ -528,14 +542,23 @@ namespace dlib
deserialize(item.running_invstds, in);
deserialize(item.num_updates, in);
deserialize(item.running_stats_window_size, in);
// if this is the older "bn_" version then check its saved mode value and make
// sure it is the one we are really using.
if (version == "bn_")
{
int _mode;
deserialize(_mode, in);
if (mode != (layer_mode)_mode) throw serialization_error("Wrong mode found while deserializing dlib::bn_");
// We also need to flip the running_invstds around since the previous
// format saved the inverse standard deviations instead of variances.
item.running_invstds = 1.0f/squared(mat(item.running_invstds)) - tt::BATCH_NORM_EPS;
}
}
private:
template < layer_mode Mode >
friend class affine_;
resizable_tensor params;
......@@ -660,7 +683,7 @@ namespace dlib
std::string version;
deserialize(version, in);
if (version != "fc_")
throw serialization_error("Unexpected version found while deserializing dlib::fc_.");
throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::fc_.");
deserialize(item.num_outputs, in);
deserialize(item.num_inputs, in);
......@@ -760,7 +783,7 @@ namespace dlib
std::string version;
deserialize(version, in);
if (version != "dropout_")
throw serialization_error("Unexpected version found while deserializing dlib::dropout_.");
throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::dropout_.");
deserialize(item.drop_rate, in);
deserialize(item.mask, in);
}
......@@ -840,7 +863,7 @@ namespace dlib
}
if (version != "multiply_")
throw serialization_error("Unexpected version found while deserializing dlib::multiply_.");
throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::multiply_.");
deserialize(item.val, in);
}
......@@ -854,22 +877,30 @@ namespace dlib
// ----------------------------------------------------------------------------------------
template <
layer_mode mode
>
class affine_
{
public:
affine_(
)
{}
) : mode(FC_MODE)
{
}
affine_(
const bn_<mode>& item
layer_mode mode_
) : mode(mode_)
{
}
template <
layer_mode bnmode
>
affine_(
const bn_<bnmode>& item
)
{
gamma = item.gamma;
beta = item.beta;
mode = bnmode;
params.copy_size(item.params);
......@@ -880,7 +911,7 @@ namespace dlib
auto sg = gamma(temp,0);
auto sb = beta(temp,gamma.size());
g = pointwise_multiply(mat(sg), mat(item.running_invstds));
g = pointwise_multiply(mat(sg), 1.0f/sqrt(mat(item.running_invstds)+tt::BATCH_NORM_EPS));
b = mat(sb) - pointwise_multiply(mat(g), mat(item.running_means));
}
......@@ -954,36 +985,45 @@ namespace dlib
{
std::string version;
deserialize(version, in);
if (version == "bn_")
if (version == "bn_con")
{
// Since we can build an affine_ from a bn_ we check if that's what is in
// the stream and if so then just convert it right here.
unserialize sin(version, in);
bn_<CONV_MODE> temp;
deserialize(temp, sin);
item = temp;
return;
}
else if (version == "bn_fc")
{
// Since we can build an affine_ from a bn_ we check if that's what is in
// the stream and if so then just convert it right here.
unserialize sin(version, in);
bn_<mode> temp;
bn_<FC_MODE> temp;
deserialize(temp, sin);
item = temp;
return;
}
if (version != "affine_")
throw serialization_error("Unexpected version found while deserializing dlib::affine_.");
throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::affine_.");
deserialize(item.params, in);
deserialize(item.gamma, in);
deserialize(item.beta, in);
int _mode;
deserialize(_mode, in);
if (mode != (layer_mode)_mode) throw serialization_error("Wrong mode found while deserializing dlib::affine_");
int mode;
deserialize(mode, in);
item.mode = (layer_mode)mode;
}
private:
resizable_tensor params, empty_params;
alias_tensor gamma, beta;
layer_mode mode;
};
template <typename SUBNET>
using affine_con = add_layer<affine_<CONV_MODE>, SUBNET>;
template <typename SUBNET>
using affine_fc = add_layer<affine_<FC_MODE>, SUBNET>;
using affine = add_layer<affine_, SUBNET>;
// ----------------------------------------------------------------------------------------
......@@ -1031,7 +1071,7 @@ namespace dlib
std::string version;
deserialize(version, in);
if (version != "add_prev_")
throw serialization_error("Unexpected version found while deserializing dlib::add_prev_.");
throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::add_prev_.");
}
private:
......@@ -1108,7 +1148,7 @@ namespace dlib
std::string version;
deserialize(version, in);
if (version != "relu_")
throw serialization_error("Unexpected version found while deserializing dlib::relu_.");
throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::relu_.");
}
private:
......@@ -1176,7 +1216,7 @@ namespace dlib
std::string version;
deserialize(version, in);
if (version != "prelu_")
throw serialization_error("Unexpected version found while deserializing dlib::prelu_.");
throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::prelu_.");
deserialize(item.params, in);
deserialize(item.initial_param_value, in);
}
......@@ -1231,7 +1271,7 @@ namespace dlib
std::string version;
deserialize(version, in);
if (version != "sig_")
throw serialization_error("Unexpected version found while deserializing dlib::sig_.");
throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::sig_.");
}
private:
......@@ -1284,7 +1324,7 @@ namespace dlib
std::string version;
deserialize(version, in);
if (version != "htan_")
throw serialization_error("Unexpected version found while deserializing dlib::htan_.");
throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::htan_.");
}
private:
......@@ -1337,7 +1377,7 @@ namespace dlib
std::string version;
deserialize(version, in);
if (version != "softmax_")
throw serialization_error("Unexpected version found while deserializing dlib::softmax_.");
throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::softmax_.");
}
private:
......
......@@ -736,9 +736,6 @@ namespace dlib
// ----------------------------------------------------------------------------------------
template <
layer_mode mode
>
class affine_
{
/*!
......@@ -777,11 +774,22 @@ namespace dlib
affine_(
);
/*!
ensures
- #get_mode() == FC_MODE
!*/
affine_(
layer_mode mode
);
/*!
ensures
- #get_mode() == mode
!*/
template <
layer_mode mode
>
affine_(
const bn_<mode>& layer
);
......@@ -812,17 +820,16 @@ namespace dlib
are no learnable parameters in this object.
!*/
friend void serialize(const affine_& item, std::ostream& out);
friend void deserialize(affine_& item, std::istream& in);
};
void serialize(const affine_& item, std::ostream& out);
void deserialize(affine_& item, std::istream& in);
/*!
provides serialization support
!*/
};
template <typename SUBNET>
using affine_con = add_layer<affine_<CONV_MODE>, SUBNET>;
template <typename SUBNET>
using affine_fc = add_layer<affine_<FC_MODE>, SUBNET>;
using affine = add_layer<affine_, SUBNET>;
// ----------------------------------------------------------------------------------------
......
......@@ -286,6 +286,8 @@ namespace dlib { namespace tt
// ----------------------------------------------------------------------------------------
const double BATCH_NORM_EPS = 0.00001;
void batch_normalize_inference (
resizable_tensor& dest,
const tensor& src,
......
......@@ -2,7 +2,7 @@
cmake_minimum_required(VERSION 2.8.4)
project(cuda_test)
find_package(CUDA 7.0 REQUIRED)
find_package(CUDA 7.5 REQUIRED)
set(CUDA_HOST_COMPILATION_CPP ON)
list(APPEND CUDA_NVCC_FLAGS "-arch=sm_30;-std=c++11;-D__STRICT_ANSI__")
......
......@@ -3,7 +3,7 @@ cmake_minimum_required(VERSION 2.8.4)
project(cudnn_test)
include(../../use_cpp_11.cmake)
find_package(CUDA 7.0 REQUIRED)
find_package(CUDA 7.5 REQUIRED)
set(CUDA_HOST_COMPILATION_CPP ON)
list(APPEND CUDA_NVCC_FLAGS "-arch=sm_30;-std=c++11;-D__STRICT_ANSI__")
add_definitions(-DDLIB_USE_CUDA)
......
......@@ -166,8 +166,11 @@ namespace
resizable_tensor running_means;
resizable_tensor running_invstds;
batch_normalize(dest, means, vars, 1, running_means, running_invstds, src, gamma, beta);
const double scale = (src.num_samples())/(src.num_samples()-1.0);
// Turn back into biased variance estimate because that's how batch_normalize() works, so if we want to match it this is necessary.
running_invstds = mat(running_invstds)/scale;
batch_normalize_inference(dest2, src, gamma, beta, running_means, running_invstds);
DLIB_TEST(max(abs(mat(dest2)-mat(dest))) < 1e-5);
DLIB_TEST_MSG(max(abs(mat(dest2)-mat(dest))) < 1e-5, max(abs(mat(dest2)-mat(dest))));
auto grad_src = [&](long idx) {
......@@ -246,6 +249,10 @@ namespace
resizable_tensor running_means;
resizable_tensor running_invstds;
batch_normalize_conv(dest, means, vars, 1, running_means, running_invstds, src, gamma, beta);
const double scale = (src.num_samples()*src.nr()*src.nc())/(src.num_samples()*src.nr()*src.nc()-1.0);
// Turn back into biased variance estimate because that's how
// batch_normalize_conv() works, so if we want to match it this is necessary.
running_invstds = mat(running_invstds)/scale;
batch_normalize_conv_inference(dest2, src, gamma, beta, running_means, running_invstds);
DLIB_TEST(max(abs(mat(dest2)-mat(dest))) < 1e-5);
......@@ -1086,12 +1093,12 @@ namespace
}
{
print_spinner();
affine_<CONV_MODE> l;
affine_ l(CONV_MODE);
DLIB_TEST_MSG(test_layer(l), test_layer(l));
}
{
print_spinner();
affine_<FC_MODE> l;
affine_ l(FC_MODE);
DLIB_TEST_MSG(test_layer(l), test_layer(l));
}
{
......
......@@ -13,7 +13,7 @@ template <int stride, typename SUBNET>
using base_res = relu<add_prev1< bn_con<con<8,3,3,1,1,relu< bn_con<con<8,3,3,stride,stride,tag1<SUBNET>>>>>>>>;
template <int stride, typename SUBNET>
using base_ares = relu<add_prev1<affine_con<con<8,3,3,1,1,relu<affine_con<con<8,3,3,stride,stride,tag1<SUBNET>>>>>>>>;
using base_ares = relu<add_prev1<affine<con<8,3,3,1,1,relu<affine<con<8,3,3,stride,stride,tag1<SUBNET>>>>>>>>;
template <typename SUBNET> using res = base_res<1,SUBNET>;
template <typename SUBNET> using res_down = base_res<2,SUBNET>;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment