Commit 88f5d9a3 authored by Davis King's avatar Davis King

Implemented more cuDNN bindings and cleaned up code a bit.

parent e9efffff
...@@ -36,6 +36,17 @@ namespace dlib ...@@ -36,6 +36,17 @@ namespace dlib
} }
} }
// ------------------------------------------------------------------------------------
static const cudnnTensorDescriptor_t descriptor(const tensor& t)
{
return (const cudnnTensorDescriptor_t)t.get_cudnn_tensor_descriptor().get_handle();
}
static const cudnnTensorDescriptor_t descriptor(const tensor_descriptor& t)
{
return (const cudnnTensorDescriptor_t)t.get_handle();
}
// ------------------------------------------------------------------------------------ // ------------------------------------------------------------------------------------
class cudnn_context class cudnn_context
...@@ -155,6 +166,13 @@ namespace dlib ...@@ -155,6 +166,13 @@ namespace dlib
const tensor& src const tensor& src
) )
{ {
check(cudnnAddTensor_v3(context(),
&alpha,
descriptor(src),
src.device(),
&beta,
descriptor(dest),
dest.device()));
} }
void set_tensor ( void set_tensor (
...@@ -162,6 +180,12 @@ namespace dlib ...@@ -162,6 +180,12 @@ namespace dlib
float value float value
) )
{ {
if (t.size() == 0)
return;
check(cudnnSetTensor(context(),
descriptor(t),
t.device(),
&value));
} }
void scale_tensor ( void scale_tensor (
...@@ -169,6 +193,12 @@ namespace dlib ...@@ -169,6 +193,12 @@ namespace dlib
float value float value
) )
{ {
if (t.size() == 0)
return;
check(cudnnScaleTensor(context(),
descriptor(t),
t.device(),
&value));
} }
// ------------------------------------------------------------------------------------ // ------------------------------------------------------------------------------------
...@@ -246,7 +276,7 @@ namespace dlib ...@@ -246,7 +276,7 @@ namespace dlib
check(cudnnGetConvolution2dForwardOutputDim( check(cudnnGetConvolution2dForwardOutputDim(
(const cudnnConvolutionDescriptor_t)conv_handle, (const cudnnConvolutionDescriptor_t)conv_handle,
(const cudnnTensorDescriptor_t)data.get_cudnn_tensor_descriptor().get_handle(), descriptor(data),
(const cudnnFilterDescriptor_t)filter_handle, (const cudnnFilterDescriptor_t)filter_handle,
&out_num_samples, &out_num_samples,
&out_k, &out_k,
...@@ -259,10 +289,10 @@ namespace dlib ...@@ -259,10 +289,10 @@ namespace dlib
cudnnConvolutionFwdAlgo_t forward_best_algo; cudnnConvolutionFwdAlgo_t forward_best_algo;
check(cudnnGetConvolutionForwardAlgorithm( check(cudnnGetConvolutionForwardAlgorithm(
context(), context(),
(const cudnnTensorDescriptor_t)data.get_cudnn_tensor_descriptor().get_handle(), descriptor(data),
(const cudnnFilterDescriptor_t)filter_handle, (const cudnnFilterDescriptor_t)filter_handle,
(const cudnnConvolutionDescriptor_t)conv_handle, (const cudnnConvolutionDescriptor_t)conv_handle,
(const cudnnTensorDescriptor_t)dest_desc.get_handle(), descriptor(dest_desc),
CUDNN_CONVOLUTION_FWD_PREFER_FASTEST, // or CUDNN_CONVOLUTION_FWD_NO_WORKSPACE, CUDNN_CONVOLUTION_FWD_PREFER_FASTEST, // or CUDNN_CONVOLUTION_FWD_NO_WORKSPACE,
std::numeric_limits<size_t>::max(), std::numeric_limits<size_t>::max(),
&forward_best_algo)); &forward_best_algo));
...@@ -271,10 +301,10 @@ namespace dlib ...@@ -271,10 +301,10 @@ namespace dlib
check(cudnnGetConvolutionForwardWorkspaceSize( check(cudnnGetConvolutionForwardWorkspaceSize(
context(), context(),
(const cudnnTensorDescriptor_t)data.get_cudnn_tensor_descriptor().get_handle(), descriptor(data),
(const cudnnFilterDescriptor_t)filter_handle, (const cudnnFilterDescriptor_t)filter_handle,
(const cudnnConvolutionDescriptor_t)conv_handle, (const cudnnConvolutionDescriptor_t)conv_handle,
(const cudnnTensorDescriptor_t)dest_desc.get_handle(), descriptor(dest_desc),
forward_best_algo, forward_best_algo,
&forward_workspace_size_in_bytes)); &forward_workspace_size_in_bytes));
...@@ -313,7 +343,7 @@ namespace dlib ...@@ -313,7 +343,7 @@ namespace dlib
check(cudnnConvolutionForward( check(cudnnConvolutionForward(
context(), context(),
&alpha, &alpha,
(const cudnnTensorDescriptor_t)data.get_cudnn_tensor_descriptor().get_handle(), descriptor(data),
data.device(), data.device(),
(const cudnnFilterDescriptor_t)filter_handle, (const cudnnFilterDescriptor_t)filter_handle,
filters.device(), filters.device(),
...@@ -322,7 +352,7 @@ namespace dlib ...@@ -322,7 +352,7 @@ namespace dlib
forward_workspace, forward_workspace,
forward_workspace_size_in_bytes, forward_workspace_size_in_bytes,
&beta, &beta,
(const cudnnTensorDescriptor_t)output.get_cudnn_tensor_descriptor().get_handle(), descriptor(output),
output.device())); output.device()));
} }
......
...@@ -83,15 +83,12 @@ namespace dlib ...@@ -83,15 +83,12 @@ namespace dlib
- dest.k()==src.k() || src.k()==1 - dest.k()==src.k() || src.k()==1
ensures ensures
- performs: dest = beta*dest + alpha*src - performs: dest = beta*dest + alpha*src
TODO, clarify: However, how the addition happens depends on the dimensions of src. In
particular, this function adds the scaled values of one src tensor to
calls cudnnAddTensor_v3() dest. Each dimension of the src tensor must match the corresponding
dimension of the dest tensor or must be equal to 1. In the latter case,
This function adds the scaled values of one src tensor to another the same value from the src tensor, for those dimensions, will be used to
tensor. Each dimension of the src tensor must match the corresponding add into the dest tensor.
dimension of the dest tensor or must be equal to 1. In the latter case,
the same value from the src tensor, for those dimensions, will be used
to blend into the dest tensor.
!*/ !*/
void set_tensor ( void set_tensor (
...@@ -101,7 +98,6 @@ namespace dlib ...@@ -101,7 +98,6 @@ namespace dlib
/*! /*!
ensures ensures
- sets all elements in t equal to value. - sets all elements in t equal to value.
Uses cudnnSetTensor().
!*/ !*/
void scale_tensor ( void scale_tensor (
...@@ -113,8 +109,6 @@ namespace dlib ...@@ -113,8 +109,6 @@ namespace dlib
- scales all elements of t by the given value. I.e. for all elements E in - scales all elements of t by the given value. I.e. for all elements E in
t, this function performs: t, this function performs:
- E = E*value - E = E*value
uses cudnnScaleTensor()
!*/ !*/
// ------------------------------------------------------------------------------------ // ------------------------------------------------------------------------------------
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
#ifndef DLIB_DNn_TENSOR_H_ #ifndef DLIB_DNn_TENSOR_H_
#define DLIB_DNn_TENSOR_H_ #define DLIB_DNn_TENSOR_H_
#include "tensor_abstract.h"
#include <cstring> #include <cstring>
#include "../matrix.h" #include "../matrix.h"
#include "cudnn_dlibapi.h" #include "cudnn_dlibapi.h"
...@@ -46,13 +47,42 @@ namespace dlib ...@@ -46,13 +47,42 @@ namespace dlib
tensor& operator= (float val) tensor& operator= (float val)
{ {
// TODO, do on the device if that's where the memory is living right now. #ifdef DLIB_USE_CUDA
// If you are using CUDA then presumably you will be mostly using tensor's on
// the GPU. So unless you seem to be actively working with the host side's
// data then we do this initialization on the device side since this avoids a
// host to device transfer that would likely immediately follow.
if (data.device_ready())
{
cuda::set_tensor(*this, val);
return *this;
}
#endif
auto d = data.host(); auto d = data.host();
for (size_t i = 0; i < data.size(); ++i) for (size_t i = 0; i < data.size(); ++i)
d[i] = val; d[i] = val;
return *this; return *this;
} }
tensor& operator*= (float val)
{
#ifdef DLIB_USE_CUDA
cuda::scale_tensor(*this, val);
return *this;
#else
auto d = data.host();
for (size_t i = 0; i < data.size(); ++i)
d[i] *= val;
return *this;
#endif
}
tensor& operator/= (float val)
{
*this *= 1.0/val;
return *this;
}
template <typename EXP> template <typename EXP>
tensor& operator= (const matrix_exp<EXP>& item) tensor& operator= (const matrix_exp<EXP>& item)
{ {
......
...@@ -161,6 +161,24 @@ namespace dlib ...@@ -161,6 +161,24 @@ namespace dlib
- returns *this - returns *this
!*/ !*/
tensor& operator*= (
float val
);
/*!
ensures
- pointwise multiplies all elements of *this tensor with val.
- returns *this
!*/
tensor& operator/= (
float val
);
/*!
ensures
- pointwise divides all elements of *this tensor with val.
- returns *this
!*/
template <typename EXP> template <typename EXP>
tensor& operator= ( tensor& operator= (
const matrix_exp<EXP>& item const matrix_exp<EXP>& item
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment