Commit 2f257cb0 authored by Davis King's avatar Davis King

Added a dot() function for matrix objects.

--HG--
extra : convert_revision : svn%3Afdd8eb12-d10e-0410-9acb-85c331704f74/trunk%403158
parent 36a8b0dc
......@@ -370,6 +370,98 @@ namespace dlib
return exp(m.ref());
}
// ----------------------------------------------------------------------------------------
template <
typename EXP1,
typename EXP2
>
typename enable_if_c<(EXP1::NR != 1 && EXP1::NC != 1) || (EXP2::NR != 1 && EXP2::NC != 1),
typename EXP1::type>::type
dot (
const matrix_exp<EXP1>& m1,
const matrix_exp<EXP2>& m2
)
{
// You are getting an error on this line because you are trying to
// compute the dot product between two matrices that aren't both vectors (i.e.
// they aren't column or row matrices).
COMPILE_TIME_ASSERT(EXP1::NR*EXP1::NC == 0 ||
EXP2::NR*EXP2::NC == 0);
DLIB_ASSERT(is_vector(m1) && is_vector(m2) && m1.size() == m2.size(),
"\t type dot(const matrix_exp& m1, const matrix_exp& m2)"
<< "\n\t You can only compute the dot product between vectors of equal length"
<< "\n\t is_vector(m1): " << is_vector(m1)
<< "\n\t is_vector(m2): " << is_vector(m2)
<< "\n\t m1.size(): " << m1.size()
<< "\n\t m2.size(): " << m2.size()
);
if (is_col_vector(m1) && is_col_vector(m2)) return (trans(m1)*m2)(0);
if (is_col_vector(m1) && is_row_vector(m2)) return (m2*m1)(0);
if (is_row_vector(m1) && is_col_vector(m2)) return (m1*m2)(0);
//if (is_row_vector(m1) && is_row_vector(m2))
return (m1*trans(m2))(0);
}
template < typename EXP1, typename EXP2 >
typename enable_if_c<EXP1::NR == 1 && EXP2::NR == 1, typename EXP1::type>::type
dot ( const matrix_exp<EXP1>& m1, const matrix_exp<EXP2>& m2)
{
DLIB_ASSERT(m1.size() == m2.size(),
"\t type dot(const matrix_exp& m1, const matrix_exp& m2)"
<< "\n\t You can only compute the dot product between vectors of equal length"
<< "\n\t m1.size(): " << m1.size()
<< "\n\t m2.size(): " << m2.size()
);
return m1*trans(m2);
}
template < typename EXP1, typename EXP2 >
typename enable_if_c<EXP1::NR == 1 && EXP2::NC == 1, typename EXP1::type>::type
dot ( const matrix_exp<EXP1>& m1, const matrix_exp<EXP2>& m2)
{
DLIB_ASSERT(m1.size() == m2.size(),
"\t type dot(const matrix_exp& m1, const matrix_exp& m2)"
<< "\n\t You can only compute the dot product between vectors of equal length"
<< "\n\t m1.size(): " << m1.size()
<< "\n\t m2.size(): " << m2.size()
);
return m1*m2;
}
template < typename EXP1, typename EXP2 >
typename enable_if_c<EXP1::NC == 1 && EXP2::NR == 1, typename EXP1::type>::type
dot ( const matrix_exp<EXP1>& m1, const matrix_exp<EXP2>& m2)
{
DLIB_ASSERT(m1.size() == m2.size(),
"\t type dot(const matrix_exp& m1, const matrix_exp& m2)"
<< "\n\t You can only compute the dot product between vectors of equal length"
<< "\n\t m1.size(): " << m1.size()
<< "\n\t m2.size(): " << m2.size()
);
return m2*m1;
}
template < typename EXP1, typename EXP2 >
typename enable_if_c<EXP1::NC == 1 && EXP2::NC == 1, typename EXP1::type>::type
dot ( const matrix_exp<EXP1>& m1, const matrix_exp<EXP2>& m2)
{
DLIB_ASSERT(m1.size() == m2.size(),
"\t type dot(const matrix_exp& m1, const matrix_exp& m2)"
<< "\n\t You can only compute the dot product between vectors of equal length"
<< "\n\t m1.size(): " << m1.size()
<< "\n\t m2.size(): " << m2.size()
);
return trans(m1)*m2;
}
// ----------------------------------------------------------------------------------------
template <long R, long C>
......
......@@ -52,6 +52,21 @@ namespace dlib
- returns the transpose of the matrix m
!*/
// ----------------------------------------------------------------------------------------
const matrix_type::type dot (
const matrix_exp& m1,
const matrix_exp& m2
);
/*!
requires
- is_vector(m1) == true
- is_vector(m2) == true
- m1.size() == m2.size()
ensures
- returns the dot product between m1 and m2.
!*/
// ----------------------------------------------------------------------------------------
const matrix_exp lowerm (
......
......@@ -1146,6 +1146,56 @@ namespace
}
{
matrix<double> m1, m2;
m1.set_size(3,1);
m2.set_size(1,3);
m1 = 1,2,3;
m2 = 4,5,6;
DLIB_TEST(dot(m1, m2) == 1*4 + 2*5 + 3*6);
DLIB_TEST(dot(m1, trans(m2)) == 1*4 + 2*5 + 3*6);
DLIB_TEST(dot(trans(m1), m2) == 1*4 + 2*5 + 3*6);
DLIB_TEST(dot(trans(m1), trans(m2)) == 1*4 + 2*5 + 3*6);
}
{
matrix<double,3,1> m1, m2;
m1.set_size(3,1);
m2.set_size(3,1);
m1 = 1,2,3;
m2 = 4,5,6;
DLIB_TEST(dot(m1, m2) == 1*4 + 2*5 + 3*6);
DLIB_TEST(dot(m1, trans(m2)) == 1*4 + 2*5 + 3*6);
DLIB_TEST(dot(trans(m1), m2) == 1*4 + 2*5 + 3*6);
DLIB_TEST(dot(trans(m1), trans(m2)) == 1*4 + 2*5 + 3*6);
}
{
matrix<double,1,3> m1, m2;
m1.set_size(1,3);
m2.set_size(1,3);
m1 = 1,2,3;
m2 = 4,5,6;
DLIB_TEST(dot(m1, m2) == 1*4 + 2*5 + 3*6);
DLIB_TEST(dot(m1, trans(m2)) == 1*4 + 2*5 + 3*6);
DLIB_TEST(dot(trans(m1), m2) == 1*4 + 2*5 + 3*6);
DLIB_TEST(dot(trans(m1), trans(m2)) == 1*4 + 2*5 + 3*6);
}
{
matrix<double,1,3> m1;
matrix<double> m2;
m1.set_size(1,3);
m2.set_size(3,1);
m1 = 1,2,3;
m2 = 4,5,6;
DLIB_TEST(dot(m1, m2) == 1*4 + 2*5 + 3*6);
DLIB_TEST(dot(m1, trans(m2)) == 1*4 + 2*5 + 3*6);
DLIB_TEST(dot(trans(m1), m2) == 1*4 + 2*5 + 3*6);
DLIB_TEST(dot(trans(m1), trans(m2)) == 1*4 + 2*5 + 3*6);
}
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment