Commit cbf420f3 authored by Davis King's avatar Davis King

- Made scale_by() work on dlib::matrix objects.

 - Added an add() and subtract() that works on
   sparse and dense vectors.
parent 0d3efbbd
...@@ -354,7 +354,7 @@ namespace dlib ...@@ -354,7 +354,7 @@ namespace dlib
// ------------------------------------------------------------------------------------ // ------------------------------------------------------------------------------------
template <typename T, typename U> template <typename T, typename U>
void scale_by ( typename disable_if<is_matrix<T>,void>::type scale_by (
T& a, T& a,
const U& value const U& value
) )
...@@ -365,8 +365,127 @@ namespace dlib ...@@ -365,8 +365,127 @@ namespace dlib
} }
} }
template <typename T, typename U>
typename enable_if<is_matrix<T>,void>::type scale_by (
T& a,
const U& value
)
{
a *= value;
}
// ------------------------------------------------------------------------------------
template <typename T>
typename disable_if<is_matrix<T>,T>::type add (
const T& a,
const T& b
)
{
T temp;
typename T::const_iterator i = a.begin();
typename T::const_iterator j = b.begin();
while (i != a.end() && j != b.end())
{
if (i->first == j->first)
{
temp.insert(temp.end(), std::make_pair(i->first, i->second + j->second));
++i;
++j;
}
else if (i->first < j->first)
{
temp.insert(temp.end(), *i);
++i;
}
else
{
temp.insert(temp.end(), *j);
++j;
}
}
while (i != a.end())
{
temp.insert(temp.end(), *i);
++i;
}
while (j != b.end())
{
temp.insert(temp.end(), *j);
++j;
}
return temp;
}
template <typename T, typename U>
typename enable_if_c<is_matrix<T>::value && is_matrix<U>::value, matrix_add_exp<T,U> >::type add (
const T& a,
const U& b
)
{
return matrix_add_exp<T,U>(a.ref(),b.ref());
}
// ------------------------------------------------------------------------------------
template <typename T>
typename disable_if<is_matrix<T>,T>::type subtract (
const T& a,
const T& b
)
{
T temp;
typename T::const_iterator i = a.begin();
typename T::const_iterator j = b.begin();
while (i != a.end() && j != b.end())
{
if (i->first == j->first)
{
temp.insert(temp.end(), std::make_pair(i->first, i->second - j->second));
++i;
++j;
}
else if (i->first < j->first)
{
temp.insert(temp.end(), *i);
++i;
}
else
{
temp.insert(temp.end(), std::make_pair(j->first, -j->second));
++j;
}
}
while (i != a.end())
{
temp.insert(temp.end(), *i);
++i;
}
while (j != b.end())
{
temp.insert(temp.end(), std::make_pair(j->first, -j->second));
++j;
}
return temp;
}
template <typename T, typename U>
typename enable_if_c<is_matrix<T>::value && is_matrix<U>::value, matrix_subtract_exp<T,U> >::type subtract (
const T& a,
const U& b
)
{
return matrix_subtract_exp<T,U>(a.ref(),b.ref());
}
// ------------------------------------------------------------------------------------
// ------------------------------------------------------------------------------------ // ------------------------------------------------------------------------------------
// ------------------------------------------------------------------------------------
namespace impl namespace impl
{ {
......
...@@ -254,12 +254,44 @@ namespace dlib ...@@ -254,12 +254,44 @@ namespace dlib
); );
/*! /*!
requires requires
- a is an unsorted sparse vector - a is an unsorted sparse vector or a dlib::matrix
ensures ensures
- #a == a*value - #a == a*value
(i.e. multiplies every element of the vector a by value) (i.e. multiplies every element of the vector a by value)
!*/ !*/
// ----------------------------------------------------------------------------------------
template <typename T>
T add (
const T& a,
const T& b
);
/*!
requires
- a is a sparse vector or dlib::matrix
- b is a sparse vector or dlib::matrix
ensures
- returns a vector or matrix which represents a+b. If the inputs are
sparse vectors then the result is a sparse vector.
!*/
// ----------------------------------------------------------------------------------------
template <typename T>
T subtract (
const T& a,
const T& b
);
/*!
requires
- a is a sparse vector or dlib::matrix
- b is a sparse vector or dlib::matrix
ensures
- returns a vector or matrix which represents a-b. If the inputs are
sparse vectors then the result is a sparse vector.
!*/
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
template <typename T> template <typename T>
......
...@@ -55,6 +55,92 @@ namespace ...@@ -55,6 +55,92 @@ namespace
DLIB_TEST(max(v) == 0); DLIB_TEST(max(v) == 0);
DLIB_TEST(min(v) == -9); DLIB_TEST(min(v) == -9);
{
matrix<double> a(2,2), b(2,2);
a = randm(2,2);
b = randm(2,2);
DLIB_TEST(equal(a-b, subtract(a,b)));
DLIB_TEST(equal(a+b, add(a,b)));
DLIB_TEST(equal(a-(b+b), subtract(a,b+b)));
DLIB_TEST(equal(a+b+b, add(a,b+b)));
}
{
std::map<unsigned long,double> a, b, c;
a[1] = 2;
a[3] = 5;
b[0] = 3;
b[1] = 1;
c = add(a,b);
DLIB_TEST(c.size() == 3);
DLIB_TEST(c[0] == 3);
DLIB_TEST(c[1] == 3);
DLIB_TEST(c[3] == 5);
c = subtract(a,b);
DLIB_TEST(c.size() == 3);
DLIB_TEST(c[0] == -3);
DLIB_TEST(c[1] == 1);
DLIB_TEST(c[3] == 5);
c = add(b,a);
DLIB_TEST(c.size() == 3);
DLIB_TEST(c[0] == 3);
DLIB_TEST(c[1] == 3);
DLIB_TEST(c[3] == 5);
c = subtract(b,a);
DLIB_TEST(c.size() == 3);
DLIB_TEST(c[0] == 3);
DLIB_TEST(c[1] == -1);
DLIB_TEST(c[3] == -5);
std::vector<std::pair<unsigned long,double> > aa, bb, cc;
aa.assign(a.begin(), a.end());
bb.assign(b.begin(), b.end());
cc = add(aa,bb);
DLIB_TEST(cc.size() == 3);
DLIB_TEST(cc[0].first == 0);
DLIB_TEST(cc[1].first == 1);
DLIB_TEST(cc[2].first == 3);
DLIB_TEST(cc[0].second == 3);
DLIB_TEST(cc[1].second == 3);
DLIB_TEST(cc[2].second == 5);
cc = subtract(aa,bb);
DLIB_TEST(cc.size() == 3);
DLIB_TEST(cc[0].first == 0);
DLIB_TEST(cc[1].first == 1);
DLIB_TEST(cc[2].first == 3);
DLIB_TEST(cc[0].second == -3);
DLIB_TEST(cc[1].second == 1);
DLIB_TEST(cc[2].second == 5);
cc = add(bb,aa);
DLIB_TEST(cc.size() == 3);
DLIB_TEST(cc[0].first == 0);
DLIB_TEST(cc[1].first == 1);
DLIB_TEST(cc[2].first == 3);
DLIB_TEST(cc[0].second == 3);
DLIB_TEST(cc[1].second == 3);
DLIB_TEST(cc[2].second == 5);
cc = subtract(bb,aa);
DLIB_TEST(cc.size() == 3);
DLIB_TEST(cc[0].first == 0);
DLIB_TEST(cc[1].first == 1);
DLIB_TEST(cc[2].first == 3);
DLIB_TEST(cc[0].second == 3);
DLIB_TEST(cc[1].second == -1);
DLIB_TEST(cc[2].second == -5);
}
} }
}; };
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment