Commit 195c7694 authored by Davis King's avatar Davis King

Made column major matrices directly wrap matlab matrix objects when used

inside mex files.  This way, if you use matrix_colmajor or fmatrix_colmajor
in a mex file it will not do any unnecessary copying or transposing.
parent 6178838e
......@@ -261,6 +261,20 @@ namespace mex_binding
std::string msg;
};
// -------------------------------------------------------
template <typename T>
struct is_column_major_matrix : public default_is_kind_value {};
template <
typename T,
long num_rows,
long num_cols,
typename mem_manager
>
struct is_column_major_matrix<matrix<T,num_rows,num_cols,mem_manager,column_major_layout> >
{ static const bool value = true; };
// -------------------------------------------------------
template <
......@@ -563,7 +577,10 @@ namespace mex_binding
sout << " argument " << arg_idx+1 << " must be a matrix of doubles";
throw invalid_args_exception(sout.str());
}
assign_mat(arg_idx, arg , pointer_to_matrix(mxGetPr(prhs), nc, nr));
if (is_column_major_matrix<T>::value)
arg._private_set_mxArray((mxArray*)prhs);
else
assign_mat(arg_idx, arg , pointer_to_matrix(mxGetPr(prhs), nc, nr));
}
else if (is_same_type<type, float>::value)
{
......@@ -574,7 +591,10 @@ namespace mex_binding
throw invalid_args_exception(sout.str());
}
assign_mat(arg_idx, arg , pointer_to_matrix((const float*)mxGetData(prhs), nc, nr));
if (is_column_major_matrix<T>::value)
arg._private_set_mxArray((mxArray*)prhs);
else
assign_mat(arg_idx, arg , pointer_to_matrix((const float*)mxGetData(prhs), nc, nr));
}
else if (is_same_type<type, bool>::value)
{
......@@ -922,6 +942,26 @@ namespace mex_binding
}
}
void assign_to_matlab(
mxArray*& plhs,
matrix_colmajor& item
)
{
// Don't need to do a copy if it's this kind of matrix since we can just
// pull the underlying mxArray out directly and thus avoid a copy.
plhs = item._private_release_mxArray();
}
void assign_to_matlab(
mxArray*& plhs,
fmatrix_colmajor& item
)
{
// Don't need to do a copy if it's this kind of matrix since we can just
// pull the underlying mxArray out directly and thus avoid a copy.
plhs = item._private_release_mxArray();
}
void assign_to_matlab(
mxArray*& plhs,
matlab_struct& item
......@@ -989,6 +1029,14 @@ namespace mex_binding
{
}
// ----------------------------------------------------------------------------------------
template <typename T>
void mark_non_persistent (const T&){}
void mark_non_persistent(matrix_colmajor& item) { item._private_mark_non_persistent(); }
void mark_non_persistent(fmatrix_colmajor& item) { item._private_mark_non_persistent(); }
// ----------------------------------------------------------------------------------------
template <
......@@ -1010,6 +1058,8 @@ namespace mex_binding
typename basic_type<arg1_type>::type A1;
mark_non_persistent(A1);
int i = 0;
if (i < nrhs && is_input_type<arg1_type>::value) {validate_and_populate_arg(i,prhs[i],A1); ++i;} ELSE_ASSIGN_ARG_1;
......@@ -1036,6 +1086,9 @@ namespace mex_binding
typename basic_type<arg1_type>::type A1;
typename basic_type<arg2_type>::type A2;
mark_non_persistent(A1);
mark_non_persistent(A2);
int i = 0;
if (i < nrhs && is_input_type<arg1_type>::value) {validate_and_populate_arg(i,prhs[i],A1); ++i;} ELSE_ASSIGN_ARG_1;
if (i < nrhs && is_input_type<arg2_type>::value) {validate_and_populate_arg(i,prhs[i],A2); ++i;} ELSE_ASSIGN_ARG_2;
......@@ -1066,6 +1119,10 @@ namespace mex_binding
typename basic_type<arg2_type>::type A2;
typename basic_type<arg3_type>::type A3;
mark_non_persistent(A1);
mark_non_persistent(A2);
mark_non_persistent(A3);
int i = 0;
if (i < nrhs && is_input_type<arg1_type>::value) {validate_and_populate_arg(i,prhs[i],A1); ++i;} ELSE_ASSIGN_ARG_1;
if (i < nrhs && is_input_type<arg2_type>::value) {validate_and_populate_arg(i,prhs[i],A2); ++i;} ELSE_ASSIGN_ARG_2;
......@@ -1100,6 +1157,11 @@ namespace mex_binding
typename basic_type<arg3_type>::type A3;
typename basic_type<arg4_type>::type A4;
mark_non_persistent(A1);
mark_non_persistent(A2);
mark_non_persistent(A3);
mark_non_persistent(A4);
int i = 0;
if (i < nrhs && is_input_type<arg1_type>::value) {validate_and_populate_arg(i,prhs[i],A1); ++i;} ELSE_ASSIGN_ARG_1;
if (i < nrhs && is_input_type<arg2_type>::value) {validate_and_populate_arg(i,prhs[i],A2); ++i;} ELSE_ASSIGN_ARG_2;
......@@ -1139,6 +1201,12 @@ namespace mex_binding
typename basic_type<arg4_type>::type A4;
typename basic_type<arg5_type>::type A5;
mark_non_persistent(A1);
mark_non_persistent(A2);
mark_non_persistent(A3);
mark_non_persistent(A4);
mark_non_persistent(A5);
int i = 0;
if (i < nrhs && is_input_type<arg1_type>::value) {validate_and_populate_arg(i,prhs[i],A1); ++i;} ELSE_ASSIGN_ARG_1;
if (i < nrhs && is_input_type<arg2_type>::value) {validate_and_populate_arg(i,prhs[i],A2); ++i;} ELSE_ASSIGN_ARG_2;
......@@ -1183,6 +1251,13 @@ namespace mex_binding
typename basic_type<arg5_type>::type A5;
typename basic_type<arg6_type>::type A6;
mark_non_persistent(A1);
mark_non_persistent(A2);
mark_non_persistent(A3);
mark_non_persistent(A4);
mark_non_persistent(A5);
mark_non_persistent(A6);
int i = 0;
if (i < nrhs && is_input_type<arg1_type>::value) {validate_and_populate_arg(i,prhs[i],A1); ++i;} ELSE_ASSIGN_ARG_1;
if (i < nrhs && is_input_type<arg2_type>::value) {validate_and_populate_arg(i,prhs[i],A2); ++i;} ELSE_ASSIGN_ARG_2;
......@@ -1231,6 +1306,14 @@ namespace mex_binding
typename basic_type<arg6_type>::type A6;
typename basic_type<arg7_type>::type A7;
mark_non_persistent(A1);
mark_non_persistent(A2);
mark_non_persistent(A3);
mark_non_persistent(A4);
mark_non_persistent(A5);
mark_non_persistent(A6);
mark_non_persistent(A7);
int i = 0;
if (i < nrhs && is_input_type<arg1_type>::value) {validate_and_populate_arg(i,prhs[i],A1); ++i;} ELSE_ASSIGN_ARG_1;
if (i < nrhs && is_input_type<arg2_type>::value) {validate_and_populate_arg(i,prhs[i],A2); ++i;} ELSE_ASSIGN_ARG_2;
......@@ -1283,6 +1366,15 @@ namespace mex_binding
typename basic_type<arg7_type>::type A7;
typename basic_type<arg8_type>::type A8;
mark_non_persistent(A1);
mark_non_persistent(A2);
mark_non_persistent(A3);
mark_non_persistent(A4);
mark_non_persistent(A5);
mark_non_persistent(A6);
mark_non_persistent(A7);
mark_non_persistent(A8);
int i = 0;
if (i < nrhs && is_input_type<arg1_type>::value) {validate_and_populate_arg(i,prhs[i],A1); ++i;} ELSE_ASSIGN_ARG_1;
if (i < nrhs && is_input_type<arg2_type>::value) {validate_and_populate_arg(i,prhs[i],A2); ++i;} ELSE_ASSIGN_ARG_2;
......@@ -1339,6 +1431,16 @@ namespace mex_binding
typename basic_type<arg8_type>::type A8;
typename basic_type<arg9_type>::type A9;
mark_non_persistent(A1);
mark_non_persistent(A2);
mark_non_persistent(A3);
mark_non_persistent(A4);
mark_non_persistent(A5);
mark_non_persistent(A6);
mark_non_persistent(A7);
mark_non_persistent(A8);
mark_non_persistent(A9);
int i = 0;
if (i < nrhs && is_input_type<arg1_type>::value) {validate_and_populate_arg(i,prhs[i],A1); ++i;} ELSE_ASSIGN_ARG_1;
if (i < nrhs && is_input_type<arg2_type>::value) {validate_and_populate_arg(i,prhs[i],A2); ++i;} ELSE_ASSIGN_ARG_2;
......@@ -1400,6 +1502,17 @@ namespace mex_binding
typename basic_type<arg9_type>::type A9;
typename basic_type<arg10_type>::type A10;
mark_non_persistent(A1);
mark_non_persistent(A2);
mark_non_persistent(A3);
mark_non_persistent(A4);
mark_non_persistent(A5);
mark_non_persistent(A6);
mark_non_persistent(A7);
mark_non_persistent(A8);
mark_non_persistent(A9);
mark_non_persistent(A10);
int i = 0;
if (i < nrhs && is_input_type<arg1_type>::value) {validate_and_populate_arg(i,prhs[i],A1); ++i;} ELSE_ASSIGN_ARG_1;
if (i < nrhs && is_input_type<arg2_type>::value) {validate_and_populate_arg(i,prhs[i],A2); ++i;} ELSE_ASSIGN_ARG_2;
......
......@@ -17,6 +17,10 @@
#include "matrix_op.h"
#include <utility>
#ifdef MATLAB_MEX_FILE
#include <mex.h>
#endif
#ifdef _MSC_VER
// Disable the following warnings for Visual Studio
......@@ -1239,6 +1243,26 @@ namespace dlib
return data(0);
}
#ifdef MATLAB_MEX_FILE
void _private_set_mxArray(
mxArray* mem
)
{
data._private_set_mxArray(mem);
}
mxArray* _private_release_mxArray(
)
{
return data._private_release_mxArray();
}
void _private_mark_non_persistent()
{
data._private_mark_non_persistent();
}
#endif
void set_size (
long rows,
long cols
......@@ -1971,6 +1995,9 @@ namespace dlib
// ----------------------------------------------------------------------------------------
typedef matrix<double,0,0,default_memory_manager,column_major_layout> matrix_colmajor;
typedef matrix<float,0,0,default_memory_manager,column_major_layout> fmatrix_colmajor;
}
#ifdef _MSC_VER
......
......@@ -645,7 +645,18 @@ namespace dlib
// ----------------------------------------------------------------------------------------
template <
/*!A matrix_colmajor
This is just a typedef of the matrix object that uses column major layout.
!*/
typedef matrix<double,0,0,default_memory_manager,column_major_layout> matrix_colmajor;
/*!A fmatrix_colmajor
This is just a typedef of the matrix object that uses column major layout.
!*/
typedef matrix<float,0,0,default_memory_manager,column_major_layout> fmatrix_colmajor;
// ----------------------------------------------------------------------------------------
template <
typename T,
long NR,
long NC,
......
......@@ -6,6 +6,9 @@
#include "../algs.h"
#include "matrix_fwd.h"
#include "matrix_data_layout_abstract.h"
#ifdef MATLAB_MEX_FILE
#include <mex.h>
#endif
// GCC 4.8 gives false alarms about some matrix operations going out of bounds. Disable
// these false warnings.
......@@ -180,6 +183,12 @@ namespace dlib
{
}
#ifdef MATLAB_MEX_FILE
void _private_set_mxArray ( mxArray* ) { DLIB_CASSERT(false, "This function should never be called."); }
mxArray* _private_release_mxArray(){DLIB_CASSERT(false, "This function should never be called."); }
void _private_mark_non_persistent() {DLIB_CASSERT(false, "This function should never be called."); }
#endif
private:
T data[num_rows*num_cols];
};
......@@ -243,6 +252,12 @@ namespace dlib
{
}
#ifdef MATLAB_MEX_FILE
void _private_set_mxArray ( mxArray* ) { DLIB_CASSERT(false, "This function should never be called."); }
mxArray* _private_release_mxArray(){DLIB_CASSERT(false, "This function should never be called."); }
void _private_mark_non_persistent() {DLIB_CASSERT(false, "This function should never be called."); }
#endif
private:
T* data;
......@@ -318,6 +333,12 @@ namespace dlib
nr_ = nr;
}
#ifdef MATLAB_MEX_FILE
void _private_set_mxArray ( mxArray* ) { DLIB_CASSERT(false, "This function should never be called."); }
mxArray* _private_release_mxArray(){DLIB_CASSERT(false, "This function should never be called."); }
void _private_mark_non_persistent() {DLIB_CASSERT(false, "This function should never be called."); }
#endif
private:
T* data;
......@@ -396,6 +417,12 @@ namespace dlib
nc_ = nc;
}
#ifdef MATLAB_MEX_FILE
void _private_set_mxArray ( mxArray* ) { DLIB_CASSERT(false, "This function should never be called."); }
mxArray* _private_release_mxArray(){DLIB_CASSERT(false, "This function should never be called."); }
void _private_mark_non_persistent() {DLIB_CASSERT(false, "This function should never be called."); }
#endif
private:
T* data;
......@@ -476,6 +503,11 @@ namespace dlib
nc_ = nc;
}
#ifdef MATLAB_MEX_FILE
void _private_set_mxArray ( mxArray* ) { DLIB_CASSERT(false, "This function should never be called."); }
mxArray* _private_release_mxArray(){DLIB_CASSERT(false, "This function should never be called."); }
void _private_mark_non_persistent() {DLIB_CASSERT(false, "This function should never be called."); }
#endif
private:
T* data;
long nr_;
......@@ -593,6 +625,12 @@ namespace dlib
{
}
#ifdef MATLAB_MEX_FILE
void _private_set_mxArray ( mxArray* ) { DLIB_CASSERT(false, "This function should never be called."); }
mxArray* _private_release_mxArray(){DLIB_CASSERT(false, "This function should never be called."); }
void _private_mark_non_persistent() {DLIB_CASSERT(false, "This function should never be called."); }
#endif
private:
T data[num_cols*num_rows];
};
......@@ -656,6 +694,12 @@ namespace dlib
{
}
#ifdef MATLAB_MEX_FILE
void _private_set_mxArray ( mxArray* ) { DLIB_CASSERT(false, "This function should never be called."); }
mxArray* _private_release_mxArray(){DLIB_CASSERT(false, "This function should never be called."); }
void _private_mark_non_persistent() {DLIB_CASSERT(false, "This function should never be called."); }
#endif
private:
T* data;
......@@ -731,6 +775,12 @@ namespace dlib
nr_ = nr;
}
#ifdef MATLAB_MEX_FILE
void _private_set_mxArray ( mxArray* ) { DLIB_CASSERT(false, "This function should never be called."); }
mxArray* _private_release_mxArray(){DLIB_CASSERT(false, "This function should never be called."); }
void _private_mark_non_persistent() {DLIB_CASSERT(false, "This function should never be called."); }
#endif
private:
T* data;
......@@ -809,6 +859,12 @@ namespace dlib
nc_ = nc;
}
#ifdef MATLAB_MEX_FILE
void _private_set_mxArray ( mxArray* ) { DLIB_CASSERT(false, "This function should never be called."); }
mxArray* _private_release_mxArray(){DLIB_CASSERT(false, "This function should never be called."); }
void _private_mark_non_persistent() {DLIB_CASSERT(false, "This function should never be called."); }
#endif
private:
T* data;
......@@ -869,6 +925,12 @@ namespace dlib
pool.swap(item.pool);
}
#ifdef MATLAB_MEX_FILE
void _private_set_mxArray ( mxArray* ) { DLIB_CASSERT(false, "This function should never be called."); }
mxArray* _private_release_mxArray(){DLIB_CASSERT(false, "This function should never be called."); }
void _private_mark_non_persistent() {DLIB_CASSERT(false, "This function should never be called."); }
#endif
long nr (
) const { return nr_; }
......@@ -894,7 +956,251 @@ namespace dlib
long nr_;
long nc_;
typename mem_manager::template rebind<T>::other pool;
};
};
#ifdef MATLAB_MEX_FILE
template <
long num_rows,
long num_cols
>
class layout<double,num_rows,num_cols,default_memory_manager,5> : noncopyable // when num_rows == 0 && num_cols == 0
{
public:
const static long NR = num_rows;
const static long NC = num_cols;
layout (
): data(0), nr_(0), nc_(0), make_persistent(true),set_by_private_set_mxArray(false),mem(0) { }
~layout ()
{
if (!set_by_private_set_mxArray && mem)
{
mxDestroyArray(mem);
mem = 0;
data = 0;
}
}
double& operator() (
long r,
long c
) { return data[c*nr_ + r]; }
const double& operator() (
long r,
long c
) const { return data[c*nr_ + r]; }
double& operator() (
long i
) { return data[i]; }
const double& operator() (
long i
) const { return data[i]; }
void _private_set_mxArray (
mxArray* mem_
)
{
// We don't own the pointer, so make note of that so we won't try to free
// it.
set_by_private_set_mxArray = true;
mem = mem_;
data = mxGetPr(mem);
nr_ = mxGetM(mem);
nc_ = mxGetN(mem);
}
mxArray* _private_release_mxArray()
{
DLIB_CASSERT(!make_persistent,"");
mxArray* temp = mem;
mem = 0;
set_by_private_set_mxArray = false;
data = 0;
nr_ = 0;
nc_ = 0;
return temp;
}
void _private_mark_non_persistent()
{
make_persistent = false;
}
void swap(
layout& item
)
{
std::swap(item.make_persistent,make_persistent);
std::swap(item.set_by_private_set_mxArray,set_by_private_set_mxArray);
std::swap(item.mem,mem);
std::swap(item.data,data);
std::swap(item.nc_,nc_);
std::swap(item.nr_,nr_);
}
long nr (
) const { return nr_; }
long nc (
) const { return nc_; }
void set_size (
long nr,
long nc
)
{
if (!set_by_private_set_mxArray && mem)
{
mxDestroyArray(mem);
mem = 0;
data = 0;
}
set_by_private_set_mxArray = false;
mem = mxCreateDoubleMatrix(nr, nc, mxREAL);
if (mem == 0)
throw std::bad_alloc();
if (make_persistent)
mexMakeArrayPersistent(mem);
data = mxGetPr(mem);
nr_ = nr;
nc_ = nc;
}
private:
double* data;
long nr_;
long nc_;
bool make_persistent;
bool set_by_private_set_mxArray;
mxArray* mem;
};
template <
long num_rows,
long num_cols
>
class layout<float,num_rows,num_cols,default_memory_manager,5> : noncopyable // when num_rows == 0 && num_cols == 0
{
public:
const static long NR = num_rows;
const static long NC = num_cols;
layout (
): data(0), nr_(0), nc_(0), make_persistent(true),set_by_private_set_mxArray(false),mem(0) { }
~layout ()
{
if (!set_by_private_set_mxArray && mem)
{
mxDestroyArray(mem);
mem = 0;
data = 0;
}
}
float& operator() (
long r,
long c
) { return data[c*nr_ + r]; }
const float& operator() (
long r,
long c
) const { return data[c*nr_ + r]; }
float& operator() (
long i
) { return data[i]; }
const float& operator() (
long i
) const { return data[i]; }
void _private_set_mxArray (
mxArray* mem_
)
{
// We don't own the pointer, so make note of that so we won't try to free
// it.
set_by_private_set_mxArray = true;
mem = mem_;
data = (float*)mxGetData(mem);
nr_ = mxGetM(mem);
nc_ = mxGetN(mem);
}
mxArray* _private_release_mxArray()
{
DLIB_CASSERT(!make_persistent,"");
mxArray* temp = mem;
mem = 0;
set_by_private_set_mxArray = false;
data = 0;
nr_ = 0;
nc_ = 0;
return temp;
}
void _private_mark_non_persistent()
{
make_persistent = false;
}
void swap(
layout& item
)
{
std::swap(item.make_persistent,make_persistent);
std::swap(item.set_by_private_set_mxArray,set_by_private_set_mxArray);
std::swap(item.mem,mem);
std::swap(item.data,data);
std::swap(item.nc_,nc_);
std::swap(item.nr_,nr_);
}
long nr (
) const { return nr_; }
long nc (
) const { return nc_; }
void set_size (
long nr,
long nc
)
{
if (!set_by_private_set_mxArray && mem)
{
mxDestroyArray(mem);
mem = 0;
data = 0;
}
set_by_private_set_mxArray = false;
mem = mxCreateNumericMatrix(nr, nc, mxSINGLE_CLASS, mxREAL);
if (mem == 0)
throw std::bad_alloc();
if (make_persistent)
mexMakeArrayPersistent(mem);
data = (float*)mxGetData(mem);
nr_ = nr;
nc_ = nc;
}
private:
float* data;
long nr_;
long nc_;
bool make_persistent;
bool set_by_private_set_mxArray;
mxArray* mem;
};
#endif
};
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment