Commit d248a225 authored by Davis King's avatar Davis King

Added the launch_kernel() function that launches a kernel by smartly picking

the number of threads and blocks rather than using the hard coded numbers I had
in there.  This makes some functions noticeably faster.

Also added a dot() function that is fully asynchronous.
parent 8466d332
...@@ -80,16 +80,16 @@ namespace dlib ...@@ -80,16 +80,16 @@ namespace dlib
const auto s2 = src2.host(); const auto s2 = src2.host();
if (dest.size() == src1.size() && src1.size() == src2.size()) if (dest.size() == src1.size() && src1.size() == src2.size())
{ {
_cuda_multiply1<<<512,512>>>(dest.device(), src1.device(), src2.device(), src1.size()); launch_kernel(_cuda_multiply1,max_jobs(dest.size()),dest.device(), src1.device(), src2.device(), src1.size());
} }
else if (dest.num_samples() == 1) else if (dest.num_samples() == 1)
{ {
_cuda_multiply2<<<512,512>>>(dest.device(), src1.device(), src2.device(), launch_kernel(_cuda_multiply2,max_jobs(dest.size()),dest.device(), src1.device(), src2.device(),
dest.size(), src1.size(), src2.size(), max_size); dest.size(), src1.size(), src2.size(), max_size);
} }
else else
{ {
_cuda_multiply3<<<512,512>>>(dest.device(), src1.device(), src2.device(), launch_kernel(_cuda_multiply3,max_jobs(dest.size()),dest.device(), src1.device(), src2.device(),
dest.size(), src1.size(), src2.size()); dest.size(), src1.size(), src2.size());
} }
} }
...@@ -150,12 +150,13 @@ namespace dlib ...@@ -150,12 +150,13 @@ namespace dlib
if (have_same_dimensions(dest, src1) && if (have_same_dimensions(dest, src1) &&
have_same_dimensions(dest, src2)) have_same_dimensions(dest, src2))
{ {
_cuda_add1<<<512,512>>>(dest.device(), src1.device(), src2.device(), dest.size()); launch_kernel(_cuda_add1,max_jobs(dest.size()), dest.device(), src1.device(), src2.device(), dest.size());
} }
else else
{ {
// Otherwise, do the more complex version with bounds checking. // Otherwise, do the more complex version with bounds checking.
_cuda_add2<<<512,512>>>(dest.device(), src1.device(), src2.device(), launch_kernel(_cuda_add2,max_jobs(dest.size()),
dest.device(), src1.device(), src2.device(),
dest.num_samples(), dest.k(), dest.nr(), dest.nc(), dest.num_samples(), dest.k(), dest.nr(), dest.nc(),
src1.num_samples(), src1.k(), src1.nr(), src1.nc(), src1.num_samples(), src1.k(), src1.nr(), src1.nc(),
src2.num_samples(), src2.k(), src2.nr(), src2.nc() src2.num_samples(), src2.k(), src2.nr(), src2.nc()
...@@ -166,7 +167,7 @@ namespace dlib ...@@ -166,7 +167,7 @@ namespace dlib
// ------------------------------------------------------------------------------------ // ------------------------------------------------------------------------------------
__global__ void _cuda_affine_transform(float* d, const float* s, size_t n, float A, float B) __global__ void _cuda_affine_transform1(float* d, const float* s, size_t n, float A, float B)
{ {
for (auto i : grid_stride_range(0, n)) for (auto i : grid_stride_range(0, n))
{ {
...@@ -182,12 +183,12 @@ namespace dlib ...@@ -182,12 +183,12 @@ namespace dlib
) )
{ {
DLIB_CASSERT(dest.size()==src.size(),""); DLIB_CASSERT(dest.size()==src.size(),"");
_cuda_affine_transform<<<512,512>>>(dest.device(), src.device(), src.size(), A, B); launch_kernel(_cuda_affine_transform1,max_jobs(dest.size()),dest.device(), src.device(), src.size(), A, B);
} }
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
__global__ void _cuda_affine_transform(float* d, const float* s1, const float* s2, size_t n, float A, float B, float C) __global__ void _cuda_affine_transform4(float* d, const float* s1, const float* s2, size_t n, float A, float B, float C)
{ {
for (auto i : grid_stride_range(0, n)) for (auto i : grid_stride_range(0, n))
{ {
...@@ -206,12 +207,12 @@ namespace dlib ...@@ -206,12 +207,12 @@ namespace dlib
{ {
DLIB_CASSERT(dest.size()==src1.size(),""); DLIB_CASSERT(dest.size()==src1.size(),"");
DLIB_CASSERT(dest.size()==src2.size(),""); DLIB_CASSERT(dest.size()==src2.size(),"");
_cuda_affine_transform<<<512,512>>>(dest.device(), src1.device(), src2.device(), dest.size(), A, B, C); launch_kernel(_cuda_affine_transform4,max_jobs(dest.size()),dest.device(), src1.device(), src2.device(), dest.size(), A, B, C);
} }
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
__global__ void _cuda_affine_transform( __global__ void _cuda_affine_transform5(
float* d, const float* s1, const float* s2, const float* s3, size_t n, float A, float B, float C, float D float* d, const float* s1, const float* s2, const float* s3, size_t n, float A, float B, float C, float D
) )
{ {
...@@ -235,7 +236,7 @@ namespace dlib ...@@ -235,7 +236,7 @@ namespace dlib
DLIB_CASSERT(dest.size()==src1.size(),""); DLIB_CASSERT(dest.size()==src1.size(),"");
DLIB_CASSERT(dest.size()==src2.size(),""); DLIB_CASSERT(dest.size()==src2.size(),"");
DLIB_CASSERT(dest.size()==src3.size(),""); DLIB_CASSERT(dest.size()==src3.size(),"");
_cuda_affine_transform<<<512,512>>>(dest.device(), src1.device(), launch_kernel(_cuda_affine_transform5,max_jobs(dest.size()),dest.device(), src1.device(),
src2.device(), src3.device(), dest.size(), A, B, C, D); src2.device(), src3.device(), dest.size(), A, B, C, D);
} }
...@@ -273,11 +274,11 @@ namespace dlib ...@@ -273,11 +274,11 @@ namespace dlib
if (A.num_samples() == 1) if (A.num_samples() == 1)
{ {
_cuda_affine_transform3<<<512,512>>>(dest.device(), src.device(), src.size(), A.device(), B.device(), A.size()); launch_kernel(_cuda_affine_transform3,max_jobs(dest.size()),dest.device(), src.device(), src.size(), A.device(), B.device(), A.size());
} }
else else
{ {
_cuda_affine_transform2<<<512,512>>>(dest.device(), src.device(), src.size(), A.device(), B.device()); launch_kernel(_cuda_affine_transform2,max_jobs(dest.size()),dest.device(), src.device(), src.size(), A.device(), B.device());
} }
} }
...@@ -305,7 +306,7 @@ namespace dlib ...@@ -305,7 +306,7 @@ namespace dlib
gradient_input.nc() == grad.nc() && gradient_input.nc() == grad.nc() &&
gradient_input.size() > 0,""); gradient_input.size() > 0,"");
_add_bias_gradient<<<512,512>>>(grad.device(), gradient_input.device(), grad.size(), gradient_input.size()); launch_kernel(_add_bias_gradient,max_jobs(grad.size()),grad.device(), gradient_input.device(), grad.size(), gradient_input.size());
} }
// ----------------------------------------------------------------------------------- // -----------------------------------------------------------------------------------
...@@ -324,11 +325,37 @@ namespace dlib ...@@ -324,11 +325,37 @@ namespace dlib
float thresh float thresh
) )
{ {
_cuda_threshold<<<512,512>>>(data.device(), data.size(), thresh); launch_kernel(_cuda_threshold,max_jobs(data.size()),data.device(), data.size(), thresh);
} }
// ------------------------------------------------------------------------------------ // ------------------------------------------------------------------------------------
__global__ void _cuda_dot(const float* a, const float* b, size_t n, float* result)
{
// Parallel sum everything into local temp variables.
float temp = 0;
for(auto i : grid_stride_range(0, n))
temp += a[i]*b[i];
// Then do the warp reduce add thing to merge into one output value.
warp_reduce_atomic_add(*result, temp);
}
void dot (
const tensor& a,
const tensor& b,
tensor& result,
size_t idx
)
{
DLIB_CASSERT(a.size() == b.size(), "");
DLIB_CASSERT(idx < result.size(), "");
launch_kernel(_cuda_dot, max_jobs(a.size()), a.device(), b.device(), a.size(), result.device()+idx);
}
// ----------------------------------------------------------------------------------------
} }
} }
...@@ -89,6 +89,15 @@ namespace dlib ...@@ -89,6 +89,15 @@ namespace dlib
float thresh float thresh
); );
// ----------------------------------------------------------------------------------------
void dot (
const tensor& a,
const tensor& b,
tensor& result,
size_t idx
);
// ------------------------------------------------------------------------------------ // ------------------------------------------------------------------------------------
// ------------------------------------------------------------------------------------ // ------------------------------------------------------------------------------------
// ------------------------------------------------------------------------------------ // ------------------------------------------------------------------------------------
......
...@@ -127,6 +127,56 @@ namespace dlib ...@@ -127,6 +127,56 @@ namespace dlib
atomicAdd(&out, val); atomicAdd(&out, val);
} }
// ------------------------------------------------------------------------------------
struct max_jobs
{
max_jobs(size_t n) : num(n) {}
size_t num;
};
template <typename Kernel, typename... T>
void launch_kernel (
Kernel K,
T ...args
)
/*!
ensures
- launches the given kernel K(args...). The point of this function is to
automatically set the kernel launch parameters to something reasonable
based on the properties of the kernel and the current GPU card.
!*/
{
int num_blocks, num_threads;
CHECK_CUDA(cudaOccupancyMaxPotentialBlockSize(&num_blocks,&num_threads,K));
K<<<num_blocks,num_threads>>>(args...);
}
template <typename Kernel, typename... T>
void launch_kernel (
Kernel K,
max_jobs m,
T ...args
)
/*!
ensures
- This function is just like launch_kernel(K,args...) except that you can
additionally supply a max_jobs number that tells it how many possible
total threads could be used. This is useful when launching potentially
small jobs that might not need the number of threads suggested by
launch_kernel().
!*/
{
int num_blocks, num_threads;
CHECK_CUDA(cudaOccupancyMaxPotentialBlockSize(&num_blocks,&num_threads,K));
// Check if the job is really small and we don't really need to launch a kernel
// with this many blocks and threads.
if (num_blocks*num_threads > m.num)
num_blocks = (m.num+num_threads-1)/num_threads;
K<<<num_blocks,num_threads>>>(args...);
}
// ------------------------------------------------------------------------------------ // ------------------------------------------------------------------------------------
class grid_stride_range class grid_stride_range
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment