Commit f467236c authored by jpblackburn's avatar jpblackburn Committed by Davis E. King

Allow cross_validate_trainer_threaded to use non-double data (#883)

parent 2b9b3fef
...@@ -47,12 +47,12 @@ namespace dlib ...@@ -47,12 +47,12 @@ namespace dlib
{ {
template < template <
typename trainer_type, typename trainer_type,
typename matrix_type, typename mem_manager_type,
typename in_sample_vector_type typename in_sample_vector_type
> >
void operator()( void operator()(
job<trainer_type,in_sample_vector_type>& j, job<trainer_type,in_sample_vector_type>& j,
matrix_type& result matrix<double,1,2,mem_manager_type>& result
) )
{ {
try try
...@@ -83,7 +83,7 @@ namespace dlib ...@@ -83,7 +83,7 @@ namespace dlib
typename in_sample_vector_type, typename in_sample_vector_type,
typename in_scalar_vector_type typename in_scalar_vector_type
> >
const matrix<typename trainer_type::scalar_type, 1, 2, typename trainer_type::mem_manager_type> const matrix<double, 1, 2, typename trainer_type::mem_manager_type>
cross_validate_trainer_threaded_impl ( cross_validate_trainer_threaded_impl (
const trainer_type& trainer, const trainer_type& trainer,
const in_sample_vector_type& x, const in_sample_vector_type& x,
...@@ -93,7 +93,6 @@ namespace dlib ...@@ -93,7 +93,6 @@ namespace dlib
) )
{ {
using namespace dlib::cvtti_helpers; using namespace dlib::cvtti_helpers;
typedef typename trainer_type::scalar_type scalar_type;
typedef typename trainer_type::mem_manager_type mem_manager_type; typedef typename trainer_type::mem_manager_type mem_manager_type;
// make sure requires clause is not broken // make sure requires clause is not broken
...@@ -137,7 +136,7 @@ namespace dlib ...@@ -137,7 +136,7 @@ namespace dlib
std::vector<future<job<trainer_type,in_sample_vector_type> > > jobs(folds); std::vector<future<job<trainer_type,in_sample_vector_type> > > jobs(folds);
std::vector<future<matrix<scalar_type, 1, 2, mem_manager_type> > > results(folds); std::vector<future<matrix<double, 1, 2, mem_manager_type> > > results(folds);
for (long i = 0; i < folds; ++i) for (long i = 0; i < folds; ++i)
...@@ -212,7 +211,7 @@ namespace dlib ...@@ -212,7 +211,7 @@ namespace dlib
} // for (long i = 0; i < folds; ++i) } // for (long i = 0; i < folds; ++i)
matrix<scalar_type, 1, 2, mem_manager_type> res; matrix<double, 1, 2, mem_manager_type> res;
set_all_elements(res,0); set_all_elements(res,0);
// now compute the total results // now compute the total results
...@@ -221,7 +220,7 @@ namespace dlib ...@@ -221,7 +220,7 @@ namespace dlib
res += results[i].get(); res += results[i].get();
} }
return res/(scalar_type)folds; return res/(double)folds;
} }
template < template <
...@@ -229,7 +228,7 @@ namespace dlib ...@@ -229,7 +228,7 @@ namespace dlib
typename in_sample_vector_type, typename in_sample_vector_type,
typename in_scalar_vector_type typename in_scalar_vector_type
> >
const matrix<typename trainer_type::scalar_type, 1, 2, typename trainer_type::mem_manager_type> const matrix<double, 1, 2, typename trainer_type::mem_manager_type>
cross_validate_trainer_threaded ( cross_validate_trainer_threaded (
const trainer_type& trainer, const trainer_type& trainer,
const in_sample_vector_type& x, const in_sample_vector_type& x,
......
...@@ -17,7 +17,7 @@ namespace dlib ...@@ -17,7 +17,7 @@ namespace dlib
typename in_sample_vector_type, typename in_sample_vector_type,
typename in_scalar_vector_type typename in_scalar_vector_type
> >
const matrix<typename trainer_type::scalar_type, 1, 2, typename trainer_type::mem_manager_type> const matrix<double, 1, 2, typename trainer_type::mem_manager_type>
cross_validate_trainer_threaded ( cross_validate_trainer_threaded (
const trainer_type& trainer, const trainer_type& trainer,
const in_sample_vector_type& x, const in_sample_vector_type& x,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment