Commit 7e12a32b authored by Davis King's avatar Davis King

Added more kinds of affine_transform(), made the solver use affine_transform() so it

runs on the GPU, and made affine_transform() take only tensors.
parent c8948ee0
......@@ -19,6 +19,7 @@ namespace dlib
const tensor& src
)
{
DLIB_CASSERT(dest.size()==src.size(),"");
const auto d = dest.host();
const auto s = src.host();
for (size_t i = 0; i < src.size(); ++i)
......@@ -28,19 +29,59 @@ namespace dlib
// -----------------------------------------------------------------------------------
void affine_transform(
resizable_tensor& dest,
tensor& dest,
const tensor& src,
const float A,
const float B
)
{
dest.copy_size(src);
DLIB_CASSERT(dest.size()==src.size(),"");
const auto d = dest.host();
const auto s = src.host();
for (size_t i = 0; i < src.size(); ++i)
d[i] = A*s[i] + B;
}
void affine_transform(
tensor& dest,
const tensor& src1,
const tensor& src2,
const float A,
const float B,
const float C
)
{
DLIB_CASSERT(dest.size()==src1.size(),"");
DLIB_CASSERT(dest.size()==src2.size(),"");
const auto d = dest.host();
const auto s1 = src1.host();
const auto s2 = src2.host();
for (size_t i = 0; i < src1.size(); ++i)
d[i] = A*s1[i] + B*s2[i] + C;
}
void affine_transform(
tensor& dest,
const tensor& src1,
const tensor& src2,
const tensor& src3,
const float A,
const float B,
const float C,
const float D
)
{
DLIB_CASSERT(dest.size()==src1.size(),"");
DLIB_CASSERT(dest.size()==src2.size(),"");
DLIB_CASSERT(dest.size()==src3.size(),"");
const auto d = dest.host();
const auto s1 = src1.host();
const auto s2 = src2.host();
const auto s3 = src3.host();
for (size_t i = 0; i < src1.size(); ++i)
d[i] = A*s1[i] + B*s2[i] + C*s3[i] + D;
}
// -----------------------------------------------------------------------------------
void affine_transform(
......
......@@ -23,12 +23,32 @@ namespace dlib
// -----------------------------------------------------------------------------------
void affine_transform(
resizable_tensor& dest,
tensor& dest,
const tensor& src,
const float A,
const float B
);
void affine_transform(
tensor& dest,
const tensor& src1,
const tensor& src2,
const float A,
const float B,
const float C
);
void affine_transform(
tensor& dest,
const tensor& src1,
const tensor& src2,
const tensor& src3,
const float A,
const float B,
const float C,
const float D
);
// -----------------------------------------------------------------------------------
void affine_transform(
......
......@@ -22,12 +22,33 @@ namespace dlib
// -----------------------------------------------------------------------------------
void affine_transform(
resizable_tensor& dest,
tensor& dest,
const tensor& src,
const float A,
const float B
);
void affine_transform(
tensor& dest,
const tensor& src1,
const tensor& src2,
const float A,
const float B,
const float C
);
void affine_transform(
tensor& dest,
const tensor& src1,
const tensor& src2,
const tensor& src3,
const float A,
const float B,
const float C,
const float D
);
// -----------------------------------------------------------------------------------
void affine_transform(
......
......@@ -40,12 +40,17 @@ namespace dlib
)
{
DLIB_CASSERT(l.get_layer_params().size() != 0,"");
if (v.size() != 0)
v = momentum*v - weight_decay*learning_rate*mat(l.get_layer_params()) - learning_rate*mat(params_grad);
else
v = - weight_decay*learning_rate*mat(l.get_layer_params()) - learning_rate*mat(params_grad);
if (v.size() == 0)
{
v.copy_size(params_grad);
v = 0;
}
tt::affine_transform(v, v, l.get_layer_params(), params_grad,
momentum, -weight_decay*learning_rate, -learning_rate, 0);
// perform l.get_layer_params() += v;
tt::affine_transform(l.get_layer_params(), l.get_layer_params(), v, 1, 1, 0);
l.get_layer_params() += v;
}
friend void serialize(const sgd& item, std::ostream& out)
......@@ -70,7 +75,7 @@ namespace dlib
}
private:
matrix<float> v;
resizable_tensor v;
float weight_decay;
float learning_rate;
float momentum;
......
......@@ -97,7 +97,7 @@ namespace dlib { namespace tt
// ----------------------------------------------------------------------------------------
void affine_transform(
resizable_tensor& dest,
tensor& dest,
const tensor& src,
const float A,
const float B
......@@ -110,6 +110,40 @@ namespace dlib { namespace tt
#endif
}
void affine_transform(
tensor& dest,
const tensor& src1,
const tensor& src2,
const float A,
const float B,
const float C
)
{
#ifdef DLIB_USE_CUDA
//cuda::affine_transform(dest,src1,src2,A,B,C);
#else
cpu::affine_transform(dest,src1,src2,A,B,C);
#endif
}
void affine_transform(
tensor& dest,
const tensor& src1,
const tensor& src2,
const tensor& src3,
const float A,
const float B,
const float C,
const float D
)
{
#ifdef DLIB_USE_CUDA
//cuda::affine_transform(dest,src1,src2,src3,A,B,C,D);
#else
cpu::affine_transform(dest,src1,src2,src3,A,B,C,D);
#endif
}
// ----------------------------------------------------------------------------------------
void affine_transform(
......
......@@ -108,17 +108,53 @@ namespace dlib { namespace tt
// ----------------------------------------------------------------------------------------
void affine_transform(
resizable_tensor& dest,
tensor& dest,
const tensor& src,
const float A,
const float B
);
/*!
requires
- dest.size()==src.size()
ensures
- have_same_dimensions(#dest,src) == true
- #dest == A*src + B
!*/
void affine_transform(
tensor& dest,
const tensor& src1,
const tensor& src2,
const float A,
const float B,
const float C
);
/*!
requires
- dest.size()==src1.size()
- dest.size()==src2.size()
ensures
- #dest == A*src1 + src2*B + C
!*/
void affine_transform(
tensor& dest,
const tensor& src1,
const tensor& src2,
const tensor& src3,
const float A,
const float B,
const float C,
const float D
);
/*!
requires
- dest.size()==src1.size()
- dest.size()==src2.size()
- dest.size()==src3.size()
ensures
- #dest == A*src1 + src2*B + src3*C + D
!*/
// ----------------------------------------------------------------------------------------
void affine_transform(
......
......@@ -201,6 +201,7 @@ namespace
print_spinner();
resizable_tensor dest, src(3,4), A(1,4), B(1,4);
src = 2;
dest.copy_size(src);
affine_transform(dest, src, 2, 3);
dlog << LINFO << mat(dest);
matrix<float> truth1(3,4), truth2(3,4);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment