Commit d3f12d83 authored by Davis King's avatar Davis King

Added more overloads of affine_transform()

parent 7d5fb9c6
...@@ -267,6 +267,16 @@ namespace dlib ...@@ -267,6 +267,16 @@ namespace dlib
launch_kernel(_cuda_affine_transform1_0,max_jobs(dest.size()),dest.device(), src.device(), src.size(), A); launch_kernel(_cuda_affine_transform1_0,max_jobs(dest.size()),dest.device(), src.device(), src.size(), A);
} }
void affine_transform(
tensor& dest,
const tensor& src,
const float A
)
{
DLIB_CASSERT(dest.size()==src.size(),"");
launch_kernel(_cuda_affine_transform1_0,max_jobs(dest.size()),dest.device(), src.device(), src.size(), A);
}
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
__global__ void _cuda_affine_transform4(float* d, const float* s1, const float* s2, size_t n, float A, float B, float C) __global__ void _cuda_affine_transform4(float* d, const float* s1, const float* s2, size_t n, float A, float B, float C)
...@@ -302,6 +312,19 @@ namespace dlib ...@@ -302,6 +312,19 @@ namespace dlib
launch_kernel(_cuda_affine_transform4_0,max_jobs(dest.size()),dest.device(), src1.device(), src2.device(), dest.size(), A, B); launch_kernel(_cuda_affine_transform4_0,max_jobs(dest.size()),dest.device(), src1.device(), src2.device(), dest.size(), A, B);
} }
void affine_transform(
tensor& dest,
const tensor& src1,
const tensor& src2,
const float A,
const float B
)
{
DLIB_CASSERT(dest.size()==src1.size(),"");
DLIB_CASSERT(dest.size()==src2.size(),"");
launch_kernel(_cuda_affine_transform4_0,max_jobs(dest.size()),dest.device(), src1.device(), src2.device(), dest.size(), A, B);
}
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
__global__ void _cuda_add_scaled(float* d, const float* s, size_t n, float scale) __global__ void _cuda_add_scaled(float* d, const float* s, size_t n, float scale)
......
...@@ -54,6 +54,12 @@ namespace dlib ...@@ -54,6 +54,12 @@ namespace dlib
const float B const float B
); );
void affine_transform(
tensor& dest,
const tensor& src,
const float A
);
void affine_transform( void affine_transform(
tensor& dest, tensor& dest,
const tensor& src1, const tensor& src1,
...@@ -63,6 +69,14 @@ namespace dlib ...@@ -63,6 +69,14 @@ namespace dlib
const float C const float C
); );
void affine_transform(
tensor& dest,
const tensor& src1,
const tensor& src2,
const float A,
const float B
);
void affine_transform( void affine_transform(
tensor& dest, tensor& dest,
const tensor& src1, const tensor& src1,
......
...@@ -176,6 +176,19 @@ namespace dlib { namespace tt ...@@ -176,6 +176,19 @@ namespace dlib { namespace tt
#endif #endif
} }
void affine_transform(
tensor& dest,
const tensor& src,
const float A
)
{
#ifdef DLIB_USE_CUDA
cuda::affine_transform(dest,src,A);
#else
cpu::affine_transform(dest,src,A,0);
#endif
}
void affine_transform( void affine_transform(
tensor& dest, tensor& dest,
const tensor& src1, const tensor& src1,
...@@ -192,6 +205,21 @@ namespace dlib { namespace tt ...@@ -192,6 +205,21 @@ namespace dlib { namespace tt
#endif #endif
} }
void affine_transform(
tensor& dest,
const tensor& src1,
const tensor& src2,
const float A,
const float B
)
{
#ifdef DLIB_USE_CUDA
cuda::affine_transform(dest,src1,src2,A,B);
#else
cpu::affine_transform(dest,src1,src2,A,B,0);
#endif
}
void affine_transform( void affine_transform(
tensor& dest, tensor& dest,
const tensor& src1, const tensor& src1,
......
...@@ -168,6 +168,18 @@ namespace dlib { namespace tt ...@@ -168,6 +168,18 @@ namespace dlib { namespace tt
- #dest == A*src + B - #dest == A*src + B
!*/ !*/
void affine_transform(
tensor& dest,
const tensor& src,
const float A
);
/*!
requires
- dest.size()==src.size()
ensures
- #dest == A*src
!*/
void affine_transform( void affine_transform(
tensor& dest, tensor& dest,
const tensor& src1, const tensor& src1,
...@@ -181,7 +193,22 @@ namespace dlib { namespace tt ...@@ -181,7 +193,22 @@ namespace dlib { namespace tt
- dest.size()==src1.size() - dest.size()==src1.size()
- dest.size()==src2.size() - dest.size()==src2.size()
ensures ensures
- #dest == A*src1 + src2*B + C - #dest == A*src1 + B*src2 + C
!*/
void affine_transform(
tensor& dest,
const tensor& src1,
const tensor& src2,
const float A,
const float B
);
/*!
requires
- dest.size()==src1.size()
- dest.size()==src2.size()
ensures
- #dest == A*src1 + B*src2
!*/ !*/
void affine_transform( void affine_transform(
...@@ -195,12 +222,11 @@ namespace dlib { namespace tt ...@@ -195,12 +222,11 @@ namespace dlib { namespace tt
const float D const float D
); );
/*! /*!
requires requires - dest.size()==src1.size()
- dest.size()==src1.size()
- dest.size()==src2.size() - dest.size()==src2.size()
- dest.size()==src3.size() - dest.size()==src3.size()
ensures ensures
- #dest == A*src1 + src2*B + src3*C + D - #dest == A*src1 + B*src2 + C*src3 + D
!*/ !*/
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment