Commit 038c73d8 authored by Davis King's avatar Davis King

Cleaned up the tensor_conv interface a little. Also fixed an error in the spec

for this object.
parent 83ecf1d9
...@@ -265,10 +265,6 @@ namespace dlib ...@@ -265,10 +265,6 @@ namespace dlib
) : ) :
filter_handle(nullptr), filter_handle(nullptr),
conv_handle(nullptr), conv_handle(nullptr),
out_num_samples(0),
out_k(0),
out_nr(0),
out_nc(0),
forward_algo(0), forward_algo(0),
forward_workspace_size_in_bytes(0), forward_workspace_size_in_bytes(0),
forward_workspace(nullptr), forward_workspace(nullptr),
...@@ -279,6 +275,7 @@ namespace dlib ...@@ -279,6 +275,7 @@ namespace dlib
backward_filters_workspace_size_in_bytes(0), backward_filters_workspace_size_in_bytes(0),
backward_filters_workspace(nullptr) backward_filters_workspace(nullptr)
{ {
clear();
} }
void tensor_conv:: void tensor_conv::
...@@ -313,6 +310,17 @@ namespace dlib ...@@ -313,6 +310,17 @@ namespace dlib
backward_filters_workspace = nullptr; backward_filters_workspace = nullptr;
backward_filters_algo = 0; backward_filters_algo = 0;
backward_filters_workspace_size_in_bytes = 0; backward_filters_workspace_size_in_bytes = 0;
stride_y = 0;
stride_x = 0;
data_num_samples = 0;
data_k = 0;
data_nr = 0;
data_nc = 0;
filters_num_samples = 0;
filters_k = 0;
filters_nr = 0;
filters_nc = 0;
} }
void tensor_conv:: void tensor_conv::
...@@ -324,11 +332,36 @@ namespace dlib ...@@ -324,11 +332,36 @@ namespace dlib
) )
{ {
DLIB_CASSERT(data.k() == filters.k(),""); DLIB_CASSERT(data.k() == filters.k(),"");
// if the last call to setup gave the same exact settings then don't do
// anything.
if (stride_y_ == stride_y &&
stride_x_ == stride_x &&
data_num_samples == data.num_samples() &&
data_k == data.k() &&
data_nr == data.nr() &&
data_nc == data.nc() &&
filters_num_samples == filters.num_samples() &&
filters_k == filters.k() &&
filters_nr == filters.nr() &&
filters_nc == filters.nc())
{
return;
}
clear(); clear();
try try
{ {
stride_y = stride_y_; stride_y = stride_y_;
stride_x = stride_x_; stride_x = stride_x_;
data_num_samples = data.num_samples();
data_k = data.k();
data_nr = data.nr();
data_nc = data.nc();
filters_num_samples = filters.num_samples();
filters_k = filters.k();
filters_nr = filters.nr();
filters_nc = filters.nc();
CHECK_CUDNN(cudnnCreateFilterDescriptor((cudnnFilterDescriptor_t*)&filter_handle)); CHECK_CUDNN(cudnnCreateFilterDescriptor((cudnnFilterDescriptor_t*)&filter_handle));
CHECK_CUDNN(cudnnSetFilter4dDescriptor((cudnnFilterDescriptor_t)filter_handle, CHECK_CUDNN(cudnnSetFilter4dDescriptor((cudnnFilterDescriptor_t)filter_handle,
...@@ -446,18 +479,22 @@ namespace dlib ...@@ -446,18 +479,22 @@ namespace dlib
void tensor_conv::operator() ( void tensor_conv::operator() (
resizable_tensor& output, resizable_tensor& output,
const tensor& data, const tensor& data,
const tensor& filters const tensor& filters,
int stride_y,
int stride_x
) )
{ {
DLIB_ASSERT(is_same_object(output,data) == false,""); DLIB_ASSERT(is_same_object(output,data) == false,"");
DLIB_ASSERT(is_same_object(output,filters) == false,""); DLIB_ASSERT(is_same_object(output,filters) == false,"");
setup(data,filters,stride_y,stride_x);
output.set_size(out_num_samples, out_k, out_nr, out_nc); output.set_size(out_num_samples, out_k, out_nr, out_nc);
DLIB_ASSERT(output.num_samples() == data.num_samples(),out_num_samples << " " << data.num_samples()); DLIB_ASSERT(output.num_samples() == data.num_samples(),out_num_samples << " " << data.num_samples());
DLIB_ASSERT(output.k() == filters.num_samples(),""); DLIB_ASSERT(output.k() == filters.num_samples(),"");
DLIB_ASSERT(output.nr() == 1+(data.nr()-1)/stride_y,""); DLIB_ASSERT(output.nr() == 1+(data.nr()-filters.nr()%2)/stride_y,"");
DLIB_ASSERT(output.nc() == 1+(data.nc()-1)/stride_x,""); DLIB_ASSERT(output.nc() == 1+(data.nc()-filters.nc()%2)/stride_x,output.nc() << " " <<1+(data.nc()-1)/stride_x << " : " << data.nc() << " " << stride_x);
const float alpha = 1; const float alpha = 1;
const float beta = 0; const float beta = 0;
......
...@@ -144,31 +144,20 @@ namespace dlib ...@@ -144,31 +144,20 @@ namespace dlib
void clear( void clear(
); );
void setup(
const tensor& data,
const tensor& filters,
int stride_y,
int stride_x
);
/*!
requires
- filters.k() == data.k()
- stride_y > 0
- stride_x > 0
!*/
~tensor_conv ( ~tensor_conv (
); );
void operator() ( void operator() (
resizable_tensor& output, resizable_tensor& output,
const tensor& data, const tensor& data,
const tensor& filters const tensor& filters,
int stride_y,
int stride_x
); );
/*! /*!
requires requires
- The dimensions of data and filters are the same as the ones given - stride_y > 0
to the last call to setup(). - stride_x > 0
- is_same_object(output,data) == false - is_same_object(output,data) == false
- is_same_object(output,filters) == false - is_same_object(output,filters) == false
ensures ensures
...@@ -176,8 +165,8 @@ namespace dlib ...@@ -176,8 +165,8 @@ namespace dlib
- filters contains filters.num_samples() filters. - filters contains filters.num_samples() filters.
- #output.num_samples() == data.num_samples() - #output.num_samples() == data.num_samples()
- #output.k() == filters.num_samples() - #output.k() == filters.num_samples()
- #output.nr() == 1+(data.nr()-1)/stride_y - #output.nr() == 1+(data.nr()-filters.nr()%2)/stride_y
- #output.nc() == 1+(data.nc()-1)/stride_x - #output.nc() == 1+(data.nc()-filters.nc()%2)/stride_x
!*/ !*/
void get_gradient_for_data ( void get_gradient_for_data (
...@@ -188,9 +177,9 @@ namespace dlib ...@@ -188,9 +177,9 @@ namespace dlib
/*! /*!
requires requires
- filters has the same dimensions as the filters object give to the - filters has the same dimensions as the filters object give to the
last call to setup(). last call to operator().
- data_gradient has the same dimensions as the data object give to the - data_gradient has the same dimensions as the data object give to the
last call to setup(). last call to operator().
- gradient_input has the same dimensions as the output of operator(). - gradient_input has the same dimensions as the output of operator().
- is_same_object(data_gradient,filters) == false - is_same_object(data_gradient,filters) == false
- is_same_object(data_gradient,gradient_input) == false - is_same_object(data_gradient,gradient_input) == false
...@@ -209,9 +198,9 @@ namespace dlib ...@@ -209,9 +198,9 @@ namespace dlib
/*! /*!
requires requires
- filters_gradient has the same dimensions as the filters object give - filters_gradient has the same dimensions as the filters object give
to the last call to setup(). to the last call to operator().
- data has the same dimensions as the data object give to the last call - data has the same dimensions as the data object give to the last call
to setup(). to operator().
- gradient_input has the same dimensions as the output of operator(). - gradient_input has the same dimensions as the output of operator().
- is_same_object(filters_gradient,data) == false - is_same_object(filters_gradient,data) == false
- is_same_object(filters_gradient,gradient_input) == false - is_same_object(filters_gradient,gradient_input) == false
...@@ -223,10 +212,29 @@ namespace dlib ...@@ -223,10 +212,29 @@ namespace dlib
!*/ !*/
private: private:
void* filter_handle;
void* conv_handle; void setup(
const tensor& data,
const tensor& filters,
int stride_y,
int stride_x
);
/*!
requires
- filters.k() == data.k()
- stride_y > 0
- stride_x > 0
!*/
// These variables record the type of data given to the last call to setup().
int stride_y; int stride_y;
int stride_x; int stride_x;
long data_num_samples, data_k, data_nr, data_nc;
long filters_num_samples, filters_k, filters_nr, filters_nc;
void* filter_handle;
void* conv_handle;
// dimensions of the output tensor from operator() // dimensions of the output tensor from operator()
int out_num_samples; int out_num_samples;
......
...@@ -277,7 +277,8 @@ namespace dlib { namespace tt ...@@ -277,7 +277,8 @@ namespace dlib { namespace tt
} }
void tensor_conv:: void tensor_conv::
setup( operator() (
resizable_tensor& output,
const tensor& data, const tensor& data,
const tensor& filters, const tensor& filters,
int stride_y, int stride_y,
...@@ -285,22 +286,7 @@ namespace dlib { namespace tt ...@@ -285,22 +286,7 @@ namespace dlib { namespace tt
) )
{ {
#ifdef DLIB_USE_CUDA #ifdef DLIB_USE_CUDA
impl.setup(data, filters, stride_y, stride_x); impl(output, data, filters, stride_y, stride_x);
#else
// TODO
DLIB_CASSERT(false,"");
#endif
}
void tensor_conv::
operator() (
resizable_tensor& output,
const tensor& data,
const tensor& filters
)
{
#ifdef DLIB_USE_CUDA
impl(output, data, filters);
#else #else
// TODO // TODO
DLIB_CASSERT(false,""); DLIB_CASSERT(false,"");
......
...@@ -397,7 +397,8 @@ namespace dlib { namespace tt ...@@ -397,7 +397,8 @@ namespace dlib { namespace tt
void clear( void clear(
); );
void setup( void operator() (
resizable_tensor& output,
const tensor& data, const tensor& data,
const tensor& filters, const tensor& filters,
int stride_y, int stride_y,
...@@ -405,20 +406,8 @@ namespace dlib { namespace tt ...@@ -405,20 +406,8 @@ namespace dlib { namespace tt
); );
/*! /*!
requires requires
- filters.k() == data.k()
- stride_y > 0 - stride_y > 0
- stride_x > 0 - stride_x > 0
!*/
void operator() (
resizable_tensor& output,
const tensor& data,
const tensor& filters
);
/*!
requires
- The dimensions of data and filters are the same as the ones given
to the last call to setup().
- is_same_object(output,data) == false - is_same_object(output,data) == false
- is_same_object(output,filters) == false - is_same_object(output,filters) == false
ensures ensures
...@@ -426,8 +415,8 @@ namespace dlib { namespace tt ...@@ -426,8 +415,8 @@ namespace dlib { namespace tt
- filters contains filters.num_samples() filters. - filters contains filters.num_samples() filters.
- #output.num_samples() == data.num_samples() - #output.num_samples() == data.num_samples()
- #output.k() == filters.num_samples() - #output.k() == filters.num_samples()
- #output.nr() == 1+(data.nr()-1)/stride_y - #output.nr() == 1+(data.nr()-filters.nr()%2)/stride_y
- #output.nc() == 1+(data.nc()-1)/stride_x - #output.nc() == 1+(data.nc()-filters.nc()%2)/stride_x
!*/ !*/
void get_gradient_for_data ( void get_gradient_for_data (
...@@ -438,9 +427,9 @@ namespace dlib { namespace tt ...@@ -438,9 +427,9 @@ namespace dlib { namespace tt
/*! /*!
requires requires
- filters has the same dimensions as the filters object give to the last - filters has the same dimensions as the filters object give to the last
call to setup(). call to operator().
- data_gradient has the same dimensions as the data object give to the last - data_gradient has the same dimensions as the data object give to the last
call to setup(). call to operator().
- gradient_input has the same dimensions as the output of operator(). - gradient_input has the same dimensions as the output of operator().
- is_same_object(data_gradient,filters) == false - is_same_object(data_gradient,filters) == false
- is_same_object(data_gradient,gradient_input) == false - is_same_object(data_gradient,gradient_input) == false
...@@ -459,9 +448,9 @@ namespace dlib { namespace tt ...@@ -459,9 +448,9 @@ namespace dlib { namespace tt
/*! /*!
requires requires
- filters_gradient has the same dimensions as the filters object give to - filters_gradient has the same dimensions as the filters object give to
the last call to setup(). the last call to operator().
- data has the same dimensions as the data object give to the last call to - data has the same dimensions as the data object give to the last call to
setup(). operator().
- gradient_input has the same dimensions as the output of operator(). - gradient_input has the same dimensions as the output of operator().
- is_same_object(filters_gradient,data) == false - is_same_object(filters_gradient,data) == false
- is_same_object(filters_gradient,gradient_input) == false - is_same_object(filters_gradient,gradient_input) == false
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment