Commit e1b2c950 authored by Davis King's avatar Davis King

Made the tensor contain a cudnn tensor descriptor and also added a few other

minor refinements.
parent 94979c79
...@@ -24,8 +24,12 @@ namespace dlib ...@@ -24,8 +24,12 @@ namespace dlib
class cublas_context class cublas_context
{ {
public: public:
// not copyable
cublas_context(const cublas_context&) = delete; cublas_context(const cublas_context&) = delete;
cublas_context& operator=(const cublas_context&) = delete; cublas_context& operator=(const cublas_context&) = delete;
// but is movable
cublas_context(const cublas_context&&) = default;
cublas_context& operator=(const cublas_context&&) = default;
cublas_context() cublas_context()
{ {
...@@ -36,6 +40,10 @@ namespace dlib ...@@ -36,6 +40,10 @@ namespace dlib
{ {
// TODO // TODO
} }
const void* get_handle (
) const { return handle; }
private: private:
void* handle; void* handle;
......
...@@ -108,8 +108,12 @@ namespace dlib ...@@ -108,8 +108,12 @@ namespace dlib
!*/ !*/
public: public:
// not copyable
dropout(const dropout&) = delete; dropout(const dropout&) = delete;
dropout& operator=(const dropout&) = delete; dropout& operator=(const dropout&) = delete;
// but is movable
dropout(const dropout&&) = default;
dropout& operator=(const dropout&&) = default;
dropout(float drop_rate = 0.5); dropout(float drop_rate = 0.5);
dropout(float drop_rate, int seed); dropout(float drop_rate, int seed);
......
...@@ -9,6 +9,9 @@ ...@@ -9,6 +9,9 @@
namespace dlib namespace dlib
{ {
class tensor;
class resizable_tensor;
namespace cuda namespace cuda
{ {
...@@ -24,8 +27,12 @@ namespace dlib ...@@ -24,8 +27,12 @@ namespace dlib
class cudnn_context class cudnn_context
{ {
public: public:
// not copyable
cudnn_context(const cudnn_context&) = delete; cudnn_context(const cudnn_context&) = delete;
cudnn_context& operator=(const cudnn_context&) = delete; cudnn_context& operator=(const cudnn_context&) = delete;
// but is movable
cudnn_context(cudnn_context&&) = default;
cudnn_context& operator=(cudnn_context&&) = default;
cudnn_context() cudnn_context()
{ {
...@@ -38,6 +45,10 @@ namespace dlib ...@@ -38,6 +45,10 @@ namespace dlib
// TODO // TODO
// cudnnDestroy() // cudnnDestroy()
} }
const void* get_handle (
) const { return handle; }
private: private:
void* handle; void* handle;
...@@ -53,8 +64,12 @@ namespace dlib ...@@ -53,8 +64,12 @@ namespace dlib
!*/ !*/
public: public:
// not copyable
tensor_descriptor(const tensor_descriptor&) = delete; tensor_descriptor(const tensor_descriptor&) = delete;
tensor_descriptor& operator=(const tensor_descriptor&) = delete; tensor_descriptor& operator=(const tensor_descriptor&) = delete;
// but is movable
tensor_descriptor(tensor_descriptor&&) = default;
tensor_descriptor& operator=(tensor_descriptor&&) = default;
tensor_descriptor() tensor_descriptor()
{ {
...@@ -69,6 +84,23 @@ namespace dlib ...@@ -69,6 +84,23 @@ namespace dlib
// cudnnDestroyTensorDescriptor() // cudnnDestroyTensorDescriptor()
} }
void set_size(
int n,
int nr,
int nc,
int k
);
void get_size (
int& n,
int& nr,
int& nc,
int& k
) const;
const void* get_handle (
) const { return handle; }
private: private:
void* handle; void* handle;
......
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
#include <memory> #include <memory>
#include <cstring> #include <cstring>
#include "../matrix.h" #include "../matrix.h"
#include "cudnn.h"
namespace dlib namespace dlib
{ {
...@@ -246,6 +247,11 @@ namespace dlib ...@@ -246,6 +247,11 @@ namespace dlib
} }
#ifdef DLIB_USE_CUDA
const cuda::tensor_descriptor& get_cudnn_tensor_descriptor (
) const { return cudnn_descriptor; }
#endif
protected: protected:
...@@ -258,6 +264,10 @@ namespace dlib ...@@ -258,6 +264,10 @@ namespace dlib
m_k = item.m_k; m_k = item.m_k;
data.set_size(item.data.size()); data.set_size(item.data.size());
std::memcpy(data.host(), item.data.host(), data.size()*sizeof(float)); std::memcpy(data.host(), item.data.host(), data.size()*sizeof(float));
#ifdef DLIB_USE_CUDA
cudnn_descriptor.set_size(m_n,m_nr,m_nc,m_k);
#endif
return *this; return *this;
} }
...@@ -277,6 +287,9 @@ namespace dlib ...@@ -277,6 +287,9 @@ namespace dlib
long m_nc; long m_nc;
long m_k; long m_k;
gpu_data data; gpu_data data;
#ifdef DLIB_USE_CUDA
cuda::tensor_descriptor cudnn_descriptor;
#endif
}; };
tensor::~tensor() tensor::~tensor()
...@@ -431,6 +444,10 @@ namespace dlib ...@@ -431,6 +444,10 @@ namespace dlib
m_nc = nc_; m_nc = nc_;
m_k = k_; m_k = k_;
data.set_size(m_n*m_nr*m_nc*m_k); data.set_size(m_n*m_nr*m_nc*m_k);
#ifdef DLIB_USE_CUDA
cudnn_descriptor.set_size(m_n,m_nr,m_nc,m_k);
#endif
} }
}; };
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment