Commit 28b5980f authored by Davis King's avatar Davis King

Added a new scale_rows() function. I also overloaded the * operator so that the

expressions mat*diagm(v) and diagm(v)*mat get bound to calls to scale_columns() and
scale_rows() respectively.

--HG--
extra : convert_revision : svn%3Afdd8eb12-d10e-0410-9acb-85c331704f74/trunk%403864
parent 11657624
......@@ -2565,13 +2565,20 @@ namespace dlib
const matrix_exp<EXP2>& v
)
{
// Both arguments to this function must contain the same type of element
COMPILE_TIME_ASSERT((is_same_type<typename EXP1::type,typename EXP2::type>::value == true));
COMPILE_TIME_ASSERT(EXP2::NC == 1 || EXP2::NC == 0);
COMPILE_TIME_ASSERT(EXP1::NC == EXP2::NR || EXP1::NC == 0 || EXP2::NR == 0);
// The v argument must be a row or column vector.
COMPILE_TIME_ASSERT((EXP2::NC == 1 || EXP2::NC == 0) || (EXP2::NR == 1 || EXP2::NR == 0));
DLIB_ASSERT(is_col_vector(v) == true && v.size() == m.nc(),
// figure out the compile time known length of v
const long v_len = ((EXP2::NR)*(EXP2::NC) == 0)? 0 : (tmax<EXP2::NR,EXP2::NC>::value);
// the length of v must match the number of columns in m
COMPILE_TIME_ASSERT(EXP1::NC == v_len || EXP1::NC == 0 || v_len == 0);
DLIB_ASSERT(is_vector(v) == true && v.size() == m.nc(),
"\tconst matrix_exp scale_columns(m, v)"
<< "\n\tv must be a column vector and its length must match the number of columns in m"
<< "\n\tv must be a row or column vector and its length must match the number of columns in m"
<< "\n\tm.nr(): " << m.nr()
<< "\n\tm.nc(): " << m.nc()
<< "\n\tv.nr(): " << v.nr()
......@@ -2581,6 +2588,98 @@ namespace dlib
return matrix_op<op>(op(m.ref(),v.ref()));
}
// ----------------------------------------------------------------------------------------
// turn expressions of the form mat*diagm(v) into scale_columns(mat, v)
template <
typename EXP1,
typename EXP2
>
const matrix_op<op_scale_columns<EXP1,EXP2> > operator* (
const matrix_exp<EXP1>& m,
const matrix_exp<matrix_op<op_diagm<EXP2> > >& v
)
{
std::cout << "yay" << std::endl;
return scale_columns(m,v.ref().op.m);
}
// ----------------------------------------------------------------------------------------
template <typename M1, typename M2>
struct op_scale_rows
{
op_scale_rows(const M1& m1_, const M2& m2_) : m1(m1_), m2(m2_) {}
const M1& m1;
const M2& m2;
const static long cost = M1::cost + M2::cost + 1;
typedef typename M1::type type;
typedef const typename M1::type const_ret_type;
typedef typename M1::mem_manager_type mem_manager_type;
typedef typename M1::layout_type layout_type;
const static long NR = M1::NR;
const static long NC = M1::NC;
const_ret_type apply ( long r, long c) const { return m1(r,c)*m2(r); }
long nr () const { return m1.nr(); }
long nc () const { return m1.nc(); }
template <typename U> bool aliases ( const matrix_exp<U>& item) const
{ return m1.aliases(item) || m2.aliases(item) ; }
template <typename U> bool destructively_aliases ( const matrix_exp<U>& item) const
{ return m1.destructively_aliases(item) || m2.aliases(item); }
};
template <
typename EXP1,
typename EXP2
>
const matrix_op<op_scale_rows<EXP1,EXP2> > scale_rows (
const matrix_exp<EXP1>& m,
const matrix_exp<EXP2>& v
)
{
// Both arguments to this function must contain the same type of element
COMPILE_TIME_ASSERT((is_same_type<typename EXP1::type,typename EXP2::type>::value == true));
// The v argument must be a row or column vector.
COMPILE_TIME_ASSERT((EXP2::NC == 1 || EXP2::NC == 0) || (EXP2::NR == 1 || EXP2::NR == 0));
// figure out the compile time known length of v
const long v_len = ((EXP2::NR)*(EXP2::NC) == 0)? 0 : (tmax<EXP2::NR,EXP2::NC>::value);
// the length of v must match the number of rows in m
COMPILE_TIME_ASSERT(EXP1::NR == v_len || EXP1::NR == 0 || v_len == 0);
DLIB_ASSERT(is_vector(v) == true && v.size() == m.nr(),
"\tconst matrix_exp scale_rows(m, v)"
<< "\n\tv must be a row or column vector and its length must match the number of rows in m"
<< "\n\tm.nr(): " << m.nr()
<< "\n\tm.nc(): " << m.nc()
<< "\n\tv.nr(): " << v.nr()
<< "\n\tv.nc(): " << v.nc()
);
typedef op_scale_rows<EXP1,EXP2> op;
return matrix_op<op>(op(m.ref(),v.ref()));
}
// ----------------------------------------------------------------------------------------
// turn expressions of the form diagm(v)*mat into scale_rows(mat, v)
template <
typename EXP1,
typename EXP2
>
const matrix_op<op_scale_rows<EXP1,EXP2> > operator* (
const matrix_exp<matrix_op<op_diagm<EXP2> > >& v,
const matrix_exp<EXP1>& m
)
{
std::cout << "yay" << std::endl;
return scale_rows(m,v.ref().op.m);
}
// ----------------------------------------------------------------------------------------
struct sort_columns_sort_helper
......
......@@ -806,7 +806,7 @@ namespace dlib
);
/*!
requires
- is_col_vector(v) == true
- is_vector(v) == true
- v.size() == m.nc()
- m and v both contain the same type of element
ensures
......@@ -817,6 +817,35 @@ namespace dlib
R(r,c) == m(r,c) * v(c)
- i.e. R is the result of multiplying each of m's columns by
the corresponding scalar in v.
- Note that this function is identical to the expression m*diagm(v).
That is, the * operator is overloaded for this case and will invoke
scale_columns() automatically as appropriate.
!*/
// ----------------------------------------------------------------------------------------
const matrix_exp scale_rows (
const matrix_exp& m,
const matrix_exp& v
);
/*!
requires
- is_vector(v) == true
- v.size() == m.nr()
- m and v both contain the same type of element
ensures
- returns a matrix R such that:
- R::type == the same type that was in m and v.
- R has the same dimensions as m.
- for all valid r and c:
R(r,c) == m(r,c) * v(r)
- i.e. R is the result of multiplying each of m's rows by
the corresponding scalar in v.
- Note that this function is identical to the expression diagm(v)*m.
That is, the * operator is overloaded for this case and will invoke
scale_rows() automatically as appropriate.
!*/
// ----------------------------------------------------------------------------------------
......
......@@ -66,6 +66,7 @@ set (tests
rand.cpp
read_write_mutex.cpp
reference_counter.cpp
scale_rows_columns.cpp
sequence.cpp
serialize.cpp
set.cpp
......
......@@ -76,6 +76,7 @@ SRC += queue.cpp
SRC += rand.cpp
SRC += read_write_mutex.cpp
SRC += reference_counter.cpp
SRC += scale_rows_columns.cpp
SRC += sequence.cpp
SRC += serialize.cpp
SRC += set.cpp
......
// Copyright (C) 2010 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#include <dlib/matrix.h>
#include <sstream>
#include <string>
#include <cstdlib>
#include <ctime>
#include <vector>
#include "../stl_checked.h"
#include "../array.h"
#include "../rand.h"
#include "tester.h"
#include <dlib/memory_manager_stateless.h>
#include <dlib/array2d.h>
namespace
{
using namespace test;
using namespace dlib;
using namespace std;
logger dlog("test.scale_rows_columns");
void matrix_test (
)
/*!
ensures
- runs tests on the matrix stuff compliance with the specs
!*/
{
typedef memory_manager_stateless<char>::kernel_2_2a MM;
print_spinner();
{
matrix<double,3,3> m = round(10*randm(3,3));
matrix<double,3,1> v = round(10*randm(3,1));
DLIB_TEST(equal( m*diagm(v) , m*tmp(diagm(v)) ));
DLIB_TEST(equal( scale_columns(m,v) , m*tmp(diagm(v)) ));
DLIB_TEST(equal( diagm(v)*m , tmp(diagm(v))*m ));
DLIB_TEST(equal( scale_rows(m,v) , tmp(diagm(v))*m ));
}
{
matrix<double,3,3> m = round(10*randm(3,3));
matrix<double,1,3> v = round(10*randm(1,3));
DLIB_TEST(equal( m*diagm(v) , m*tmp(diagm(v)) ));
DLIB_TEST(equal( scale_columns(m,v) , m*tmp(diagm(v)) ));
DLIB_TEST(equal( diagm(v)*m , tmp(diagm(v))*m ));
DLIB_TEST(equal( scale_rows(m,v) , tmp(diagm(v))*m ));
}
{
matrix<double> m = round(10*randm(3,3));
matrix<double,1,3> v = round(10*randm(1,3));
DLIB_TEST(equal( m*diagm(v) , m*tmp(diagm(v)) ));
DLIB_TEST(equal( scale_columns(m,v) , m*tmp(diagm(v)) ));
DLIB_TEST(equal( diagm(v)*m , tmp(diagm(v))*m ));
DLIB_TEST(equal( scale_rows(m,v) , tmp(diagm(v))*m ));
}
{
matrix<double> m = round(10*randm(3,3));
matrix<double,0,3> v = round(10*randm(1,3));
DLIB_TEST(equal( m*diagm(v) , m*tmp(diagm(v)) ));
DLIB_TEST(equal( scale_columns(m,v) , m*tmp(diagm(v)) ));
DLIB_TEST(equal( diagm(v)*m , tmp(diagm(v))*m ));
DLIB_TEST(equal( scale_rows(m,v) , tmp(diagm(v))*m ));
}
{
matrix<double> m = round(10*randm(3,3));
matrix<double,1,0> v = round(10*randm(1,3));
DLIB_TEST(equal( m*diagm(v) , m*tmp(diagm(v)) ));
DLIB_TEST(equal( scale_columns(m,v) , m*tmp(diagm(v)) ));
DLIB_TEST(equal( diagm(v)*m , tmp(diagm(v))*m ));
DLIB_TEST(equal( scale_rows(m,v) , tmp(diagm(v))*m ));
}
{
matrix<double> m = round(10*randm(3,3));
matrix<double,3,0> v = round(10*randm(3,1));
DLIB_TEST(equal( m*diagm(v) , m*tmp(diagm(v)) ));
DLIB_TEST(equal( scale_columns(m,v) , m*tmp(diagm(v)) ));
DLIB_TEST(equal( diagm(v)*m , tmp(diagm(v))*m ));
DLIB_TEST(equal( scale_rows(m,v) , tmp(diagm(v))*m ));
}
{
matrix<double> m = round(10*randm(3,3));
matrix<double,0,1> v = round(10*randm(3,1));
DLIB_TEST(equal( m*diagm(v) , m*tmp(diagm(v)) ));
DLIB_TEST(equal( scale_columns(m,v) , m*tmp(diagm(v)) ));
DLIB_TEST(equal( diagm(v)*m , tmp(diagm(v))*m ));
DLIB_TEST(equal( scale_rows(m,v) , tmp(diagm(v))*m ));
}
{
matrix<double,3,3> m = round(10*randm(3,3));
matrix<double,3,0> v = round(10*randm(3,1));
DLIB_TEST(equal( m*diagm(v) , m*tmp(diagm(v)) ));
DLIB_TEST(equal( scale_columns(m,v) , m*tmp(diagm(v)) ));
DLIB_TEST(equal( diagm(v)*m , tmp(diagm(v))*m ));
DLIB_TEST(equal( scale_rows(m,v) , tmp(diagm(v))*m ));
}
{
matrix<double,3,3> m = round(10*randm(3,3));
matrix<double,0,1> v = round(10*randm(3,1));
DLIB_TEST(equal( m*diagm(v) , m*tmp(diagm(v)) ));
DLIB_TEST(equal( scale_columns(m,v) , m*tmp(diagm(v)) ));
DLIB_TEST(equal( diagm(v)*m , tmp(diagm(v))*m ));
DLIB_TEST(equal( scale_rows(m,v) , tmp(diagm(v))*m ));
}
{
matrix<double,3,5> m = round(10*randm(3,5));
matrix<double,0,1> v1 = round(10*randm(5,1));
matrix<double,0,1> v2 = round(10*randm(3,1));
DLIB_TEST(equal( m*diagm(v1) , m*tmp(diagm(v1)) ));
DLIB_TEST(equal( scale_columns(m,v1) , m*tmp(diagm(v1)) ));
DLIB_TEST(equal( diagm(v2)*m , tmp(diagm(v2))*m ));
DLIB_TEST(equal( scale_rows(m,v2) , tmp(diagm(v2))*m ));
}
{
matrix<double,3,5> m = round(10*randm(3,5));
matrix<double,5,1> v1 = round(10*randm(5,1));
matrix<double,3,1> v2 = round(10*randm(3,1));
DLIB_TEST(equal( m*diagm(v1) , m*tmp(diagm(v1)) ));
DLIB_TEST(equal( scale_columns(m,v1) , m*tmp(diagm(v1)) ));
DLIB_TEST(equal( diagm(v2)*m , tmp(diagm(v2))*m ));
DLIB_TEST(equal( scale_rows(m,v2) , tmp(diagm(v2))*m ));
}
}
class matrix_tester : public tester
{
public:
matrix_tester (
) :
tester ("test_scale_rows_columns",
"Runs tests on the scale_rows and scale_columns functions.")
{}
void perform_test (
)
{
for (int i = 0; i < 10; ++i)
matrix_test();
}
} a;
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment