Commit 7d8d7fe4 authored by Davis King's avatar Davis King

Added cuda_data, a templated container for GPU memory.

parent c1bc27c5
......@@ -11,6 +11,10 @@
#include <cuda_runtime.h>
#include <sstream>
#include <iostream>
#include <memory>
#include <vector>
#include <type_traits>
// Check the return value of a call to the CUDA runtime for an error condition.
......@@ -26,6 +30,91 @@ do{
} \
}while(false)
// ----------------------------------------------------------------------------------------
namespace dlib
{
namespace cuda
{
template <typename T>
class cuda_data
{
/*!
WHAT THIS OBJECT REPRESENTS
This is a block of memory on a CUDA device.
!*/
public:
static_assert(std::is_standard_layout<T>::value, "You can only create basic standard layout types on the GPU");
cuda_data() = default;
cuda_data(size_t n) : num(n)
/*!
ensures
- This object will allocate a device memory buffer of n T objects.
- #size() == n
!*/
{
if (n == 0)
return;
T* data = nullptr;
CHECK_CUDA(cudaMalloc(&data, n*sizeof(T)));
pdata.reset((T*)data, [](T* ptr){
auto err = cudaFree(ptr);
if(err!=cudaSuccess)
std::cerr << "cudaFree() failed. Reason: " << cudaGetErrorString(err) << std::endl;
});
}
T* data() { return pdata.get(); }
const T* data() const { return pdata.get(); }
operator T*() { return pdata.get(); }
operator const T*() const { return pdata.get(); }
void reset() { pdata.reset(); }
size_t size() const { return num; }
private:
size_t num = 0;
std::shared_ptr<T> pdata;
};
template <typename T>
void memcpy(
std::vector<T>& dest,
const cuda_data<T>& src
)
{
dest.resize(src.size());
if (src.size() != 0)
{
CHECK_CUDA(cudaMemcpy(dest.data(), src.data(), src.size()*sizeof(T), cudaMemcpyDefault));
}
}
template <typename T>
void memcpy(
cuda_data<T>& dest,
const std::vector<T>& src
)
{
if (dest.size() != src.size())
dest = cuda_data<T>(src.size());
if (src.size() != 0)
{
CHECK_CUDA(cudaMemcpy(dest.data(), src.data(), src.size()*sizeof(T), cudaMemcpyDefault));
}
}
}
}
// ----------------------------------------------------------------------------------------
#ifdef __CUDACC__
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment