Commit cb2f9de6 authored by Davis King's avatar Davis King

Added part of the tensor_tools implementations

parent 76433858
...@@ -138,6 +138,7 @@ if (NOT TARGET dlib) ...@@ -138,6 +138,7 @@ if (NOT TARGET dlib)
if (COMPILER_CAN_DO_CPP_11) if (COMPILER_CAN_DO_CPP_11)
set(source_files ${source_files} set(source_files ${source_files}
dnn/cpu_dlib.cpp dnn/cpu_dlib.cpp
dnn/tensor_tools.cpp
) )
endif() endif()
......
...@@ -21,6 +21,7 @@ ...@@ -21,6 +21,7 @@
// Stuff that requires C++11 // Stuff that requires C++11
#if __cplusplus >= 201103 #if __cplusplus >= 201103
#include "../dnn/cpu_dlib.cpp" #include "../dnn/cpu_dlib.cpp"
#include "../dnn/tensor_tools.cpp"
#endif #endif
#ifndef DLIB_ISO_CPP_ONLY #ifndef DLIB_ISO_CPP_ONLY
......
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
#include "dnn/solvers.h" #include "dnn/solvers.h"
#include "dnn/trainer.h" #include "dnn/trainer.h"
#include "dnn/cpu_dlib.h" #include "dnn/cpu_dlib.h"
#include "dnn/tensor_tools.h"
#endif // DLIB_DNn_ #endif // DLIB_DNn_
......
...@@ -91,6 +91,9 @@ namespace dlib ...@@ -91,6 +91,9 @@ namespace dlib
} }
} }
#ifdef NO_MAKEFILE
#include "cpu_dlib.cpp"
#endif
#endif // DLIB_DNN_CPU_H_ #endif // DLIB_DNN_CPU_H_
......
This diff is collapsed.
...@@ -4,8 +4,12 @@ ...@@ -4,8 +4,12 @@
#define DLIB_TeNSOR_TOOLS_H_ #define DLIB_TeNSOR_TOOLS_H_
#include "tensor.h" #include "tensor.h"
#include "cudnn_dlibapi.h"
#include "cublas_dlibapi.h"
#include "curand_dlibapi.h"
#include "../rand.h"
namespace dlib namespace dlib { namespace tt
{ {
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
...@@ -37,7 +41,51 @@ namespace dlib ...@@ -37,7 +41,51 @@ namespace dlib
class tensor_rand class tensor_rand
{ {
/*!
WHAT THIS OBJECT REPRESENTS
This is a tool for filling a tensor with random numbers.
Note that the sequence of random numbers output by this object is different
when dlib is compiled with DLIB_USE_CUDA. So you should not write code
that depends on any specific sequence of numbers coming out of a
tensor_rand.
!*/
public: public:
// not copyable
tensor_rand(const tensor_rand&) = delete;
tensor_rand& operator=(const tensor_rand&) = delete;
tensor_rand() : tensor_rand(0) {}
tensor_rand(unsigned long long seed);
void fill_gaussian (
tensor& data,
float mean,
float stddev
);
/*!
requires
- data.size()%2 == 0
ensures
- Fills data with random numbers drawn from a Gaussian distribution
with the given mean and standard deviation.
!*/
void fill_uniform (
tensor& data
);
/*!
ensures
- Fills data with uniform random numbers in the range (0.0, 1.0].
!*/
#ifdef DLIB_USE_CUDA
cuda::curand_generator rnd;
#else
dlib::rand rnd;
#endif
}; };
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
...@@ -278,13 +326,13 @@ namespace dlib ...@@ -278,13 +326,13 @@ namespace dlib
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
class conv class tensor_conv
{ {
public: public:
conv(const conv&) = delete; tensor_conv(const tensor_conv&) = delete;
conv& operator=(const conv&) = delete; tensor_conv& operator=(const tensor_conv&) = delete;
conv(); tensor_conv();
void clear( void clear(
); );
...@@ -302,9 +350,6 @@ namespace dlib ...@@ -302,9 +350,6 @@ namespace dlib
- stride_x > 0 - stride_x > 0
!*/ !*/
~conv (
);
void operator() ( void operator() (
resizable_tensor& output, resizable_tensor& output,
const tensor& data, const tensor& data,
...@@ -362,6 +407,11 @@ namespace dlib ...@@ -362,6 +407,11 @@ namespace dlib
!*/ !*/
private: private:
#ifdef DLIB_USE_CUDA
cuda::tensor_conv impl;
#else
// TODO
#endif
}; };
...@@ -379,9 +429,6 @@ namespace dlib ...@@ -379,9 +429,6 @@ namespace dlib
max_pool ( max_pool (
); );
~max_pool(
);
void clear( void clear(
); );
...@@ -429,6 +476,11 @@ namespace dlib ...@@ -429,6 +476,11 @@ namespace dlib
!*/ !*/
private: private:
#ifdef DLIB_USE_CUDA
cuda::max_pool impl;
#else
// TODO
#endif
}; };
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
...@@ -564,8 +616,11 @@ namespace dlib ...@@ -564,8 +616,11 @@ namespace dlib
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
} }}
}
#ifdef NO_MAKEFILE
#include "tensor_tools.cpp"
#endif
#endif // DLIB_TeNSOR_TOOLS_H_ #endif // DLIB_TeNSOR_TOOLS_H_
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment