Commit cbdeb160 authored by Davis King's avatar Davis King

Made add() faster by calling my own version for the simple pointwise add case.

parent 30005b7e
...@@ -210,6 +210,26 @@ namespace dlib ...@@ -210,6 +210,26 @@ namespace dlib
launch_kernel(_cuda_affine_transform4,max_jobs(dest.size()),dest.device(), src1.device(), src2.device(), dest.size(), A, B, C); launch_kernel(_cuda_affine_transform4,max_jobs(dest.size()),dest.device(), src1.device(), src2.device(), dest.size(), A, B, C);
} }
// ----------------------------------------------------------------------------------------
__global__ void _cuda_add_scaled(float* d, const float* s, size_t n, float scale)
{
for (auto i : grid_stride_range(0, n))
{
d[i] += scale*s[i];
}
}
void add_scaled(
tensor& dest,
const float scale,
const tensor& src
)
{
DLIB_CASSERT(dest.size()==src.size(),"");
launch_kernel(_cuda_add_scaled,max_jobs(dest.size()),dest.device(), src.device(), dest.size(), scale);
}
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
__global__ void _cuda_affine_transform5( __global__ void _cuda_affine_transform5(
......
...@@ -65,6 +65,13 @@ namespace dlib ...@@ -65,6 +65,13 @@ namespace dlib
const float D const float D
); );
// Note that this function isn't in the tt:: namespace because add_scaled() is
// called by cuda::add() so we don't need a tt:: version of add_scaled().
void add_scaled(
tensor& dest,
const float scale,
const tensor& src
);
// ----------------------------------------------------------------------------------- // -----------------------------------------------------------------------------------
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
#include <string> #include <string>
#include "cuda_utils.h" #include "cuda_utils.h"
#include "cpu_dlib.h" #include "cpu_dlib.h"
#include "cuda_dlib.h"
static const char* cudnn_get_error_string(cudnnStatus_t s) static const char* cudnn_get_error_string(cudnnStatus_t s)
{ {
...@@ -213,6 +214,14 @@ namespace dlib ...@@ -213,6 +214,14 @@ namespace dlib
<<"\n\t src.nc(): " << src.nc() <<"\n\t src.nc(): " << src.nc()
); );
if (dest.size() == src.size() && beta == 1)
{
// Call the dlib function in this case since it's faster than the one that
// comes with cuDNN (at least as of cuDNN v4).
add_scaled(dest, alpha, src);
return;
}
CHECK_CUDNN(cudnnAddTensor_v3(context(), CHECK_CUDNN(cudnnAddTensor_v3(context(),
&alpha, &alpha,
descriptor(src), descriptor(src),
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment