Commit ccb148b4 authored by Davis King's avatar Davis King

Cleaned up cuda error handling code

parent dbbce825
...@@ -9,27 +9,36 @@ ...@@ -9,27 +9,36 @@
#include <cublas_v2.h> #include <cublas_v2.h>
namespace dlib static const char* cublas_get_error_string(cublasStatus_t s)
{ {
namespace cuda switch(s)
{ {
case CUBLAS_STATUS_NOT_INITIALIZED:
return "CUDA Runtime API initialization failed.";
case CUBLAS_STATUS_ALLOC_FAILED:
return "CUDA Resources could not be allocated.";
default:
return "A call to cuBLAS failed";
}
}
// ---------------------------------------------------------------------------------------- // Check the return value of a call to the cuBLAS runtime for an error condition.
#define CHECK_CUBLAS(call) \
{ \
const cublasStatus_t error = call; \
if (error != CUBLAS_STATUS_SUCCESS) \
{ \
std::ostringstream sout; \
sout << "Error while calling " << #call << " in file " << __FILE__ << ":" << __LINE__ << ". ";\
sout << "code: " << error << ", reason: " << cublas_get_error_string(error);\
throw dlib::cublas_error(sout.str()); \
} \
}
// TODO, make into a macro that prints more information like the line number, etc. namespace dlib
static void check(cublasStatus_t s) {
{ namespace cuda
switch(s) {
{
case CUBLAS_STATUS_SUCCESS: return;
case CUBLAS_STATUS_NOT_INITIALIZED:
throw cublas_error("CUDA Runtime API initialization failed.");
case CUBLAS_STATUS_ALLOC_FAILED:
throw cublas_error("CUDA Resources could not be allocated.");
default:
throw cublas_error("A call to cuBLAS failed");
}
}
// ----------------------------------------------------------------------------------- // -----------------------------------------------------------------------------------
...@@ -42,7 +51,7 @@ namespace dlib ...@@ -42,7 +51,7 @@ namespace dlib
cublas_context() cublas_context()
{ {
check(cublasCreate(&handle)); CHECK_CUBLAS(cublasCreate(&handle));
} }
~cublas_context() ~cublas_context()
{ {
...@@ -117,7 +126,7 @@ namespace dlib ...@@ -117,7 +126,7 @@ namespace dlib
} }
const int k = trans_rhs ? rhs_nc : rhs_nr; const int k = trans_rhs ? rhs_nc : rhs_nr;
check(cublasSgemm(context(), CHECK_CUBLAS(cublasSgemm(context(),
transb, transb,
transa, transa,
dest_nc, dest_nr, k, dest_nc, dest_nr, k,
......
...@@ -6,20 +6,13 @@ ...@@ -6,20 +6,13 @@
#ifdef DLIB_USE_CUDA #ifdef DLIB_USE_CUDA
#include "tensor.h" #include "tensor.h"
#include "../error.h" #include "cuda_errors.h"
namespace dlib namespace dlib
{ {
namespace cuda namespace cuda
{ {
// -----------------------------------------------------------------------------------
struct cublas_error : public error
{
cublas_error(const std::string& message): error(message) {}
};
// ----------------------------------------------------------------------------------- // -----------------------------------------------------------------------------------
void gemm ( void gemm (
......
...@@ -30,6 +30,28 @@ namespace dlib ...@@ -30,6 +30,28 @@ namespace dlib
cudnn_error(const std::string& message): cuda_error(message) {} cudnn_error(const std::string& message): cuda_error(message) {}
}; };
struct curand_error : public cuda_error
{
/*!
WHAT THIS OBJECT REPRESENTS
This is the exception thrown if any calls to the NVIDIA cuRAND library
returns an error.
!*/
curand_error(const std::string& message): cuda_error(message) {}
};
struct cublas_error : public cuda_error
{
/*!
WHAT THIS OBJECT REPRESENTS
This is the exception thrown if any calls to the NVIDIA cuBLAS library
returns an error.
!*/
cublas_error(const std::string& message): cuda_error(message) {}
};
} }
......
This diff is collapsed.
...@@ -9,27 +9,36 @@ ...@@ -9,27 +9,36 @@
#include <curand.h> #include <curand.h>
#include "../string.h" #include "../string.h"
namespace dlib static const char* curand_get_error_string(curandStatus_t s)
{ {
namespace cuda switch(s)
{ {
case CURAND_STATUS_NOT_INITIALIZED:
return "CUDA Runtime API initialization failed.";
case CURAND_STATUS_LENGTH_NOT_MULTIPLE:
return "The requested length must be a multiple of two.";
default:
return "A call to cuRAND failed";
}
}
// ---------------------------------------------------------------------------------------- // Check the return value of a call to the cuDNN runtime for an error condition.
#define CHECK_CURAND(call) \
{ \
const curandStatus_t error = call; \
if (error != CURAND_STATUS_SUCCESS) \
{ \
std::ostringstream sout; \
sout << "Error while calling " << #call << " in file " << __FILE__ << ":" << __LINE__ << ". ";\
sout << "code: " << error << ", reason: " << curand_get_error_string(error);\
throw dlib::curand_error(sout.str()); \
} \
}
// TODO, make into a macro that prints more information like the line number, etc. namespace dlib
static void check(curandStatus_t s) {
{ namespace cuda
switch(s) {
{
case CURAND_STATUS_SUCCESS: return;
case CURAND_STATUS_NOT_INITIALIZED:
throw curand_error("CUDA Runtime API initialization failed.");
case CURAND_STATUS_LENGTH_NOT_MULTIPLE:
throw curand_error("The requested length must be a multiple of two.");
default:
throw curand_error("A call to cuRAND failed: " + cast_to_string(s));
}
}
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
...@@ -39,10 +48,10 @@ namespace dlib ...@@ -39,10 +48,10 @@ namespace dlib
) : handle(nullptr) ) : handle(nullptr)
{ {
curandGenerator_t gen; curandGenerator_t gen;
check(curandCreateGenerator(&gen, CURAND_RNG_PSEUDO_DEFAULT)); CHECK_CURAND(curandCreateGenerator(&gen, CURAND_RNG_PSEUDO_DEFAULT));
handle = gen; handle = gen;
check(curandSetPseudoRandomGeneratorSeed(gen, seed)); CHECK_CURAND(curandSetPseudoRandomGeneratorSeed(gen, seed));
} }
curand_generator:: curand_generator::
...@@ -64,7 +73,7 @@ namespace dlib ...@@ -64,7 +73,7 @@ namespace dlib
if (data.size() == 0) if (data.size() == 0)
return; return;
check(curandGenerateNormal((curandGenerator_t)handle, CHECK_CURAND(curandGenerateNormal((curandGenerator_t)handle,
data.device(), data.device(),
data.size(), data.size(),
mean, mean,
...@@ -79,7 +88,7 @@ namespace dlib ...@@ -79,7 +88,7 @@ namespace dlib
if (data.size() == 0) if (data.size() == 0)
return; return;
check(curandGenerateUniform((curandGenerator_t)handle, data.device(), data.size())); CHECK_CURAND(curandGenerateUniform((curandGenerator_t)handle, data.device(), data.size()));
} }
// ----------------------------------------------------------------------------------- // -----------------------------------------------------------------------------------
......
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
#ifdef DLIB_USE_CUDA #ifdef DLIB_USE_CUDA
#include "tensor.h" #include "tensor.h"
#include "../error.h" #include "cuda_errors.h"
namespace dlib namespace dlib
{ {
...@@ -15,13 +15,6 @@ namespace dlib ...@@ -15,13 +15,6 @@ namespace dlib
// ----------------------------------------------------------------------------------- // -----------------------------------------------------------------------------------
struct curand_error : public error
{
curand_error(const std::string& message): error(message) {}
};
// ----------------------------------------------------------------------------------------
class curand_generator class curand_generator
{ {
public: public:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment