Commit ccb148b4 authored by Davis King's avatar Davis King

Cleaned up cuda error handling code

parent dbbce825
...@@ -9,27 +9,36 @@ ...@@ -9,27 +9,36 @@
#include <cublas_v2.h> #include <cublas_v2.h>
namespace dlib static const char* cublas_get_error_string(cublasStatus_t s)
{ {
namespace cuda switch(s)
{ {
case CUBLAS_STATUS_NOT_INITIALIZED:
return "CUDA Runtime API initialization failed.";
case CUBLAS_STATUS_ALLOC_FAILED:
return "CUDA Resources could not be allocated.";
default:
return "A call to cuBLAS failed";
}
}
// ---------------------------------------------------------------------------------------- // Check the return value of a call to the cuBLAS runtime for an error condition.
#define CHECK_CUBLAS(call) \
{ \
const cublasStatus_t error = call; \
if (error != CUBLAS_STATUS_SUCCESS) \
{ \
std::ostringstream sout; \
sout << "Error while calling " << #call << " in file " << __FILE__ << ":" << __LINE__ << ". ";\
sout << "code: " << error << ", reason: " << cublas_get_error_string(error);\
throw dlib::cublas_error(sout.str()); \
} \
}
// TODO, make into a macro that prints more information like the line number, etc. namespace dlib
static void check(cublasStatus_t s) {
{ namespace cuda
switch(s) {
{
case CUBLAS_STATUS_SUCCESS: return;
case CUBLAS_STATUS_NOT_INITIALIZED:
throw cublas_error("CUDA Runtime API initialization failed.");
case CUBLAS_STATUS_ALLOC_FAILED:
throw cublas_error("CUDA Resources could not be allocated.");
default:
throw cublas_error("A call to cuBLAS failed");
}
}
// ----------------------------------------------------------------------------------- // -----------------------------------------------------------------------------------
...@@ -42,7 +51,7 @@ namespace dlib ...@@ -42,7 +51,7 @@ namespace dlib
cublas_context() cublas_context()
{ {
check(cublasCreate(&handle)); CHECK_CUBLAS(cublasCreate(&handle));
} }
~cublas_context() ~cublas_context()
{ {
...@@ -117,7 +126,7 @@ namespace dlib ...@@ -117,7 +126,7 @@ namespace dlib
} }
const int k = trans_rhs ? rhs_nc : rhs_nr; const int k = trans_rhs ? rhs_nc : rhs_nr;
check(cublasSgemm(context(), CHECK_CUBLAS(cublasSgemm(context(),
transb, transb,
transa, transa,
dest_nc, dest_nr, k, dest_nc, dest_nr, k,
......
...@@ -6,20 +6,13 @@ ...@@ -6,20 +6,13 @@
#ifdef DLIB_USE_CUDA #ifdef DLIB_USE_CUDA
#include "tensor.h" #include "tensor.h"
#include "../error.h" #include "cuda_errors.h"
namespace dlib namespace dlib
{ {
namespace cuda namespace cuda
{ {
// -----------------------------------------------------------------------------------
struct cublas_error : public error
{
cublas_error(const std::string& message): error(message) {}
};
// ----------------------------------------------------------------------------------- // -----------------------------------------------------------------------------------
void gemm ( void gemm (
......
...@@ -30,6 +30,28 @@ namespace dlib ...@@ -30,6 +30,28 @@ namespace dlib
cudnn_error(const std::string& message): cuda_error(message) {} cudnn_error(const std::string& message): cuda_error(message) {}
}; };
struct curand_error : public cuda_error
{
/*!
WHAT THIS OBJECT REPRESENTS
This is the exception thrown if any calls to the NVIDIA cuRAND library
returns an error.
!*/
curand_error(const std::string& message): cuda_error(message) {}
};
struct cublas_error : public cuda_error
{
/*!
WHAT THIS OBJECT REPRESENTS
This is the exception thrown if any calls to the NVIDIA cuBLAS library
returns an error.
!*/
cublas_error(const std::string& message): cuda_error(message) {}
};
} }
......
...@@ -12,6 +12,34 @@ ...@@ -12,6 +12,34 @@
#include <string> #include <string>
#include "cuda_utils.h" #include "cuda_utils.h"
static const char* cudnn_get_error_string(cudnnStatus_t s)
{
switch(s)
{
case CUDNN_STATUS_NOT_INITIALIZED:
return "CUDA Runtime API initialization failed.";
case CUDNN_STATUS_ALLOC_FAILED:
return "CUDA Resources could not be allocated.";
case CUDNN_STATUS_BAD_PARAM:
return "CUDNN_STATUS_BAD_PARAM";
default:
return "A call to cuDNN failed";
}
}
// Check the return value of a call to the cuDNN runtime for an error condition.
#define CHECK_CUDNN(call) \
{ \
const cudnnStatus_t error = call; \
if (error != CUDNN_STATUS_SUCCESS) \
{ \
std::ostringstream sout; \
sout << "Error while calling " << #call << " in file " << __FILE__ << ":" << __LINE__ << ". ";\
sout << "code: " << error << ", reason: " << cudnn_get_error_string(error);\
throw dlib::cudnn_error(sout.str()); \
} \
}
namespace dlib namespace dlib
{ {
...@@ -19,23 +47,6 @@ namespace dlib ...@@ -19,23 +47,6 @@ namespace dlib
namespace cuda namespace cuda
{ {
// TODO, make into a macro that prints more information like the line number, etc.
static void check(cudnnStatus_t s)
{
switch(s)
{
case CUDNN_STATUS_SUCCESS: return;
case CUDNN_STATUS_NOT_INITIALIZED:
throw cudnn_error("CUDA Runtime API initialization failed.");
case CUDNN_STATUS_ALLOC_FAILED:
throw cudnn_error("CUDA Resources could not be allocated.");
case CUDNN_STATUS_BAD_PARAM:
throw cudnn_error("CUDNN_STATUS_BAD_PARAM");
default:
throw cudnn_error("A call to cuDNN failed: " + std::string(cudnnGetErrorString(s)));
}
}
// ------------------------------------------------------------------------------------ // ------------------------------------------------------------------------------------
static cudnnTensorDescriptor_t descriptor(const tensor& t) static cudnnTensorDescriptor_t descriptor(const tensor& t)
...@@ -58,7 +69,7 @@ namespace dlib ...@@ -58,7 +69,7 @@ namespace dlib
cudnn_context() cudnn_context()
{ {
check(cudnnCreate(&handle)); CHECK_CUDNN(cudnnCreate(&handle));
} }
~cudnn_context() ~cudnn_context()
...@@ -112,10 +123,10 @@ namespace dlib ...@@ -112,10 +123,10 @@ namespace dlib
else else
{ {
cudnnTensorDescriptor_t h; cudnnTensorDescriptor_t h;
check(cudnnCreateTensorDescriptor(&h)); CHECK_CUDNN(cudnnCreateTensorDescriptor(&h));
handle = h; handle = h;
check(cudnnSetTensor4dDescriptor((cudnnTensorDescriptor_t)handle, CHECK_CUDNN(cudnnSetTensor4dDescriptor((cudnnTensorDescriptor_t)handle,
CUDNN_TENSOR_NCHW, CUDNN_TENSOR_NCHW,
CUDNN_DATA_FLOAT, CUDNN_DATA_FLOAT,
n, n,
...@@ -137,7 +148,7 @@ namespace dlib ...@@ -137,7 +148,7 @@ namespace dlib
{ {
int nStride, cStride, hStride, wStride; int nStride, cStride, hStride, wStride;
cudnnDataType_t datatype; cudnnDataType_t datatype;
check(cudnnGetTensor4dDescriptor((cudnnTensorDescriptor_t)handle, CHECK_CUDNN(cudnnGetTensor4dDescriptor((cudnnTensorDescriptor_t)handle,
&datatype, &datatype,
&n, &n,
&k, &k,
...@@ -172,7 +183,7 @@ namespace dlib ...@@ -172,7 +183,7 @@ namespace dlib
(dest.nc()==src.nc() || src.nc()==1) && (dest.nc()==src.nc() || src.nc()==1) &&
(dest.k()==src.k() || src.k()==1), ""); (dest.k()==src.k() || src.k()==1), "");
check(cudnnAddTensor_v3(context(), CHECK_CUDNN(cudnnAddTensor_v3(context(),
&alpha, &alpha,
descriptor(src), descriptor(src),
src.device(), src.device(),
...@@ -188,7 +199,7 @@ namespace dlib ...@@ -188,7 +199,7 @@ namespace dlib
{ {
if (t.size() == 0) if (t.size() == 0)
return; return;
check(cudnnSetTensor(context(), CHECK_CUDNN(cudnnSetTensor(context(),
descriptor(t), descriptor(t),
t.device(), t.device(),
&value)); &value));
...@@ -201,7 +212,7 @@ namespace dlib ...@@ -201,7 +212,7 @@ namespace dlib
{ {
if (t.size() == 0) if (t.size() == 0)
return; return;
check(cudnnScaleTensor(context(), CHECK_CUDNN(cudnnScaleTensor(context(),
descriptor(t), descriptor(t),
t.device(), t.device(),
&value)); &value));
...@@ -222,7 +233,7 @@ namespace dlib ...@@ -222,7 +233,7 @@ namespace dlib
const float alpha = 1; const float alpha = 1;
const float beta = 0; const float beta = 0;
check(cudnnConvolutionBackwardBias(context(), CHECK_CUDNN(cudnnConvolutionBackwardBias(context(),
&alpha, &alpha,
descriptor(gradient_input), descriptor(gradient_input),
gradient_input.device(), gradient_input.device(),
...@@ -304,16 +315,16 @@ namespace dlib ...@@ -304,16 +315,16 @@ namespace dlib
stride_y = stride_y_; stride_y = stride_y_;
stride_x = stride_x_; stride_x = stride_x_;
check(cudnnCreateFilterDescriptor((cudnnFilterDescriptor_t*)&filter_handle)); CHECK_CUDNN(cudnnCreateFilterDescriptor((cudnnFilterDescriptor_t*)&filter_handle));
check(cudnnSetFilter4dDescriptor((cudnnFilterDescriptor_t)filter_handle, CHECK_CUDNN(cudnnSetFilter4dDescriptor((cudnnFilterDescriptor_t)filter_handle,
CUDNN_DATA_FLOAT, CUDNN_DATA_FLOAT,
filters.num_samples(), filters.num_samples(),
filters.k(), filters.k(),
filters.nr(), filters.nr(),
filters.nc())); filters.nc()));
check(cudnnCreateConvolutionDescriptor((cudnnConvolutionDescriptor_t*)&conv_handle)); CHECK_CUDNN(cudnnCreateConvolutionDescriptor((cudnnConvolutionDescriptor_t*)&conv_handle));
check(cudnnSetConvolution2dDescriptor((cudnnConvolutionDescriptor_t)conv_handle, CHECK_CUDNN(cudnnSetConvolution2dDescriptor((cudnnConvolutionDescriptor_t)conv_handle,
filters.nr()/2, // vertical padding filters.nr()/2, // vertical padding
filters.nc()/2, // horizontal padding filters.nc()/2, // horizontal padding
stride_y, stride_y,
...@@ -321,7 +332,7 @@ namespace dlib ...@@ -321,7 +332,7 @@ namespace dlib
1, 1, // must be 1,1 1, 1, // must be 1,1
CUDNN_CONVOLUTION)); // could also be CUDNN_CROSS_CORRELATION CUDNN_CONVOLUTION)); // could also be CUDNN_CROSS_CORRELATION
check(cudnnGetConvolution2dForwardOutputDim( CHECK_CUDNN(cudnnGetConvolution2dForwardOutputDim(
(const cudnnConvolutionDescriptor_t)conv_handle, (const cudnnConvolutionDescriptor_t)conv_handle,
descriptor(data), descriptor(data),
(const cudnnFilterDescriptor_t)filter_handle, (const cudnnFilterDescriptor_t)filter_handle,
...@@ -336,7 +347,7 @@ namespace dlib ...@@ -336,7 +347,7 @@ namespace dlib
// Pick which forward algorithm we will use and allocate the necessary // Pick which forward algorithm we will use and allocate the necessary
// workspace buffer. // workspace buffer.
cudnnConvolutionFwdAlgo_t forward_best_algo; cudnnConvolutionFwdAlgo_t forward_best_algo;
check(cudnnGetConvolutionForwardAlgorithm( CHECK_CUDNN(cudnnGetConvolutionForwardAlgorithm(
context(), context(),
descriptor(data), descriptor(data),
(const cudnnFilterDescriptor_t)filter_handle, (const cudnnFilterDescriptor_t)filter_handle,
...@@ -346,7 +357,7 @@ namespace dlib ...@@ -346,7 +357,7 @@ namespace dlib
std::numeric_limits<size_t>::max(), std::numeric_limits<size_t>::max(),
&forward_best_algo)); &forward_best_algo));
forward_algo = forward_best_algo; forward_algo = forward_best_algo;
check(cudnnGetConvolutionForwardWorkspaceSize( CHECK_CUDNN(cudnnGetConvolutionForwardWorkspaceSize(
context(), context(),
descriptor(data), descriptor(data),
(const cudnnFilterDescriptor_t)filter_handle, (const cudnnFilterDescriptor_t)filter_handle,
...@@ -360,7 +371,7 @@ namespace dlib ...@@ -360,7 +371,7 @@ namespace dlib
// Pick which backward data algorithm we will use and allocate the // Pick which backward data algorithm we will use and allocate the
// necessary workspace buffer. // necessary workspace buffer.
cudnnConvolutionBwdDataAlgo_t backward_data_best_algo; cudnnConvolutionBwdDataAlgo_t backward_data_best_algo;
check(cudnnGetConvolutionBackwardDataAlgorithm( CHECK_CUDNN(cudnnGetConvolutionBackwardDataAlgorithm(
context(), context(),
(const cudnnFilterDescriptor_t)filter_handle, (const cudnnFilterDescriptor_t)filter_handle,
descriptor(dest_desc), descriptor(dest_desc),
...@@ -370,7 +381,7 @@ namespace dlib ...@@ -370,7 +381,7 @@ namespace dlib
std::numeric_limits<size_t>::max(), std::numeric_limits<size_t>::max(),
&backward_data_best_algo)); &backward_data_best_algo));
backward_data_algo = backward_data_best_algo; backward_data_algo = backward_data_best_algo;
check(cudnnGetConvolutionBackwardDataWorkspaceSize( CHECK_CUDNN(cudnnGetConvolutionBackwardDataWorkspaceSize(
context(), context(),
(const cudnnFilterDescriptor_t)filter_handle, (const cudnnFilterDescriptor_t)filter_handle,
descriptor(dest_desc), descriptor(dest_desc),
...@@ -384,7 +395,7 @@ namespace dlib ...@@ -384,7 +395,7 @@ namespace dlib
// Pick which backward filters algorithm we will use and allocate the // Pick which backward filters algorithm we will use and allocate the
// necessary workspace buffer. // necessary workspace buffer.
cudnnConvolutionBwdFilterAlgo_t backward_filters_best_algo; cudnnConvolutionBwdFilterAlgo_t backward_filters_best_algo;
check(cudnnGetConvolutionBackwardFilterAlgorithm( CHECK_CUDNN(cudnnGetConvolutionBackwardFilterAlgorithm(
context(), context(),
descriptor(data), descriptor(data),
descriptor(dest_desc), descriptor(dest_desc),
...@@ -394,7 +405,7 @@ namespace dlib ...@@ -394,7 +405,7 @@ namespace dlib
std::numeric_limits<size_t>::max(), std::numeric_limits<size_t>::max(),
&backward_filters_best_algo)); &backward_filters_best_algo));
backward_filters_algo = backward_filters_best_algo; backward_filters_algo = backward_filters_best_algo;
check(cudnnGetConvolutionBackwardFilterWorkspaceSize( CHECK_CUDNN(cudnnGetConvolutionBackwardFilterWorkspaceSize(
context(), context(),
descriptor(data), descriptor(data),
descriptor(dest_desc), descriptor(dest_desc),
...@@ -434,7 +445,7 @@ namespace dlib ...@@ -434,7 +445,7 @@ namespace dlib
const float alpha = 1; const float alpha = 1;
const float beta = 0; const float beta = 0;
check(cudnnConvolutionForward( CHECK_CUDNN(cudnnConvolutionForward(
context(), context(),
&alpha, &alpha,
descriptor(data), descriptor(data),
...@@ -460,7 +471,7 @@ namespace dlib ...@@ -460,7 +471,7 @@ namespace dlib
const float beta = 1; const float beta = 1;
check(cudnnConvolutionBackwardData_v3(context(), CHECK_CUDNN(cudnnConvolutionBackwardData_v3(context(),
&alpha, &alpha,
(const cudnnFilterDescriptor_t)filter_handle, (const cudnnFilterDescriptor_t)filter_handle,
filters.device(), filters.device(),
...@@ -484,7 +495,7 @@ namespace dlib ...@@ -484,7 +495,7 @@ namespace dlib
{ {
const float alpha = 1; const float alpha = 1;
const float beta = 0; const float beta = 0;
check(cudnnConvolutionBackwardFilter_v3(context(), CHECK_CUDNN(cudnnConvolutionBackwardFilter_v3(context(),
&alpha, &alpha,
descriptor(data), descriptor(data),
data.device(), data.device(),
...@@ -535,10 +546,10 @@ namespace dlib ...@@ -535,10 +546,10 @@ namespace dlib
stride_x = stride_x_; stride_x = stride_x_;
stride_y = stride_y_; stride_y = stride_y_;
cudnnPoolingDescriptor_t poolingDesc; cudnnPoolingDescriptor_t poolingDesc;
check(cudnnCreatePoolingDescriptor(&poolingDesc)); CHECK_CUDNN(cudnnCreatePoolingDescriptor(&poolingDesc));
handle = poolingDesc; handle = poolingDesc;
check(cudnnSetPooling2dDescriptor(poolingDesc, CHECK_CUDNN(cudnnSetPooling2dDescriptor(poolingDesc,
CUDNN_POOLING_MAX, CUDNN_POOLING_MAX,
window_height, window_height,
window_width, window_width,
...@@ -559,7 +570,7 @@ namespace dlib ...@@ -559,7 +570,7 @@ namespace dlib
int outC; int outC;
int outH; int outH;
int outW; int outW;
check(cudnnGetPooling2dForwardOutputDim((const cudnnPoolingDescriptor_t)handle, CHECK_CUDNN(cudnnGetPooling2dForwardOutputDim((const cudnnPoolingDescriptor_t)handle,
descriptor(src), descriptor(src),
&outN, &outN,
&outC, &outC,
...@@ -574,7 +585,7 @@ namespace dlib ...@@ -574,7 +585,7 @@ namespace dlib
DLIB_CASSERT(dest.nr() == src.nr()/stride_y,""); DLIB_CASSERT(dest.nr() == src.nr()/stride_y,"");
DLIB_CASSERT(dest.nc() == src.nc()/stride_x,""); DLIB_CASSERT(dest.nc() == src.nc()/stride_x,"");
check(cudnnPoolingForward(context(), CHECK_CUDNN(cudnnPoolingForward(context(),
(const cudnnPoolingDescriptor_t)handle, (const cudnnPoolingDescriptor_t)handle,
&alpha, &alpha,
descriptor(src), descriptor(src),
...@@ -596,7 +607,7 @@ namespace dlib ...@@ -596,7 +607,7 @@ namespace dlib
const float alpha = 1; const float alpha = 1;
const float beta = 0; const float beta = 0;
check(cudnnPoolingBackward(context(), CHECK_CUDNN(cudnnPoolingBackward(context(),
(const cudnnPoolingDescriptor_t)handle, (const cudnnPoolingDescriptor_t)handle,
&alpha, &alpha,
descriptor(dest), descriptor(dest),
...@@ -625,7 +636,7 @@ namespace dlib ...@@ -625,7 +636,7 @@ namespace dlib
const float alpha = 1; const float alpha = 1;
const float beta = 0; const float beta = 0;
check(cudnnSoftmaxForward(context(), CHECK_CUDNN(cudnnSoftmaxForward(context(),
CUDNN_SOFTMAX_ACCURATE, CUDNN_SOFTMAX_ACCURATE,
CUDNN_SOFTMAX_MODE_CHANNEL, CUDNN_SOFTMAX_MODE_CHANNEL,
&alpha, &alpha,
...@@ -651,7 +662,7 @@ namespace dlib ...@@ -651,7 +662,7 @@ namespace dlib
const float alpha = 1; const float alpha = 1;
const float beta = 0; const float beta = 0;
check(cudnnSoftmaxBackward(context(), CHECK_CUDNN(cudnnSoftmaxBackward(context(),
CUDNN_SOFTMAX_ACCURATE, CUDNN_SOFTMAX_ACCURATE,
CUDNN_SOFTMAX_MODE_CHANNEL, CUDNN_SOFTMAX_MODE_CHANNEL,
&alpha, &alpha,
...@@ -678,7 +689,7 @@ namespace dlib ...@@ -678,7 +689,7 @@ namespace dlib
const float alpha = 1; const float alpha = 1;
const float beta = 0; const float beta = 0;
check(cudnnActivationForward(context(), CHECK_CUDNN(cudnnActivationForward(context(),
CUDNN_ACTIVATION_SIGMOID, CUDNN_ACTIVATION_SIGMOID,
&alpha, &alpha,
descriptor(src), descriptor(src),
...@@ -702,7 +713,7 @@ namespace dlib ...@@ -702,7 +713,7 @@ namespace dlib
const float alpha = 1; const float alpha = 1;
const float beta = 0; const float beta = 0;
check(cudnnActivationBackward(context(), CHECK_CUDNN(cudnnActivationBackward(context(),
CUDNN_ACTIVATION_SIGMOID, CUDNN_ACTIVATION_SIGMOID,
&alpha, &alpha,
descriptor(dest), descriptor(dest),
...@@ -729,7 +740,7 @@ namespace dlib ...@@ -729,7 +740,7 @@ namespace dlib
const float alpha = 1; const float alpha = 1;
const float beta = 0; const float beta = 0;
check(cudnnActivationForward(context(), CHECK_CUDNN(cudnnActivationForward(context(),
CUDNN_ACTIVATION_RELU, CUDNN_ACTIVATION_RELU,
&alpha, &alpha,
descriptor(src), descriptor(src),
...@@ -753,7 +764,7 @@ namespace dlib ...@@ -753,7 +764,7 @@ namespace dlib
const float alpha = 1; const float alpha = 1;
const float beta = 0; const float beta = 0;
check(cudnnActivationBackward(context(), CHECK_CUDNN(cudnnActivationBackward(context(),
CUDNN_ACTIVATION_RELU, CUDNN_ACTIVATION_RELU,
&alpha, &alpha,
descriptor(dest), descriptor(dest),
...@@ -780,7 +791,7 @@ namespace dlib ...@@ -780,7 +791,7 @@ namespace dlib
const float alpha = 1; const float alpha = 1;
const float beta = 0; const float beta = 0;
check(cudnnActivationForward(context(), CHECK_CUDNN(cudnnActivationForward(context(),
CUDNN_ACTIVATION_TANH, CUDNN_ACTIVATION_TANH,
&alpha, &alpha,
descriptor(src), descriptor(src),
...@@ -804,7 +815,7 @@ namespace dlib ...@@ -804,7 +815,7 @@ namespace dlib
const float alpha = 1; const float alpha = 1;
const float beta = 0; const float beta = 0;
check(cudnnActivationBackward(context(), CHECK_CUDNN(cudnnActivationBackward(context(),
CUDNN_ACTIVATION_TANH, CUDNN_ACTIVATION_TANH,
&alpha, &alpha,
descriptor(dest), descriptor(dest),
......
...@@ -9,27 +9,36 @@ ...@@ -9,27 +9,36 @@
#include <curand.h> #include <curand.h>
#include "../string.h" #include "../string.h"
namespace dlib static const char* curand_get_error_string(curandStatus_t s)
{ {
namespace cuda switch(s)
{ {
case CURAND_STATUS_NOT_INITIALIZED:
return "CUDA Runtime API initialization failed.";
case CURAND_STATUS_LENGTH_NOT_MULTIPLE:
return "The requested length must be a multiple of two.";
default:
return "A call to cuRAND failed";
}
}
// ---------------------------------------------------------------------------------------- // Check the return value of a call to the cuDNN runtime for an error condition.
#define CHECK_CURAND(call) \
{ \
const curandStatus_t error = call; \
if (error != CURAND_STATUS_SUCCESS) \
{ \
std::ostringstream sout; \
sout << "Error while calling " << #call << " in file " << __FILE__ << ":" << __LINE__ << ". ";\
sout << "code: " << error << ", reason: " << curand_get_error_string(error);\
throw dlib::curand_error(sout.str()); \
} \
}
// TODO, make into a macro that prints more information like the line number, etc. namespace dlib
static void check(curandStatus_t s) {
{ namespace cuda
switch(s) {
{
case CURAND_STATUS_SUCCESS: return;
case CURAND_STATUS_NOT_INITIALIZED:
throw curand_error("CUDA Runtime API initialization failed.");
case CURAND_STATUS_LENGTH_NOT_MULTIPLE:
throw curand_error("The requested length must be a multiple of two.");
default:
throw curand_error("A call to cuRAND failed: " + cast_to_string(s));
}
}
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
...@@ -39,10 +48,10 @@ namespace dlib ...@@ -39,10 +48,10 @@ namespace dlib
) : handle(nullptr) ) : handle(nullptr)
{ {
curandGenerator_t gen; curandGenerator_t gen;
check(curandCreateGenerator(&gen, CURAND_RNG_PSEUDO_DEFAULT)); CHECK_CURAND(curandCreateGenerator(&gen, CURAND_RNG_PSEUDO_DEFAULT));
handle = gen; handle = gen;
check(curandSetPseudoRandomGeneratorSeed(gen, seed)); CHECK_CURAND(curandSetPseudoRandomGeneratorSeed(gen, seed));
} }
curand_generator:: curand_generator::
...@@ -64,7 +73,7 @@ namespace dlib ...@@ -64,7 +73,7 @@ namespace dlib
if (data.size() == 0) if (data.size() == 0)
return; return;
check(curandGenerateNormal((curandGenerator_t)handle, CHECK_CURAND(curandGenerateNormal((curandGenerator_t)handle,
data.device(), data.device(),
data.size(), data.size(),
mean, mean,
...@@ -79,7 +88,7 @@ namespace dlib ...@@ -79,7 +88,7 @@ namespace dlib
if (data.size() == 0) if (data.size() == 0)
return; return;
check(curandGenerateUniform((curandGenerator_t)handle, data.device(), data.size())); CHECK_CURAND(curandGenerateUniform((curandGenerator_t)handle, data.device(), data.size()));
} }
// ----------------------------------------------------------------------------------- // -----------------------------------------------------------------------------------
......
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
#ifdef DLIB_USE_CUDA #ifdef DLIB_USE_CUDA
#include "tensor.h" #include "tensor.h"
#include "../error.h" #include "cuda_errors.h"
namespace dlib namespace dlib
{ {
...@@ -15,13 +15,6 @@ namespace dlib ...@@ -15,13 +15,6 @@ namespace dlib
// ----------------------------------------------------------------------------------- // -----------------------------------------------------------------------------------
struct curand_error : public error
{
curand_error(const std::string& message): error(message) {}
};
// ----------------------------------------------------------------------------------------
class curand_generator class curand_generator
{ {
public: public:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment