Commit 8c41c850 authored by Davis King's avatar Davis King

Made many of the mat() converters bind the resulting matrix expressions into

the BLAS bindings.
parent b5038f78
......@@ -10,6 +10,7 @@
#include "matrix_assign_fwd.h"
#include "matrix_default_mul.h"
#include "matrix_conj_trans.h"
#include "matrix_mat.h"
namespace dlib
{
......@@ -159,6 +160,29 @@ namespace dlib
const static int value = general_matrix;
};
template < typename T, typename MM >
struct matrix_type_id<matrix_op<op_array2d_to_mat<array2d<T,MM> > > >
{ const static int value = general_matrix; };
template < typename T, typename MM >
struct matrix_type_id<matrix_op<op_array_to_mat<array<T,MM> > > >
{ const static int value = column_matrix; };
template < typename value_type, typename alloc >
struct matrix_type_id<matrix_op<op_std_vect_to_mat<std::vector<value_type,alloc> > > >
{ const static int value = column_matrix; };
template < typename value_type, typename alloc >
struct matrix_type_id<matrix_op<op_std_vect_to_mat<std_vector_c<value_type,alloc> > > >
{ const static int value = column_matrix; };
template < typename T >
struct matrix_type_id<matrix_op<op_pointer_to_col_vect<T> > >
{ const static int value = column_matrix; };
template < typename T >
struct matrix_type_id<matrix_op<op_pointer_to_mat<T> > >
{ const static int value = general_matrix; };
// ------------------------------------------------------------------------------------
template <typename T, typename U>
......
......@@ -408,8 +408,34 @@ namespace dlib
template <typename T, long NR, long NC, typename MM>
int get_ld (const assignable_sub_matrix<T,NR,NC,MM,column_major_layout>& m) { return m.m.nr(); }
template <typename T, typename MM>
int get_ld (const matrix_op<op_array2d_to_mat<array2d<T,MM> > >& m) { return m.nc(); }
template <typename T, typename MM>
int get_ld (const matrix_op<op_array_to_mat<array<T,MM> > >& m) { return m.nc(); }
template < typename value_type, typename alloc >
int get_ld (const matrix_op<op_std_vect_to_mat<std::vector<value_type,alloc> > >& m) { return m.nc(); }
template < typename value_type, typename alloc >
int get_ld (const matrix_op<op_std_vect_to_mat<std_vector_c<value_type,alloc> > >& m) { return m.nc(); }
template <typename T>
int get_ld (const matrix_op<op_pointer_to_col_vect<T> >& m) { return m.nc(); }
template <typename T>
int get_ld (const matrix_op<op_pointer_to_mat<T> >& m) { return m.nc(); }
// --------
template <typename T, typename MM>
int get_inc (const matrix_op<op_array2d_to_mat<array2d<T,MM> > >& ) { return 1; }
template <typename T, typename MM>
int get_inc (const matrix_op<op_array_to_mat<array<T,MM> > >& ) { return 1; }
template < typename value_type, typename alloc >
int get_inc (const matrix_op<op_std_vect_to_mat<std::vector<value_type,alloc> > >& ) { return 1; }
template < typename value_type, typename alloc >
int get_inc (const matrix_op<op_std_vect_to_mat<std_vector_c<value_type,alloc> > >& ) { return 1; }
template <typename T>
int get_inc (const matrix_op<op_pointer_to_col_vect<T> >& ) { return 1; }
template <typename T>
int get_inc (const matrix_op<op_pointer_to_mat<T> >& ) { return 1; }
template <typename T, long NR, long NC, typename MM, typename L>
int get_inc (const matrix<T,NR,NC,MM,L>& ) { return 1; }
......@@ -522,6 +548,19 @@ namespace dlib
template <typename T, long NR, long NC, typename MM, typename L>
T* get_ptr (assignable_sub_matrix<T,NR,NC,MM,L>& m) { return &m(0,0); }
template <typename T, typename MM>
const T* get_ptr (const matrix_op<op_array2d_to_mat<array2d<T,MM> > >& m) { return &m.op.array[0][0]; }
template <typename T, typename MM>
const T* get_ptr (const matrix_op<op_array_to_mat<array<T,MM> > >& m) { return &m.op.vect[0]; }
template < typename T, typename alloc >
const T* get_ptr (const matrix_op<op_std_vect_to_mat<std::vector<T,alloc> > >& m) { return &m.op.vect[0]; }
template < typename T, typename alloc >
const T* get_ptr (const matrix_op<op_std_vect_to_mat<std_vector_c<T,alloc> > >& m) { return &m.op.vect[0]; }
template <typename T>
const T* get_ptr (const matrix_op<op_pointer_to_col_vect<T> >& m) { return m.op.ptr; }
template <typename T>
const T* get_ptr (const matrix_op<op_pointer_to_mat<T> >& m) { return m.op.ptr; }
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
......
......@@ -42,6 +42,61 @@ namespace
)
{}
void test_mat_bindings()
{
using namespace dlib;
using namespace dlib::blas_bindings;
matrix<double,1,0> rv(10);
matrix<double,0,1> cv(10);
double val;
rv = 1; cv = 1;
counter_dot() = 0;
val = rv*cv;
DLIB_TEST(val == 10);
DLIB_TEST(counter_dot() == 1);
rv = 1; cv = 1;
counter_dot() = 0;
val = rv*mat(&cv(0),cv.size());
DLIB_TEST(val == 10);
DLIB_TEST(counter_dot() == 1);
rv = 1; cv = 1;
counter_dot() = 0;
val = trans(mat(&rv(0),rv.size()))*mat(&cv(0),cv.size());
DLIB_TEST(val == 10);
DLIB_TEST(counter_dot() == 1);
std::vector<double> sv(10,1);
rv = 1;
counter_dot() = 0;
val = trans(mat(&rv(0),rv.size()))*mat(sv);
DLIB_TEST(val == 10);
DLIB_TEST(counter_dot() == 1);
counter_dot() = 0;
val = trans(mat(sv))*mat(sv);
DLIB_TEST(val == 10);
DLIB_TEST(counter_dot() == 1);
std_vector_c<double> svc(10,1);
counter_dot() = 0;
val = trans(mat(svc))*mat(svc);
DLIB_TEST(val == 10);
DLIB_TEST(counter_dot() == 1);
dlib::array<double> arr(10);
for (unsigned int i = 0; i < arr.size(); ++i)
arr[i] = 1;
counter_dot() = 0;
val = trans(mat(arr))*mat(arr);
DLIB_TEST(val == 10);
DLIB_TEST(counter_dot() == 1);
}
template <typename matrix_type, typename cv_type, typename rv_type>
void test_dot_stuff(
matrix_type& m,
......@@ -238,6 +293,8 @@ namespace
}
test_mat_bindings();
print_spinner();
}
};
......
......@@ -258,6 +258,21 @@ namespace
test_gemm_stuff_conj(c);
}
{
using namespace dlib;
using namespace dlib::blas_bindings;
array2d<double> a(100,100);
array2d<double> b(100,100);
matrix<double> c;
counter_gemm() = 0;
c = mat(a)*mat(b);
DLIB_TEST(counter_gemm() == 1);
counter_gemm() = 0;
c = trans(2*mat(a)*mat(b));
DLIB_TEST(counter_gemm() == 1);
}
print_spinner();
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment