Commit 45dd580b authored by Davis King's avatar Davis King

Wrote replacements for set_tensor() and scale_tensor() since the previous

versions were calling into cuDNN, however, the cuDNN functions for doing this
are horrifically slow, well over 100x slower than they should be, which is
surprising since these functions are so trivial.
parent 452b188d
...@@ -864,6 +864,38 @@ namespace dlib ...@@ -864,6 +864,38 @@ namespace dlib
launch_kernel(_add_bias_gradient,max_jobs(grad.size()),grad.device(), gradient_input.device(), grad.size(), gradient_input.size()); launch_kernel(_add_bias_gradient,max_jobs(grad.size()),grad.device(), gradient_input.device(), grad.size(), gradient_input.size());
} }
// ----------------------------------------------------------------------------------------
__global__ void _set_tensor(float* out, size_t n, const float val)
{
for (auto i : grid_stride_range(0, n))
out[i] = val;
}
void set_tensor (
tensor& t,
float value
)
{
launch_kernel(_set_tensor, max_jobs(t.size()), t.device(), t.size(), value);
}
// ----------------------------------------------------------------------------------------
__global__ void _scale_tensor(float* out, size_t n, const float val)
{
for (auto i : grid_stride_range(0, n))
out[i] *= val;
}
void scale_tensor (
tensor& t,
float value
)
{
launch_kernel(_scale_tensor, max_jobs(t.size()), t.device(), t.size(), value);
}
// ----------------------------------------------------------------------------------- // -----------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------- // -----------------------------------------------------------------------------------
......
...@@ -141,6 +141,18 @@ namespace dlib ...@@ -141,6 +141,18 @@ namespace dlib
const tensor& v2 const tensor& v2
); );
// ------------------------------------------------------------------------------------
void set_tensor (
tensor& t,
float value
);
void scale_tensor (
tensor& t,
float value
);
// ------------------------------------------------------------------------------------ // ------------------------------------------------------------------------------------
void multiply ( void multiply (
......
...@@ -289,32 +289,6 @@ namespace dlib ...@@ -289,32 +289,6 @@ namespace dlib
dest.device())); dest.device()));
} }
void set_tensor (
tensor& t,
float value
)
{
if (t.size() == 0)
return;
CHECK_CUDNN(cudnnSetTensor(context(),
descriptor(t),
t.device_write_only(),
&value));
}
void scale_tensor (
tensor& t,
float value
)
{
if (t.size() == 0)
return;
CHECK_CUDNN(cudnnScaleTensor(context(),
descriptor(t),
t.device(),
&value));
}
void assign_conv_bias_gradient ( void assign_conv_bias_gradient (
tensor& grad, tensor& grad,
const tensor& gradient_input const tensor& gradient_input
......
...@@ -89,26 +89,6 @@ namespace dlib ...@@ -89,26 +89,6 @@ namespace dlib
add into the dest tensor. add into the dest tensor.
!*/ !*/
void set_tensor (
tensor& t,
float value
);
/*!
ensures
- sets all elements in t equal to value.
!*/
void scale_tensor (
tensor& t,
float value
);
/*!
ensures
- scales all elements of t by the given value. I.e. for all elements E in
t, this function performs:
- E = E*value
!*/
// ------------------------------------------------------------------------------------ // ------------------------------------------------------------------------------------
void assign_conv_bias_gradient ( void assign_conv_bias_gradient (
......
...@@ -14,6 +14,22 @@ ...@@ -14,6 +14,22 @@
namespace dlib namespace dlib
{ {
// ----------------------------------------------------------------------------------------
namespace cuda
{
void set_tensor (
tensor& t,
float value
);
void scale_tensor (
tensor& t,
float value
);
}
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
class tensor class tensor
......
...@@ -339,6 +339,14 @@ namespace ...@@ -339,6 +339,14 @@ namespace
dlog << LINFO << mat(dest); dlog << LINFO << mat(dest);
matrix<float> truth1(3,4), truth2(3,4); matrix<float> truth1(3,4), truth2(3,4);
truth1 = 2;
DLIB_TEST(max(abs(truth1-mat(src))) < 1e-5);
src *= 2;
truth1 = 4;
DLIB_TEST(max(abs(truth1-mat(src))) < 1e-5);
src = 2;
truth1 = 7; truth1 = 7;
truth2 = 7, 10, 7, 7, truth2 = 7, 10, 7, 7,
7, 10, 7, 7, 7, 10, 7, 7,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment