Commit 141b384b authored by Davis King's avatar Davis King

Added binding to cuRAND

parent 6539ea67
...@@ -456,9 +456,14 @@ if (NOT TARGET dlib) ...@@ -456,9 +456,14 @@ if (NOT TARGET dlib)
dnn/cuda_dlib.cu dnn/cuda_dlib.cu
dnn/cudnn_dlibapi.cpp dnn/cudnn_dlibapi.cpp
dnn/cublas_dlibapi.cpp dnn/cublas_dlibapi.cpp
dnn/curand_dlibapi.cpp
dnn/gpu_data.cpp dnn/gpu_data.cpp
) )
set(dlib_needed_libraries ${dlib_needed_libraries} ${CUDA_CUBLAS_LIBRARIES} ${cudnn}) set(dlib_needed_libraries ${dlib_needed_libraries}
${CUDA_CUBLAS_LIBRARIES}
${cudnn}
${CUDA_curand_LIBRARY}
)
include_directories(${cudnn_include}) include_directories(${cudnn_include})
else() else()
set(DLIB_USE_CUDA OFF CACHE STRING ${DLIB_USE_BLAS_STR} FORCE ) set(DLIB_USE_CUDA OFF CACHE STRING ${DLIB_USE_BLAS_STR} FORCE )
......
// Copyright (C) 2015 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#ifndef DLIB_DNN_CuRAND_CPP_
#define DLIB_DNN_CuRAND_CPP_
#ifdef DLIB_USE_CUDA
#include "curand_dlibapi.h"
#include <curand.h>
#include "../string.h"
namespace dlib
{
namespace cuda
{
// ----------------------------------------------------------------------------------------
// TODO, make into a macro that prints more information like the line number, etc.
static void check(curandStatus_t s)
{
switch(s)
{
case CURAND_STATUS_SUCCESS: return;
case CURAND_STATUS_NOT_INITIALIZED:
throw curand_error("CUDA Runtime API initialization failed.");
case CURAND_STATUS_LENGTH_NOT_MULTIPLE:
throw curand_error("The requested length must be a multiple of two.");
default:
throw curand_error("A call to cuRAND failed: " + cast_to_string(s));
}
}
// ----------------------------------------------------------------------------------------
curand_generator::
curand_generator(
unsigned long long seed
) : handle(nullptr)
{
curandGenerator_t gen;
check(curandCreateGenerator(&gen, CURAND_RNG_PSEUDO_DEFAULT));
handle = gen;
check(curandSetPseudoRandomGeneratorSeed(gen, seed));
}
curand_generator::
~curand_generator()
{
if (handle)
{
curandDestroyGenerator((curandGenerator_t)handle);
}
}
void curand_generator::
fill_gaussian (
tensor& data,
float mean,
float stddev
)
{
if (data.size() == 0)
return;
check(curandGenerateNormal((curandGenerator_t)handle,
data.device(),
data.size(),
mean,
stddev));
}
void curand_generator::
fill_uniform (
tensor& data
)
{
if (data.size() == 0)
return;
check(curandGenerateUniform((curandGenerator_t)handle, data.device(), data.size()));
}
// -----------------------------------------------------------------------------------
}
}
#endif // DLIB_USE_CUDA
#endif // DLIB_DNN_CuRAND_CPP_
// Copyright (C) 2015 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#ifndef DLIB_DNN_CuRAND_H_
#define DLIB_DNN_CuRAND_H_
#ifdef DLIB_USE_CUDA
#include "tensor.h"
#include "../error.h"
namespace dlib
{
namespace cuda
{
// -----------------------------------------------------------------------------------
struct curand_error : public error
{
curand_error(const std::string& message): error(message) {}
};
// ----------------------------------------------------------------------------------------
class curand_generator
{
public:
// not copyable
curand_generator(const curand_generator&) = delete;
curand_generator& operator=(const curand_generator&) = delete;
curand_generator() : curand_generator(0) {}
curand_generator(unsigned long long seed);
~curand_generator();
void fill_gaussian (
tensor& data,
float mean,
float stddev
);
/*!
requires
- data.size()%2 == 0
ensures
- Fills data with random numbers drawn from a Gaussian distribution
with the given mean and standard deviation.
!*/
void fill_uniform (
tensor& data
);
/*!
ensures
- Fills data with uniform random numbers in the range (0.0, 1.0].
!*/
private:
void* handle;
};
// -----------------------------------------------------------------------------------
}
}
#endif // DLIB_USE_CUDA
#endif // DLIB_DNN_CuRAND_H_
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment