Commit 8e6d8ae0 authored by Davis King's avatar Davis King

Changed conv layer to use cross-correlation rather than convolution.

parent 595f0128
......@@ -1631,9 +1631,11 @@ namespace dlib
// now fill in the Toeplitz output matrix for the n-th sample in data.
size_t cnt = 0;
for (long r = filter_nr-1-padding_y; r-padding_y < data.nr(); r+=stride_y)
const long max_r = data.nr() + padding_y-(filter_nr-1);
const long max_c = data.nc() + padding_x-(filter_nc-1);
for (long r = -padding_y; r < max_r; r+=stride_y)
{
for (long c = filter_nc-1-padding_x; c-padding_x < data.nc(); c+=stride_x)
for (long c = -padding_x; c < max_c; c+=stride_x)
{
for (long k = 0; k < data.k(); ++k)
{
......@@ -1642,8 +1644,8 @@ namespace dlib
for (long x = 0; x < filter_nc; ++x)
{
DLIB_ASSERT(cnt < output.size(),"");
long xx = c-x;
long yy = r-y;
long xx = c+x;
long yy = r+y;
if (boundary.contains(xx,yy))
*t = d[(k*data.nr() + yy)*data.nc() + xx];
else
......@@ -1676,9 +1678,11 @@ namespace dlib
const float* t = &output(0,0);
// now fill in the Toeplitz output matrix for the n-th sample in data.
for (long r = filter_nr-1-padding_y; r-padding_y < data.nr(); r+=stride_y)
const long max_r = data.nr() + padding_y-(filter_nr-1);
const long max_c = data.nc() + padding_x-(filter_nc-1);
for (long r = -padding_y; r < max_r; r+=stride_y)
{
for (long c = filter_nc-1-padding_x; c-padding_x < data.nc(); c+=stride_x)
for (long c = -padding_x; c < max_c; c+=stride_x)
{
for (long k = 0; k < data.k(); ++k)
{
......@@ -1686,8 +1690,8 @@ namespace dlib
{
for (long x = 0; x < filter_nc; ++x)
{
long xx = c-x;
long yy = r-y;
long xx = c+x;
long yy = r+y;
if (boundary.contains(xx,yy))
d[(k*data.nr() + yy)*data.nc() + xx] += *t;
++t;
......
......@@ -827,7 +827,7 @@ namespace dlib
stride_y,
stride_x,
1, 1, // must be 1,1
CUDNN_CONVOLUTION)); // could also be CUDNN_CROSS_CORRELATION
CUDNN_CROSS_CORRELATION)); // could also be CUDNN_CONVOLUTION
CHECK_CUDNN(cudnnGetConvolution2dForwardOutputDim(
(const cudnnConvolutionDescriptor_t)conv_handle,
......
......@@ -160,7 +160,7 @@ namespace dlib
friend void serialize(const con_& item, std::ostream& out)
{
serialize("con_3", out);
serialize("con_4", out);
serialize(item.params, out);
serialize(_num_filters, out);
serialize(_nr, out);
......@@ -186,6 +186,33 @@ namespace dlib
long nc;
int stride_y;
int stride_x;
if (version == "con_4")
{
deserialize(item.params, in);
deserialize(num_filters, in);
deserialize(nr, in);
deserialize(nc, in);
deserialize(stride_y, in);
deserialize(stride_x, in);
deserialize(item.padding_y_, in);
deserialize(item.padding_x_, in);
deserialize(item.filters, in);
deserialize(item.biases, in);
deserialize(item.learning_rate_multiplier, in);
deserialize(item.weight_decay_multiplier, in);
deserialize(item.bias_learning_rate_multiplier, in);
deserialize(item.bias_weight_decay_multiplier, in);
if (item.padding_y_ != _padding_y) throw serialization_error("Wrong padding_y found while deserializing dlib::con_");
if (item.padding_x_ != _padding_x) throw serialization_error("Wrong padding_x found while deserializing dlib::con_");
if (num_filters != _num_filters) throw serialization_error("Wrong num_filters found while deserializing dlib::con_");
if (nr != _nr) throw serialization_error("Wrong nr found while deserializing dlib::con_");
if (nc != _nc) throw serialization_error("Wrong nc found while deserializing dlib::con_");
if (stride_y != _stride_y) throw serialization_error("Wrong stride_y found while deserializing dlib::con_");
if (stride_x != _stride_x) throw serialization_error("Wrong stride_x found while deserializing dlib::con_");
return;
}
if (version == "con_")
{
deserialize(item.params, in);
......@@ -237,6 +264,20 @@ namespace dlib
throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::con_.");
}
// now flip all the filters
alias_tensor at(_nr, _nc);
size_t off = 0;
for (int i = 0; i < item.filters.num_samples(); ++i)
{
for (int j = 0; j < item.filters.k(); ++j)
{
auto temp = at(item.params,off);
off += _nr*_nc;
temp = flipud(fliplr(mat(temp)));
}
}
if (num_filters != _num_filters) throw serialization_error("Wrong num_filters found while deserializing dlib::con_");
if (nr != _nr) throw serialization_error("Wrong nr found while deserializing dlib::con_");
if (nc != _nc) throw serialization_error("Wrong nc found while deserializing dlib::con_");
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment