Commit 4805117a authored by Davis King's avatar Davis King

Added a probabilistic trainer adapter.

--HG--
extra : convert_revision : svn%3Afdd8eb12-d10e-0410-9acb-85c331704f74/trunk%404075
parent d1873299
......@@ -646,6 +646,48 @@ namespace dlib
return probabilistic_function<typename trainer_type::trained_function_type>( A, B, trainer.train(x,y) );
}
// ----------------------------------------------------------------------------------------
template <typename trainer_type>
struct trainer_adapter_probabilistic
{
typedef probabilistic_function<typename trainer_type::trained_function_type> trained_function_type;
const trainer_type& trainer;
const long folds;
trainer_adapter_probabilistic (
const trainer_type& trainer_,
const long folds_
) : trainer(trainer_),folds(folds_) {}
template <
typename sample_type,
typename scalar_type,
typename alloc_type1,
typename alloc_type2
>
const trained_function_type train (
const std::vector<sample_type,alloc_type1>& samples,
const std::vector<scalar_type,alloc_type2>& labels
) const
{
return train_probabilistic_decision_function(trainer, samples, labels, folds);
}
};
template <
typename trainer_type
>
trainer_adapter_probabilistic<trainer_type> probabilistic (
const trainer_type& trainer,
const long folds
)
{
return trainer_adapter_probabilistic<trainer_type>(trainer,folds);
}
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
......
......@@ -96,6 +96,24 @@ namespace dlib
- std::bad_alloc
!*/
// ----------------------------------------------------------------------------------------
template <
typename trainer_type
>
trainer_adapter_probabilistic<trainer_type> probabilistic (
const trainer_type& trainer,
const long folds
);
/*!
requires
- 1 < folds <= x.size()
- trainer_type == some kind of batch trainer object (e.g. svm_nu_trainer)
ensures
- returns a trainer adapter TA such that calling TA.train(samples, labels)
returns the same object as calling train_probabilistic_decision_function(trainer,samples,labels,folds).
!*/
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
// Miscellaneous functions
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment