Commit ee6e54b4 authored by Davis King's avatar Davis King

Figured out the *undocumented* requirements for calling cuDNN's

cudnnAddTensor() function and updated the specs and asserts accordingly.
parent 0babe27a
...@@ -76,11 +76,21 @@ namespace dlib ...@@ -76,11 +76,21 @@ namespace dlib
) )
{ {
DLIB_CASSERT( DLIB_CASSERT(
(dest.num_samples()==src.num_samples() || src.num_samples()==1) && (have_same_dimensions(src, dest) ||
(dest.nr()==src.nr() || src.nr()==1) && (src.num_samples()==1 && src.k()==dest.k() && src.nr()==1 && src.nc()==1) ||
(dest.nc()==src.nc() || src.nc()==1) && (src.num_samples()==1 && src.k()==dest.k() && src.nr()==dest.nr() && src.nc()==dest.nc()) ||
(dest.k()==src.k() || src.k()==1) && (src.num_samples()==1 && src.k()==1 && src.nr()==dest.nr() && src.nc()==dest.nc())) &&
is_same_object(src,dest) == false , ""); is_same_object(src,dest) == false ,
"\n\t dest.num_samples(): " << dest.num_samples()
<<"\n\t dest.k(): " << dest.k()
<<"\n\t dest.nr(): " << dest.nr()
<<"\n\t dest.nc(): " << dest.nc()
<<"\n\t src.num_samples(): " << src.num_samples()
<<"\n\t src.k(): " << src.k()
<<"\n\t src.nr(): " << src.nr()
<<"\n\t src.nc(): " << src.nc()
);
if (beta == 0 && alpha == 0) if (beta == 0 && alpha == 0)
{ {
......
...@@ -197,11 +197,20 @@ namespace dlib ...@@ -197,11 +197,20 @@ namespace dlib
) )
{ {
DLIB_CASSERT( DLIB_CASSERT(
(dest.num_samples()==src.num_samples() || src.num_samples()==1) && (have_same_dimensions(src, dest) ||
(dest.nr()==src.nr() || src.nr()==1) && (src.num_samples()==1 && src.k()==dest.k() && src.nr()==1 && src.nc()==1) ||
(dest.nc()==src.nc() || src.nc()==1) && (src.num_samples()==1 && src.k()==dest.k() && src.nr()==dest.nr() && src.nc()==dest.nc()) ||
(dest.k()==src.k() || src.k()==1) && (src.num_samples()==1 && src.k()==1 && src.nr()==dest.nr() && src.nc()==dest.nc())) &&
is_same_object(src,dest) == false , ""); is_same_object(src,dest) == false ,
"\n\t dest.num_samples(): " << dest.num_samples()
<<"\n\t dest.k(): " << dest.k()
<<"\n\t dest.nr(): " << dest.nr()
<<"\n\t dest.nc(): " << dest.nc()
<<"\n\t src.num_samples(): " << src.num_samples()
<<"\n\t src.k(): " << src.k()
<<"\n\t src.nr(): " << src.nr()
<<"\n\t src.nc(): " << src.nc()
);
CHECK_CUDNN(cudnnAddTensor_v3(context(), CHECK_CUDNN(cudnnAddTensor_v3(context(),
&alpha, &alpha,
......
...@@ -73,10 +73,11 @@ namespace dlib ...@@ -73,10 +73,11 @@ namespace dlib
); );
/*! /*!
requires requires
- dest.num_samples()==src.num_samples() || src.num_samples()==1 - One of the following is true:
- dest.nr()==src.nr() || src.nr()==1 - have_same_dimensions(src, dest)
- dest.nc()==src.nc() || src.nc()==1 - src.num_samples()==1 && src.k()==dest.k() && src.nr()==1 && src.nc()==1
- dest.k()==src.k() || src.k()==1 - src.num_samples()==1 && src.k()==dest.k() && src.nr()==dest.nr() && src.nc()==dest.nc()
- src.num_samples()==1 && src.k()==1 && src.nr()==dest.nr() && src.nc()==dest.nc()
- is_same_object(src,dest) == false - is_same_object(src,dest) == false
ensures ensures
- performs: dest = beta*dest + alpha*src - performs: dest = beta*dest + alpha*src
......
...@@ -375,10 +375,11 @@ namespace dlib { namespace tt ...@@ -375,10 +375,11 @@ namespace dlib { namespace tt
); );
/*! /*!
requires requires
- dest.num_samples()==src.num_samples() || src.num_samples()==1 - One of the following is true:
- dest.nr()==src.nr() || src.nr()==1 - have_same_dimensions(src, dest)
- dest.nc()==src.nc() || src.nc()==1 - src.num_samples()==1 && src.k()==dest.k() && src.nr()==1 && src.nc()==1
- dest.k()==src.k() || src.k()==1 - src.num_samples()==1 && src.k()==dest.k() && src.nr()==dest.nr() && src.nc()==dest.nc()
- src.num_samples()==1 && src.k()==1 && src.nr()==dest.nr() && src.nc()==dest.nc()
- is_same_object(src,dest) == false - is_same_object(src,dest) == false
ensures ensures
- performs: dest = beta*dest + alpha*src - performs: dest = beta*dest + alpha*src
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment