Commit 29f56b12 authored by Davis King's avatar Davis King

Made the affine_transform functions consistent.

parent 76786430
...@@ -102,12 +102,13 @@ namespace dlib ...@@ -102,12 +102,13 @@ namespace dlib
// ----------------------------------------------------------------------------------- // -----------------------------------------------------------------------------------
void affine_transform( void affine_transform(
resizable_tensor& dest, tensor& dest,
const tensor& src, const tensor& src,
const tensor& A, const tensor& A,
const tensor& B const tensor& B
) )
{ {
DLIB_CASSERT(have_same_dimensions(dest,src),"");
DLIB_CASSERT( DLIB_CASSERT(
((A.num_samples()==1 && B.num_samples()==1) || ((A.num_samples()==1 && B.num_samples()==1) ||
(A.num_samples()==src.num_samples() && B.num_samples()==src.num_samples())) && (A.num_samples()==src.num_samples() && B.num_samples()==src.num_samples())) &&
...@@ -115,7 +116,6 @@ namespace dlib ...@@ -115,7 +116,6 @@ namespace dlib
A.nc()==B.nc() && B.nc()==src.nc() && A.nc()==B.nc() && B.nc()==src.nc() &&
A.k() ==B.k() && B.k()==src.k(),""); A.k() ==B.k() && B.k()==src.k(),"");
dest.copy_size(src);
auto d = dest.host(); auto d = dest.host();
auto s = src.host(); auto s = src.host();
const auto a = A.host(); const auto a = A.host();
......
...@@ -58,7 +58,7 @@ namespace dlib ...@@ -58,7 +58,7 @@ namespace dlib
// ----------------------------------------------------------------------------------- // -----------------------------------------------------------------------------------
void affine_transform( void affine_transform(
resizable_tensor& dest, tensor& dest,
const tensor& src, const tensor& src,
const tensor& A, const tensor& A,
const tensor& B const tensor& B
......
...@@ -160,12 +160,13 @@ namespace dlib ...@@ -160,12 +160,13 @@ namespace dlib
} }
void affine_transform( void affine_transform(
resizable_tensor& dest, tensor& dest,
const tensor& src, const tensor& src,
const tensor& A, const tensor& A,
const tensor& B const tensor& B
) )
{ {
DLIB_CASSERT(have_same_dimensions(dest, src),"");
DLIB_CASSERT( DLIB_CASSERT(
((A.num_samples()==1 && B.num_samples()==1) || ((A.num_samples()==1 && B.num_samples()==1) ||
(A.num_samples()==src.num_samples() && B.num_samples()==src.num_samples())) && (A.num_samples()==src.num_samples() && B.num_samples()==src.num_samples())) &&
...@@ -173,7 +174,6 @@ namespace dlib ...@@ -173,7 +174,6 @@ namespace dlib
A.nc()==B.nc() && B.nc()==src.nc() && A.nc()==B.nc() && B.nc()==src.nc() &&
A.k() ==B.k() && B.k()==src.k(),""); A.k() ==B.k() && B.k()==src.k(),"");
dest.copy_size(src);
if (A.num_samples() == 1) if (A.num_samples() == 1)
{ {
_cuda_affine_transform3<<<512,512>>>(dest.device(), src.device(), src.size(), A.device(), B.device(), A.size()); _cuda_affine_transform3<<<512,512>>>(dest.device(), src.device(), src.size(), A.device(), B.device(), A.size());
......
...@@ -68,7 +68,7 @@ namespace dlib ...@@ -68,7 +68,7 @@ namespace dlib
// ----------------------------------------------------------------------------------- // -----------------------------------------------------------------------------------
void affine_transform( void affine_transform(
resizable_tensor& dest, tensor& dest,
const tensor& src, const tensor& src,
const tensor& A, const tensor& A,
const tensor& B const tensor& B
......
...@@ -179,7 +179,7 @@ namespace dlib { namespace tt ...@@ -179,7 +179,7 @@ namespace dlib { namespace tt
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
void affine_transform( void affine_transform(
resizable_tensor& dest, tensor& dest,
const tensor& src, const tensor& src,
const tensor& A, const tensor& A,
const tensor& B const tensor& B
......
...@@ -175,13 +175,14 @@ namespace dlib { namespace tt ...@@ -175,13 +175,14 @@ namespace dlib { namespace tt
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
void affine_transform( void affine_transform(
resizable_tensor& dest, tensor& dest,
const tensor& src, const tensor& src,
const tensor& A, const tensor& A,
const tensor& B const tensor& B
); );
/*! /*!
requires requires
- have_same_dimensions(dest,src) == true
- if (A.num_samples() == 1) then - if (A.num_samples() == 1) then
- B.num_samples() == 1 - B.num_samples() == 1
- else - else
...@@ -191,7 +192,6 @@ namespace dlib { namespace tt ...@@ -191,7 +192,6 @@ namespace dlib { namespace tt
- A.nc() == B.nc() == src.nc() - A.nc() == B.nc() == src.nc()
- A.k() == B.k() == src.k() - A.k() == B.k() == src.k()
ensures ensures
- have_same_dimensions(#dest,src) == true
- if (A.num_samples() == 1) then - if (A.num_samples() == 1) then
- #dest == A*src + B - #dest == A*src + B
(done for each sample in src) (done for each sample in src)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment