Commit cb2f9de6 authored by Davis King's avatar Davis King

Added part of the tensor_tools implementations

parent 76433858
......@@ -138,6 +138,7 @@ if (NOT TARGET dlib)
if (COMPILER_CAN_DO_CPP_11)
set(source_files ${source_files}
dnn/cpu_dlib.cpp
dnn/tensor_tools.cpp
)
endif()
......
......@@ -21,6 +21,7 @@
// Stuff that requires C++11
#if __cplusplus >= 201103
#include "../dnn/cpu_dlib.cpp"
#include "../dnn/tensor_tools.cpp"
#endif
#ifndef DLIB_ISO_CPP_ONLY
......
......@@ -11,6 +11,7 @@
#include "dnn/solvers.h"
#include "dnn/trainer.h"
#include "dnn/cpu_dlib.h"
#include "dnn/tensor_tools.h"
#endif // DLIB_DNn_
......
......@@ -91,6 +91,9 @@ namespace dlib
}
}
#ifdef NO_MAKEFILE
#include "cpu_dlib.cpp"
#endif
#endif // DLIB_DNN_CPU_H_
......
This diff is collapsed.
......@@ -4,8 +4,12 @@
#define DLIB_TeNSOR_TOOLS_H_
#include "tensor.h"
#include "cudnn_dlibapi.h"
#include "cublas_dlibapi.h"
#include "curand_dlibapi.h"
#include "../rand.h"
namespace dlib
namespace dlib { namespace tt
{
// ----------------------------------------------------------------------------------------
......@@ -37,7 +41,51 @@ namespace dlib
class tensor_rand
{
/*!
WHAT THIS OBJECT REPRESENTS
This is a tool for filling a tensor with random numbers.
Note that the sequence of random numbers output by this object is different
when dlib is compiled with DLIB_USE_CUDA. So you should not write code
that depends on any specific sequence of numbers coming out of a
tensor_rand.
!*/
public:
// not copyable
tensor_rand(const tensor_rand&) = delete;
tensor_rand& operator=(const tensor_rand&) = delete;
tensor_rand() : tensor_rand(0) {}
tensor_rand(unsigned long long seed);
void fill_gaussian (
tensor& data,
float mean,
float stddev
);
/*!
requires
- data.size()%2 == 0
ensures
- Fills data with random numbers drawn from a Gaussian distribution
with the given mean and standard deviation.
!*/
void fill_uniform (
tensor& data
);
/*!
ensures
- Fills data with uniform random numbers in the range (0.0, 1.0].
!*/
#ifdef DLIB_USE_CUDA
cuda::curand_generator rnd;
#else
dlib::rand rnd;
#endif
};
// ----------------------------------------------------------------------------------------
......@@ -278,13 +326,13 @@ namespace dlib
// ----------------------------------------------------------------------------------------
class conv
class tensor_conv
{
public:
conv(const conv&) = delete;
conv& operator=(const conv&) = delete;
tensor_conv(const tensor_conv&) = delete;
tensor_conv& operator=(const tensor_conv&) = delete;
conv();
tensor_conv();
void clear(
);
......@@ -302,9 +350,6 @@ namespace dlib
- stride_x > 0
!*/
~conv (
);
void operator() (
resizable_tensor& output,
const tensor& data,
......@@ -362,6 +407,11 @@ namespace dlib
!*/
private:
#ifdef DLIB_USE_CUDA
cuda::tensor_conv impl;
#else
// TODO
#endif
};
......@@ -379,9 +429,6 @@ namespace dlib
max_pool (
);
~max_pool(
);
void clear(
);
......@@ -429,6 +476,11 @@ namespace dlib
!*/
private:
#ifdef DLIB_USE_CUDA
cuda::max_pool impl;
#else
// TODO
#endif
};
// ----------------------------------------------------------------------------------------
......@@ -564,8 +616,11 @@ namespace dlib
// ----------------------------------------------------------------------------------------
}
}
}}
#ifdef NO_MAKEFILE
#include "tensor_tools.cpp"
#endif
#endif // DLIB_TeNSOR_TOOLS_H_
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment