Commit fb2fa0f7 authored by Davis King's avatar Davis King

Added another add() function for adding tensors. This one lets you add

tensors with different sizes and it will zero pad them as needed.
parent ca776404
......@@ -108,6 +108,67 @@ namespace dlib
}
}
// ----------------------------------------------------------------------------------------
void add (
tensor& dest,
const tensor& src1,
const tensor& src2
)
{
auto d = dest.host();
auto s1 = src1.host();
auto s2 = src2.host();
// Do the simple and fast version if everything has the same dimensions
if (have_same_dimensions(dest, src1) &&
have_same_dimensions(dest, src2))
{
for (size_t i = 0; i < dest.size(); ++i)
d[i] = s1[i] + s2[i];
return;
}
// Otherwise, do the more complex version with bounds checking.
for (long n = 0; n < dest.num_samples(); ++n)
{
for (long k = 0; k < dest.k(); ++k)
{
for (long r = 0; r < dest.nr(); ++r)
{
for (long c = 0; c < dest.nc(); ++c)
{
float v1 = 0;
float v2 = 0;
// if this index is inside src1
if (n < src1.num_samples() &&
k < src1.k() &&
r < src1.nr() &&
c < src1.nc() )
{
const auto s_idx = ((n*src1.k() + k)*src1.nr() + r)*src1.nc() + c;
v1 = s1[s_idx];
}
// if this index is inside src2
if (n < src2.num_samples() &&
k < src2.k() &&
r < src2.nr() &&
c < src2.nc() )
{
const auto s_idx = ((n*src2.k() + k)*src2.nr() + r)*src2.nc() + c;
v2 = s2[s_idx];
}
*d = v1 + v2;
++d;
}
}
}
}
}
// ----------------------------------------------------------------------------------------
void assign_bias_gradient (
......
......@@ -37,6 +37,12 @@ namespace dlib
const tensor& gradient_input
);
void add (
tensor& dest,
const tensor& src1,
const tensor& src2
);
// -----------------------------------------------------------------------------------
void affine_transform(
......
......@@ -94,7 +94,77 @@ namespace dlib
}
}
// -----------------------------------------------------------------------------------
// ------------------------------------------------------------------------------------
__global__ void _cuda_add1(float* d, const float* s1, const float* s2, size_t n)
{
for (auto i : grid_stride_range(0, n))
{
d[i] = s1[i]+s2[i];
}
}
__global__ void _cuda_add2(float* d, const float* s1, const float* s2,
size_t dn, size_t dk, size_t dr, size_t dc,
size_t s1n, size_t s1k, size_t s1r, size_t s1c,
size_t s2n, size_t s2k, size_t s2r, size_t s2c)
{
for (auto i : grid_stride_range(0, dn*dk*dr*dc))
{
size_t n,k,r,c;
unpack_idx(i, dk,dr,dc, n,k,r,c);
float v1 = 0;
float v2 = 0;
if (n < s1n &&
k < s1k &&
r < s1r &&
c < s1c )
{
v1 = s1[pack_idx(s1k,s1r,s1c, n,k,r,c)];
}
if (n < s2n &&
k < s2k &&
r < s2r &&
c < s2c )
{
v2 = s2[pack_idx(s2k,s2r,s2c, n,k,r,c)];
}
d[i] = v1+v2;
}
}
void add (
tensor& dest,
const tensor& src1,
const tensor& src2
)
{
if (dest.size() == 0)
return;
// Do the simple and fast version if everything has the same dimensions
if (have_same_dimensions(dest, src1) &&
have_same_dimensions(dest, src2))
{
_cuda_add1<<<512,512>>>(dest.device(), src1.device(), src2.device(), dest.size());
}
else
{
// Otherwise, do the more complex version with bounds checking.
_cuda_add2<<<512,512>>>(dest.device(), src1.device(), src2.device(),
dest.num_samples(), dest.k(), dest.nr(), dest.nc(),
src1.num_samples(), src1.k(), src1.nr(), src1.nc(),
src2.num_samples(), src2.k(), src2.nr(), src2.nc()
);
}
}
// ------------------------------------------------------------------------------------
__global__ void _cuda_affine_transform(float* d, const float* s, size_t n, float A, float B)
{
......
......@@ -30,6 +30,12 @@ namespace dlib
const tensor& src2
);
void add (
tensor& dest,
const tensor& src1,
const tensor& src2
);
// -----------------------------------------------------------------------------------
void affine_transform(
......
......@@ -323,6 +323,21 @@ namespace dlib { namespace tt
#endif
}
// ----------------------------------------------------------------------------------------
void add (
tensor& dest,
const tensor& src1,
const tensor& src2
)
{
#ifdef DLIB_USE_CUDA
cuda::add(dest, src1, src2);
#else
cpu::add(dest, src1, src2);
#endif
}
// ----------------------------------------------------------------------------------------
void assign_conv_bias_gradient (
......
......@@ -424,6 +424,20 @@ namespace dlib { namespace tt
tensor.
!*/
// ----------------------------------------------------------------------------------------
void add (
tensor& dest,
const tensor& src1,
const tensor& src2
);
/*!
ensures
- performs: dest = src1 + src2
The addition happens pointwise according to 4D tensor arithmetic. If the
dimensions don't match then missing elements are presumed to be equal to 0.
!*/
// ----------------------------------------------------------------------------------------
void assign_conv_bias_gradient (
......
......@@ -493,6 +493,60 @@ namespace
// ----------------------------------------------------------------------------------------
#ifdef DLIB_USE_CUDA
void test_add()
{
print_spinner();
dlib::rand rnd;
tt::tensor_rand trnd;
for (int iter = 0; iter < 300; ++iter)
{
resizable_tensor dest1(rnd.get_random_32bit_number()%4+1,
rnd.get_random_32bit_number()%4+1,
rnd.get_random_32bit_number()%4+1,
rnd.get_random_32bit_number()%4+1);
resizable_tensor dest2;
dest2.copy_size(dest1);
resizable_tensor src1(rnd.get_random_32bit_number()%4+1,
rnd.get_random_32bit_number()%4+1,
rnd.get_random_32bit_number()%4+1,
rnd.get_random_32bit_number()%4+1);
resizable_tensor src2(rnd.get_random_32bit_number()%4+1,
rnd.get_random_32bit_number()%4+1,
rnd.get_random_32bit_number()%4+1,
rnd.get_random_32bit_number()%4+1);
trnd.fill_uniform(dest1);
trnd.fill_uniform(dest2);
trnd.fill_uniform(src1);
trnd.fill_uniform(src2);
cpu::add(dest1, src1, src2);
cuda::add(dest2, src1, src2);
DLIB_TEST(max(abs(mat(dest1) - mat(dest2))) < 1e-5);
}
// make sure we have a test for the case where all tensors have the same
// dimensions.
resizable_tensor dest1(3,4,5,6);
resizable_tensor dest2;
resizable_tensor src1;
resizable_tensor src2;
dest2.copy_size(dest1);
src1.copy_size(dest1);
src2.copy_size(dest1);
trnd.fill_uniform(dest1);
trnd.fill_uniform(dest2);
trnd.fill_uniform(src1);
trnd.fill_uniform(src2);
cpu::add(dest1, src1, src2);
cuda::add(dest2, src1, src2);
DLIB_TEST(max(abs(mat(dest1) - mat(dest2))) < 1e-5);
}
void test_more_ops(const long nr, const long nc)
{
using namespace dlib::tt;
......@@ -950,6 +1004,7 @@ namespace
test_more_ops(10000,4);
compare_bn_gpu_and_cpu();
compare_bn_conv_gpu_and_cpu();
test_add();
#endif
test_max_pool(1,1,2,3);
test_max_pool(3,3,1,1);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment