Commit f467236c authored by jpblackburn's avatar jpblackburn Committed by Davis E. King

Allow cross_validate_trainer_threaded to use non-double data (#883)

parent 2b9b3fef
......@@ -47,12 +47,12 @@ namespace dlib
{
template <
typename trainer_type,
typename matrix_type,
typename mem_manager_type,
typename in_sample_vector_type
>
void operator()(
job<trainer_type,in_sample_vector_type>& j,
matrix_type& result
matrix<double,1,2,mem_manager_type>& result
)
{
try
......@@ -83,7 +83,7 @@ namespace dlib
typename in_sample_vector_type,
typename in_scalar_vector_type
>
const matrix<typename trainer_type::scalar_type, 1, 2, typename trainer_type::mem_manager_type>
const matrix<double, 1, 2, typename trainer_type::mem_manager_type>
cross_validate_trainer_threaded_impl (
const trainer_type& trainer,
const in_sample_vector_type& x,
......@@ -93,7 +93,6 @@ namespace dlib
)
{
using namespace dlib::cvtti_helpers;
typedef typename trainer_type::scalar_type scalar_type;
typedef typename trainer_type::mem_manager_type mem_manager_type;
// make sure requires clause is not broken
......@@ -137,7 +136,7 @@ namespace dlib
std::vector<future<job<trainer_type,in_sample_vector_type> > > jobs(folds);
std::vector<future<matrix<scalar_type, 1, 2, mem_manager_type> > > results(folds);
std::vector<future<matrix<double, 1, 2, mem_manager_type> > > results(folds);
for (long i = 0; i < folds; ++i)
......@@ -212,7 +211,7 @@ namespace dlib
} // for (long i = 0; i < folds; ++i)
matrix<scalar_type, 1, 2, mem_manager_type> res;
matrix<double, 1, 2, mem_manager_type> res;
set_all_elements(res,0);
// now compute the total results
......@@ -221,7 +220,7 @@ namespace dlib
res += results[i].get();
}
return res/(scalar_type)folds;
return res/(double)folds;
}
template <
......@@ -229,7 +228,7 @@ namespace dlib
typename in_sample_vector_type,
typename in_scalar_vector_type
>
const matrix<typename trainer_type::scalar_type, 1, 2, typename trainer_type::mem_manager_type>
const matrix<double, 1, 2, typename trainer_type::mem_manager_type>
cross_validate_trainer_threaded (
const trainer_type& trainer,
const in_sample_vector_type& x,
......
......@@ -17,7 +17,7 @@ namespace dlib
typename in_sample_vector_type,
typename in_scalar_vector_type
>
const matrix<typename trainer_type::scalar_type, 1, 2, typename trainer_type::mem_manager_type>
const matrix<double, 1, 2, typename trainer_type::mem_manager_type>
cross_validate_trainer_threaded (
const trainer_type& trainer,
const in_sample_vector_type& x,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment