Commit 863702f0 authored by Evgeniy Fominov's avatar Evgeniy Fominov Committed by Davis E. King

cuDNN convolution algorithms shared workspace (#695)

* added shared workspace

* rewrite shared workspace code

* rename and device-based buffer allocation

* fix cudnn_device_buffer constructors
parent 3e471ade
...@@ -149,6 +149,32 @@ namespace dlib ...@@ -149,6 +149,32 @@ namespace dlib
// ------------------------------------------------------------------------------------ // ------------------------------------------------------------------------------------
class resizable_cuda_buffer
{
/*!
WHAT THIS OBJECT REPRESENTS
This is a block of memory on a CUDA device that will be automatically
resized if requestes size is larger than allocated
!*/
public:
cuda_data_void_ptr get(size_t size)
/*!
ensures
- This object will return the buffer of requested size of larger
- buffer.size() >= size
!*/
{
if (buffer.size() < size)
{
buffer.reset();
buffer = cuda_data_void_ptr(size);
}
return buffer;
}
private:
cuda_data_void_ptr buffer;
};
} }
} }
......
...@@ -117,7 +117,55 @@ namespace dlib ...@@ -117,7 +117,55 @@ namespace dlib
thread_local cudnn_context c; thread_local cudnn_context c;
return c.get_handle(); return c.get_handle();
} }
// ------------------------------------------------------------------------------------
class cudnn_device_buffer
{
public:
// not copyable
cudnn_device_buffer(const cudnn_device_buffer&) = delete;
cudnn_device_buffer& operator=(const cudnn_device_buffer&) = delete;
cudnn_device_buffer()
{
buffers.resize(16);
}
~cudnn_device_buffer()
{
}
std::shared_ptr<resizable_cuda_buffer> get_buffer (
)
{
int new_device_id;
CHECK_CUDA(cudaGetDevice(&new_device_id));
// make room for more devices if needed
if (new_device_id >= (long)buffers.size())
buffers.resize(new_device_id+16);
// If we don't have a buffer already for this device then make one
std::shared_ptr<resizable_cuda_buffer> buff = buffers[new_device_id].lock();
if (!buff)
{
buff = std::make_shared<resizable_cuda_buffer>();
buffers[new_device_id] = buff;
}
// Finally, return the buffer for the current device
return buff;
}
private:
std::vector<std::weak_ptr<resizable_cuda_buffer>> buffers;
};
static std::shared_ptr<resizable_cuda_buffer> device_global_buffer()
{
thread_local cudnn_device_buffer buffer;
return buffer.get_buffer();
}
// ------------------------------------------------------------------------------------ // ------------------------------------------------------------------------------------
class cudnn_activation_descriptor class cudnn_activation_descriptor
...@@ -705,14 +753,8 @@ namespace dlib ...@@ -705,14 +753,8 @@ namespace dlib
filter_handle(nullptr), filter_handle(nullptr),
conv_handle(nullptr), conv_handle(nullptr),
forward_algo(0), forward_algo(0),
forward_workspace_size_in_bytes(0),
forward_workspace(nullptr),
backward_data_algo(0), backward_data_algo(0),
backward_data_workspace_size_in_bytes(0), backward_filters_algo(0)
backward_data_workspace(nullptr),
backward_filters_algo(0),
backward_filters_workspace_size_in_bytes(0),
backward_filters_workspace(nullptr)
{ {
clear(); clear();
} }
...@@ -732,24 +774,6 @@ namespace dlib ...@@ -732,24 +774,6 @@ namespace dlib
out_nr = 0; out_nr = 0;
out_nc = 0; out_nc = 0;
if (forward_workspace)
cudaFree(forward_workspace);
forward_workspace = nullptr;
forward_algo = 0;
forward_workspace_size_in_bytes = 0;
if (backward_data_workspace)
cudaFree(backward_data_workspace);
backward_data_workspace = nullptr;
backward_data_algo = 0;
backward_data_workspace_size_in_bytes = 0;
if (backward_filters_workspace)
cudaFree(backward_filters_workspace);
backward_filters_workspace = nullptr;
backward_filters_algo = 0;
backward_filters_workspace_size_in_bytes = 0;
stride_y = 0; stride_y = 0;
stride_x = 0; stride_x = 0;
padding_y = 0; padding_y = 0;
...@@ -762,6 +786,16 @@ namespace dlib ...@@ -762,6 +786,16 @@ namespace dlib
filters_k = 0; filters_k = 0;
filters_nr = 0; filters_nr = 0;
filters_nc = 0; filters_nc = 0;
forward_algo = 0;
backward_data_algo = 0;
backward_filters_algo = 0;
forward_workspace_size_in_bytes = 0;
backward_data_workspace_size_in_bytes = 0;
backward_filters_workspace_size_in_bytes = 0;
workspace.reset();
} }
void tensor_conv:: void tensor_conv::
...@@ -872,8 +906,6 @@ namespace dlib ...@@ -872,8 +906,6 @@ namespace dlib
descriptor(dest_desc), descriptor(dest_desc),
forward_best_algo, forward_best_algo,
&forward_workspace_size_in_bytes)); &forward_workspace_size_in_bytes));
CHECK_CUDA(cudaMalloc(&forward_workspace, forward_workspace_size_in_bytes));
// Pick which backward data algorithm we will use and allocate the // Pick which backward data algorithm we will use and allocate the
// necessary workspace buffer. // necessary workspace buffer.
...@@ -888,6 +920,7 @@ namespace dlib ...@@ -888,6 +920,7 @@ namespace dlib
std::numeric_limits<size_t>::max(), std::numeric_limits<size_t>::max(),
&backward_data_best_algo)); &backward_data_best_algo));
backward_data_algo = backward_data_best_algo; backward_data_algo = backward_data_best_algo;
CHECK_CUDNN(cudnnGetConvolutionBackwardDataWorkspaceSize( CHECK_CUDNN(cudnnGetConvolutionBackwardDataWorkspaceSize(
context(), context(),
(const cudnnFilterDescriptor_t)filter_handle, (const cudnnFilterDescriptor_t)filter_handle,
...@@ -896,8 +929,6 @@ namespace dlib ...@@ -896,8 +929,6 @@ namespace dlib
descriptor(data), descriptor(data),
backward_data_best_algo, backward_data_best_algo,
&backward_data_workspace_size_in_bytes)); &backward_data_workspace_size_in_bytes));
CHECK_CUDA(cudaMalloc(&backward_data_workspace, backward_data_workspace_size_in_bytes));
// Pick which backward filters algorithm we will use and allocate the // Pick which backward filters algorithm we will use and allocate the
// necessary workspace buffer. // necessary workspace buffer.
...@@ -934,7 +965,8 @@ namespace dlib ...@@ -934,7 +965,8 @@ namespace dlib
(const cudnnFilterDescriptor_t)filter_handle, (const cudnnFilterDescriptor_t)filter_handle,
backward_filters_best_algo, backward_filters_best_algo,
&backward_filters_workspace_size_in_bytes)); &backward_filters_workspace_size_in_bytes));
CHECK_CUDA(cudaMalloc(&backward_filters_workspace, backward_filters_workspace_size_in_bytes));
workspace = device_global_buffer();
} }
catch(...) catch(...)
{ {
...@@ -997,6 +1029,7 @@ namespace dlib ...@@ -997,6 +1029,7 @@ namespace dlib
const float alpha = 1; const float alpha = 1;
const float beta = add_to_output ? 1 : 0; const float beta = add_to_output ? 1 : 0;
CHECK_CUDNN(cudnnConvolutionForward( CHECK_CUDNN(cudnnConvolutionForward(
context(), context(),
&alpha, &alpha,
...@@ -1006,7 +1039,7 @@ namespace dlib ...@@ -1006,7 +1039,7 @@ namespace dlib
filters.device(), filters.device(),
(const cudnnConvolutionDescriptor_t)conv_handle, (const cudnnConvolutionDescriptor_t)conv_handle,
(cudnnConvolutionFwdAlgo_t)forward_algo, (cudnnConvolutionFwdAlgo_t)forward_algo,
forward_workspace, workspace->get(forward_workspace_size_in_bytes),
forward_workspace_size_in_bytes, forward_workspace_size_in_bytes,
&beta, &beta,
descriptor(output), descriptor(output),
...@@ -1032,7 +1065,7 @@ namespace dlib ...@@ -1032,7 +1065,7 @@ namespace dlib
gradient_input.device(), gradient_input.device(),
(const cudnnConvolutionDescriptor_t)conv_handle, (const cudnnConvolutionDescriptor_t)conv_handle,
(cudnnConvolutionBwdDataAlgo_t)backward_data_algo, (cudnnConvolutionBwdDataAlgo_t)backward_data_algo,
backward_data_workspace, workspace->get(backward_data_workspace_size_in_bytes),
backward_data_workspace_size_in_bytes, backward_data_workspace_size_in_bytes,
&beta, &beta,
descriptor(data_gradient), descriptor(data_gradient),
...@@ -1057,7 +1090,7 @@ namespace dlib ...@@ -1057,7 +1090,7 @@ namespace dlib
gradient_input.device(), gradient_input.device(),
(const cudnnConvolutionDescriptor_t)conv_handle, (const cudnnConvolutionDescriptor_t)conv_handle,
(cudnnConvolutionBwdFilterAlgo_t)backward_filters_algo, (cudnnConvolutionBwdFilterAlgo_t)backward_filters_algo,
backward_filters_workspace, workspace->get(backward_filters_workspace_size_in_bytes),
backward_filters_workspace_size_in_bytes, backward_filters_workspace_size_in_bytes,
&beta, &beta,
(const cudnnFilterDescriptor_t)filter_handle, (const cudnnFilterDescriptor_t)filter_handle,
...@@ -1482,7 +1515,6 @@ namespace dlib ...@@ -1482,7 +1515,6 @@ namespace dlib
} }
// ------------------------------------------------------------------------------------ // ------------------------------------------------------------------------------------
} }
} }
......
// Copyright (C) 2015 Davis E. King (davis@dlib.net) // Copyright (C) 2015 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license. // License: Boost Software License See LICENSE.txt for the full license.
#ifndef DLIB_DNN_CuDNN_H_ #ifndef DLIB_DNN_CuDNN_H_
#define DLIB_DNN_CuDNN_H_ #define DLIB_DNN_CuDNN_H_
...@@ -6,6 +6,8 @@ ...@@ -6,6 +6,8 @@
#ifdef DLIB_USE_CUDA #ifdef DLIB_USE_CUDA
#include "cuda_errors.h" #include "cuda_errors.h"
#include <memory>
#include "cuda_data_ptr.h"
namespace dlib namespace dlib
{ {
...@@ -260,16 +262,13 @@ namespace dlib ...@@ -260,16 +262,13 @@ namespace dlib
int out_nc; int out_nc;
int forward_algo; int forward_algo;
size_t forward_workspace_size_in_bytes;
void* forward_workspace;
int backward_data_algo; int backward_data_algo;
size_t backward_data_workspace_size_in_bytes;
void* backward_data_workspace;
int backward_filters_algo; int backward_filters_algo;
size_t forward_workspace_size_in_bytes;
size_t backward_data_workspace_size_in_bytes;
size_t backward_filters_workspace_size_in_bytes; size_t backward_filters_workspace_size_in_bytes;
void* backward_filters_workspace; std::shared_ptr<resizable_cuda_buffer> workspace;
}; };
// ------------------------------------------------------------------------------------ // ------------------------------------------------------------------------------------
...@@ -490,6 +489,8 @@ namespace dlib ...@@ -490,6 +489,8 @@ namespace dlib
is_same_object(grad, gradient_input)==true is_same_object(grad, gradient_input)==true
!*/ !*/
// ------------------------------------------------------------------------------------ // ------------------------------------------------------------------------------------
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment