Commit ee6e54b4 authored by Davis King's avatar Davis King

Figured out the *undocumented* requirements for calling cuDNN's

cudnnAddTensor() function and updated the specs and asserts accordingly.
parent 0babe27a
......@@ -76,11 +76,21 @@ namespace dlib
)
{
DLIB_CASSERT(
(dest.num_samples()==src.num_samples() || src.num_samples()==1) &&
(dest.nr()==src.nr() || src.nr()==1) &&
(dest.nc()==src.nc() || src.nc()==1) &&
(dest.k()==src.k() || src.k()==1) &&
is_same_object(src,dest) == false , "");
(have_same_dimensions(src, dest) ||
(src.num_samples()==1 && src.k()==dest.k() && src.nr()==1 && src.nc()==1) ||
(src.num_samples()==1 && src.k()==dest.k() && src.nr()==dest.nr() && src.nc()==dest.nc()) ||
(src.num_samples()==1 && src.k()==1 && src.nr()==dest.nr() && src.nc()==dest.nc())) &&
is_same_object(src,dest) == false ,
"\n\t dest.num_samples(): " << dest.num_samples()
<<"\n\t dest.k(): " << dest.k()
<<"\n\t dest.nr(): " << dest.nr()
<<"\n\t dest.nc(): " << dest.nc()
<<"\n\t src.num_samples(): " << src.num_samples()
<<"\n\t src.k(): " << src.k()
<<"\n\t src.nr(): " << src.nr()
<<"\n\t src.nc(): " << src.nc()
);
if (beta == 0 && alpha == 0)
{
......
......@@ -197,11 +197,20 @@ namespace dlib
)
{
DLIB_CASSERT(
(dest.num_samples()==src.num_samples() || src.num_samples()==1) &&
(dest.nr()==src.nr() || src.nr()==1) &&
(dest.nc()==src.nc() || src.nc()==1) &&
(dest.k()==src.k() || src.k()==1) &&
is_same_object(src,dest) == false , "");
(have_same_dimensions(src, dest) ||
(src.num_samples()==1 && src.k()==dest.k() && src.nr()==1 && src.nc()==1) ||
(src.num_samples()==1 && src.k()==dest.k() && src.nr()==dest.nr() && src.nc()==dest.nc()) ||
(src.num_samples()==1 && src.k()==1 && src.nr()==dest.nr() && src.nc()==dest.nc())) &&
is_same_object(src,dest) == false ,
"\n\t dest.num_samples(): " << dest.num_samples()
<<"\n\t dest.k(): " << dest.k()
<<"\n\t dest.nr(): " << dest.nr()
<<"\n\t dest.nc(): " << dest.nc()
<<"\n\t src.num_samples(): " << src.num_samples()
<<"\n\t src.k(): " << src.k()
<<"\n\t src.nr(): " << src.nr()
<<"\n\t src.nc(): " << src.nc()
);
CHECK_CUDNN(cudnnAddTensor_v3(context(),
&alpha,
......
......@@ -73,10 +73,11 @@ namespace dlib
);
/*!
requires
- dest.num_samples()==src.num_samples() || src.num_samples()==1
- dest.nr()==src.nr() || src.nr()==1
- dest.nc()==src.nc() || src.nc()==1
- dest.k()==src.k() || src.k()==1
- One of the following is true:
- have_same_dimensions(src, dest)
- src.num_samples()==1 && src.k()==dest.k() && src.nr()==1 && src.nc()==1
- src.num_samples()==1 && src.k()==dest.k() && src.nr()==dest.nr() && src.nc()==dest.nc()
- src.num_samples()==1 && src.k()==1 && src.nr()==dest.nr() && src.nc()==dest.nc()
- is_same_object(src,dest) == false
ensures
- performs: dest = beta*dest + alpha*src
......
......@@ -375,10 +375,11 @@ namespace dlib { namespace tt
);
/*!
requires
- dest.num_samples()==src.num_samples() || src.num_samples()==1
- dest.nr()==src.nr() || src.nr()==1
- dest.nc()==src.nc() || src.nc()==1
- dest.k()==src.k() || src.k()==1
- One of the following is true:
- have_same_dimensions(src, dest)
- src.num_samples()==1 && src.k()==dest.k() && src.nr()==1 && src.nc()==1
- src.num_samples()==1 && src.k()==dest.k() && src.nr()==dest.nr() && src.nc()==dest.nc()
- src.num_samples()==1 && src.k()==1 && src.nr()==dest.nr() && src.nc()==dest.nc()
- is_same_object(src,dest) == false
ensures
- performs: dest = beta*dest + alpha*src
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment