Commit a92477b9 authored by Davis King's avatar Davis King

Improved cuda_data_ptr.

parent bb8e0bc8
......@@ -37,28 +37,52 @@ namespace dlib
void memcpy(
void* dest,
const cuda_data_void_ptr& src
const cuda_data_void_ptr& src,
const size_t num
)
{
DLIB_ASSERT(num <= src.size());
if (src.size() != 0)
{
CHECK_CUDA(cudaMemcpy(dest, src.data(), src.size(), cudaMemcpyDefault));
CHECK_CUDA(cudaMemcpy(dest, src.data(), num, cudaMemcpyDefault));
}
}
// ------------------------------------------------------------------------------------
void memcpy(
cuda_data_void_ptr& dest,
const void* src
void* dest,
const cuda_data_void_ptr& src
)
{
memcpy(dest, src, src.size());
}
// ------------------------------------------------------------------------------------
void memcpy(
cuda_data_void_ptr dest,
const void* src,
const size_t num
)
{
DLIB_ASSERT(num <= dest.size());
if (dest.size() != 0)
{
CHECK_CUDA(cudaMemcpy(dest.data(), src, dest.size(), cudaMemcpyDefault));
CHECK_CUDA(cudaMemcpy(dest.data(), src, num, cudaMemcpyDefault));
}
}
// ------------------------------------------------------------------------------------
void memcpy(
cuda_data_void_ptr dest,
const void* src
)
{
memcpy(dest,src,dest.size());
}
// ------------------------------------------------------------------------------------
class cudnn_device_buffer
......
......@@ -7,6 +7,7 @@
#include <memory>
#include <vector>
#include "../assert.h"
namespace dlib
{
......@@ -45,12 +46,29 @@ namespace dlib
- returns the length of this buffer, in bytes.
!*/
cuda_data_void_ptr operator+ (size_t offset) const
/*!
requires
- offset < size()
ensures
- returns a pointer that is offset by the given amount.
!*/
{
DLIB_CASSERT(offset < num);
cuda_data_void_ptr temp;
temp.num = num-offset;
temp.pdata = std::shared_ptr<void>(pdata, ((char*)pdata.get())+offset);
return temp;
}
private:
size_t num = 0;
std::shared_ptr<void> pdata;
};
inline cuda_data_void_ptr operator+(size_t offset, const cuda_data_void_ptr& rhs) { return rhs+offset; }
// ------------------------------------------------------------------------------------
void memcpy(
......@@ -62,12 +80,27 @@ namespace dlib
- dest == a pointer to at least src.size() bytes on the host machine.
ensures
- copies the GPU data from src into dest.
- This routine is equivalent to performing: memcpy(dest,src,src.size())
!*/
void memcpy(
void* dest,
const cuda_data_void_ptr& src,
const size_t num
);
/*!
requires
- dest == a pointer to at least num bytes on the host machine.
- num <= src.size()
ensures
- copies the GPU data from src into dest. Copies only the first num bytes
of src to dest.
!*/
// ------------------------------------------------------------------------------------
void memcpy(
cuda_data_void_ptr& dest,
cuda_data_void_ptr dest,
const void* src
);
/*!
......@@ -75,6 +108,21 @@ namespace dlib
- dest == a pointer to at least src.size() bytes on the host machine.
ensures
- copies the host data from src to the GPU memory buffer dest.
- This routine is equivalent to performing: memcpy(dest,src,dest.size())
!*/
void memcpy(
cuda_data_void_ptr dest,
const void* src,
const size_t num
);
/*!
requires
- dest == a pointer to at least num bytes on the host machine.
- num <= dest.size()
ensures
- copies the host data from src to the GPU memory buffer dest. Copies only
the first num bytes of src to dest.
!*/
// ------------------------------------------------------------------------------------
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment