Commit 5f8e41a8 authored by Davis King's avatar Davis King

Added another version of multiply()

parent 79344339
...@@ -26,6 +26,23 @@ namespace dlib ...@@ -26,6 +26,23 @@ namespace dlib
d[i] *= s[i]; d[i] *= s[i];
} }
// -----------------------------------------------------------------------------------
void multiply (
tensor& dest,
const tensor& src1,
const tensor& src2
)
{
DLIB_CASSERT(dest.size()==src1.size(),"");
DLIB_CASSERT(dest.size()==src2.size(),"");
const auto d = dest.host();
const auto s1 = src1.host();
const auto s2 = src2.host();
for (size_t i = 0; i < src1.size(); ++i)
d[i] = s1[i]*s2[i];
}
// ----------------------------------------------------------------------------------- // -----------------------------------------------------------------------------------
void affine_transform( void affine_transform(
......
...@@ -20,6 +20,12 @@ namespace dlib ...@@ -20,6 +20,12 @@ namespace dlib
const tensor& src const tensor& src
); );
void multiply (
tensor& dest,
const tensor& src1,
const tensor& src2
);
// ----------------------------------------------------------------------------------- // -----------------------------------------------------------------------------------
void affine_transform( void affine_transform(
......
...@@ -46,6 +46,27 @@ namespace dlib ...@@ -46,6 +46,27 @@ namespace dlib
_cuda_multiply<<<512,512>>>(dest.device(), src.device(), src.size()); _cuda_multiply<<<512,512>>>(dest.device(), src.device(), src.size());
} }
// -----------------------------------------------------------------------------------
__global__ void _cuda_multiply(float* d, const float* s1, const float* s2, size_t n)
{
for (auto i : grid_stride_range(0, n))
{
d[i] = s1[i]*s2[i];
}
}
void multiply (
tensor& dest,
const tensor& src1,
const tensor& src2
)
{
DLIB_CASSERT(dest.size()==src1.size(),"");
DLIB_CASSERT(dest.size()==src2.size(),"");
_cuda_multiply<<<512,512>>>(dest.device(), src1.device(), src2.device(), src1.size());
}
// ----------------------------------------------------------------------------------- // -----------------------------------------------------------------------------------
__global__ void _cuda_affine_transform(float* d, const float* s, size_t n, float A, float B) __global__ void _cuda_affine_transform(float* d, const float* s, size_t n, float A, float B)
......
...@@ -29,6 +29,12 @@ namespace dlib ...@@ -29,6 +29,12 @@ namespace dlib
const tensor& src const tensor& src
); );
void multiply (
tensor& dest,
const tensor& src1,
const tensor& src2
);
// ----------------------------------------------------------------------------------- // -----------------------------------------------------------------------------------
void affine_transform( void affine_transform(
......
...@@ -108,6 +108,24 @@ namespace dlib { namespace tt ...@@ -108,6 +108,24 @@ namespace dlib { namespace tt
} }
// ----------------------------------------------------------------------------------------
void multiply (
tensor& dest,
const tensor& src1,
const tensor& src2
)
{
DLIB_CASSERT(have_same_dimensions(dest,src1) == true,"");
DLIB_CASSERT(have_same_dimensions(dest,src2) == true,"");
#ifdef DLIB_USE_CUDA
cuda::multiply(dest, src1, src2);
#else
cpu::multiply(dest, src1, src2);
#endif
}
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
void affine_transform( void affine_transform(
......
...@@ -105,6 +105,23 @@ namespace dlib { namespace tt ...@@ -105,6 +105,23 @@ namespace dlib { namespace tt
#dest.host()[i] == dest.host()[i]*src.host()[i] #dest.host()[i] == dest.host()[i]*src.host()[i]
!*/ !*/
// ----------------------------------------------------------------------------------------
void multiply (
tensor& dest,
const tensor& src1,
const tensor& src2
);
/*!
requires
- have_same_dimensions(dest,src1) == true
- have_same_dimensions(dest,src2) == true
ensures
- #dest == src1*src2
That is, for all valid i:
#dest.host()[i] == src1.host()[i]*src2.host()[i]
!*/
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
void affine_transform( void affine_transform(
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment