Commit 2e75e57f authored by Davis King's avatar Davis King

Fixed a compile time bug and another bug where the code inappropriately assumed a

sample_type was a dlib matrix.

--HG--
extra : convert_revision : svn%3Afdd8eb12-d10e-0410-9acb-85c331704f74/trunk%403628
parent ac9f61b3
...@@ -296,17 +296,6 @@ namespace dlib ...@@ -296,17 +296,6 @@ namespace dlib
const in_scalar_vector_type& y const in_scalar_vector_type& y
) const ) const
{ {
// make sure requires clause is not broken
DLIB_ASSERT(is_binary_classification_problem(x,y) == true,
"\t decision_function svm_c_ekm_trainer::train(x,y)"
<< "\n\t invalid inputs were given to this function"
<< "\n\t x.nr(): " << x.nr()
<< "\n\t y.nr(): " << y.nr()
<< "\n\t x.nc(): " << x.nc()
<< "\n\t y.nc(): " << y.nc()
<< "\n\t is_binary_classification_problem(x,y): " << is_binary_classification_problem(x,y)
);
scalar_type obj; scalar_type obj;
if (basis_loaded()) if (basis_loaded())
return do_train_user_basis(vector_to_matrix(x),vector_to_matrix(y),obj); return do_train_user_basis(vector_to_matrix(x),vector_to_matrix(y),obj);
...@@ -324,17 +313,6 @@ namespace dlib ...@@ -324,17 +313,6 @@ namespace dlib
scalar_type& svm_objective scalar_type& svm_objective
) const ) const
{ {
// make sure requires clause is not broken
DLIB_ASSERT(is_binary_classification_problem(x,y) == true,
"\t decision_function svm_c_ekm_trainer::train(x,y)"
<< "\n\t invalid inputs were given to this function"
<< "\n\t x.nr(): " << x.nr()
<< "\n\t y.nr(): " << y.nr()
<< "\n\t x.nc(): " << x.nc()
<< "\n\t y.nc(): " << y.nc()
<< "\n\t is_binary_classification_problem(x,y): " << is_binary_classification_problem(x,y)
);
if (basis_loaded()) if (basis_loaded())
return do_train_user_basis(vector_to_matrix(x),vector_to_matrix(y),svm_objective); return do_train_user_basis(vector_to_matrix(x),vector_to_matrix(y),svm_objective);
else else
...@@ -360,6 +338,17 @@ namespace dlib ...@@ -360,6 +338,17 @@ namespace dlib
- trains an SVM with the user supplied basis - trains an SVM with the user supplied basis
!*/ !*/
{ {
// make sure requires clause is not broken
DLIB_ASSERT(is_binary_classification_problem(x,y) == true,
"\t decision_function svm_c_ekm_trainer::train(x,y)"
<< "\n\t invalid inputs were given to this function"
<< "\n\t x.nr(): " << x.nr()
<< "\n\t y.nr(): " << y.nr()
<< "\n\t x.nc(): " << x.nc()
<< "\n\t y.nc(): " << y.nc()
<< "\n\t is_binary_classification_problem(x,y): " << is_binary_classification_problem(x,y)
);
if (ekm_stale) if (ekm_stale)
{ {
ekm.load(kern, basis); ekm.load(kern, basis);
...@@ -415,6 +404,17 @@ namespace dlib ...@@ -415,6 +404,17 @@ namespace dlib
scalar_type& svm_objective scalar_type& svm_objective
) const ) const
{ {
// make sure requires clause is not broken
DLIB_ASSERT(is_binary_classification_problem(x,y) == true,
"\t decision_function svm_c_ekm_trainer::train(x,y)"
<< "\n\t invalid inputs were given to this function"
<< "\n\t x.nr(): " << x.nr()
<< "\n\t y.nr(): " << y.nr()
<< "\n\t x.nc(): " << x.nc()
<< "\n\t y.nc(): " << y.nc()
<< "\n\t is_binary_classification_problem(x,y): " << is_binary_classification_problem(x,y)
);
std::vector<matrix<scalar_type,0,1, mem_manager_type> > proj_samples(x.size()); std::vector<matrix<scalar_type,0,1, mem_manager_type> > proj_samples(x.size());
decision_function<linear_kernel<matrix<scalar_type,0,1, mem_manager_type> > > df; decision_function<linear_kernel<matrix<scalar_type,0,1, mem_manager_type> > > df;
...@@ -492,7 +492,7 @@ namespace dlib ...@@ -492,7 +492,7 @@ namespace dlib
break; break;
prev_svm_objective = svm_objective; prev_svm_objective = svm_objective;
std::vector<sample_type> new_basis_elements; std::vector<matrix<scalar_type,0,1, mem_manager_type> > new_basis_elements;
// now add more elements to the basis // now add more elements to the basis
unsigned long count = 0; unsigned long count = 0;
...@@ -534,7 +534,7 @@ namespace dlib ...@@ -534,7 +534,7 @@ namespace dlib
prev_ekm.get_transformation_to(ekm, prev_to_new, proj_part); prev_ekm.get_transformation_to(ekm, prev_to_new, proj_part);
sample_type temp; matrix<scalar_type,0,1, mem_manager_type> temp;
for (long i = 0; i < x.size(); ++i) for (long i = 0; i < x.size(); ++i)
{ {
// assign to temporary to avoid memory allocation that would result if we // assign to temporary to avoid memory allocation that would result if we
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment