Commit e1b2c950 authored by Davis King's avatar Davis King

Made the tensor contain a cudnn tensor descriptor and also added a few other

minor refinements.
parent 94979c79
......@@ -24,8 +24,12 @@ namespace dlib
class cublas_context
{
public:
// not copyable
cublas_context(const cublas_context&) = delete;
cublas_context& operator=(const cublas_context&) = delete;
// but is movable
cublas_context(const cublas_context&&) = default;
cublas_context& operator=(const cublas_context&&) = default;
cublas_context()
{
......@@ -36,6 +40,10 @@ namespace dlib
{
// TODO
}
const void* get_handle (
) const { return handle; }
private:
void* handle;
......
......@@ -108,8 +108,12 @@ namespace dlib
!*/
public:
// not copyable
dropout(const dropout&) = delete;
dropout& operator=(const dropout&) = delete;
// but is movable
dropout(const dropout&&) = default;
dropout& operator=(const dropout&&) = default;
dropout(float drop_rate = 0.5);
dropout(float drop_rate, int seed);
......
......@@ -9,6 +9,9 @@
namespace dlib
{
class tensor;
class resizable_tensor;
namespace cuda
{
......@@ -24,8 +27,12 @@ namespace dlib
class cudnn_context
{
public:
// not copyable
cudnn_context(const cudnn_context&) = delete;
cudnn_context& operator=(const cudnn_context&) = delete;
// but is movable
cudnn_context(cudnn_context&&) = default;
cudnn_context& operator=(cudnn_context&&) = default;
cudnn_context()
{
......@@ -38,6 +45,10 @@ namespace dlib
// TODO
// cudnnDestroy()
}
const void* get_handle (
) const { return handle; }
private:
void* handle;
......@@ -53,8 +64,12 @@ namespace dlib
!*/
public:
// not copyable
tensor_descriptor(const tensor_descriptor&) = delete;
tensor_descriptor& operator=(const tensor_descriptor&) = delete;
// but is movable
tensor_descriptor(tensor_descriptor&&) = default;
tensor_descriptor& operator=(tensor_descriptor&&) = default;
tensor_descriptor()
{
......@@ -69,6 +84,23 @@ namespace dlib
// cudnnDestroyTensorDescriptor()
}
void set_size(
int n,
int nr,
int nc,
int k
);
void get_size (
int& n,
int& nr,
int& nc,
int& k
) const;
const void* get_handle (
) const { return handle; }
private:
void* handle;
......
......@@ -6,6 +6,7 @@
#include <memory>
#include <cstring>
#include "../matrix.h"
#include "cudnn.h"
namespace dlib
{
......@@ -246,6 +247,11 @@ namespace dlib
}
#ifdef DLIB_USE_CUDA
const cuda::tensor_descriptor& get_cudnn_tensor_descriptor (
) const { return cudnn_descriptor; }
#endif
protected:
......@@ -258,6 +264,10 @@ namespace dlib
m_k = item.m_k;
data.set_size(item.data.size());
std::memcpy(data.host(), item.data.host(), data.size()*sizeof(float));
#ifdef DLIB_USE_CUDA
cudnn_descriptor.set_size(m_n,m_nr,m_nc,m_k);
#endif
return *this;
}
......@@ -277,6 +287,9 @@ namespace dlib
long m_nc;
long m_k;
gpu_data data;
#ifdef DLIB_USE_CUDA
cuda::tensor_descriptor cudnn_descriptor;
#endif
};
tensor::~tensor()
......@@ -431,6 +444,10 @@ namespace dlib
m_nc = nc_;
m_k = k_;
data.set_size(m_n*m_nr*m_nc*m_k);
#ifdef DLIB_USE_CUDA
cudnn_descriptor.set_size(m_n,m_nr,m_nc,m_k);
#endif
}
};
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment