Commit 15cff080 authored by Davis King's avatar Davis King

Added the roc_trainer_type object.

--HG--
extra : convert_revision : svn%3Afdd8eb12-d10e-0410-9acb-85c331704f74/trunk%403043
parent 8901e32c
......@@ -16,6 +16,7 @@
#include "svm/pegasos.h"
#include "svm/sparse_kernel.h"
#include "svm/null_trainer.h"
#include "svm/roc_trainer.h"
#endif // DLIB_SVm_HEADER
......
// Copyright (C) 2009 Davis E. King (davisking@users.sourceforge.net)
// License: Boost Software License See LICENSE.txt for the full license.
#ifndef DLIB_ROC_TRAINEr_H_
#define DLIB_ROC_TRAINEr_H_
#include "roc_trainer_abstract.h"
#include "../algs.h"
namespace dlib
{
// ----------------------------------------------------------------------------------------
template <
typename trainer_type
>
class roc_trainer_type
{
public:
typedef typename trainer_type::kernel_type kernel_type;
typedef typename trainer_type::scalar_type scalar_type;
typedef typename trainer_type::sample_type sample_type;
typedef typename trainer_type::mem_manager_type mem_manager_type;
typedef typename trainer_type::trained_function_type trained_function_type;
roc_trainer_type (
) : desired_accuracy(0), class_selection(0){}
roc_trainer_type (
const trainer_type& trainer_,
const scalar_type& desired_accuracy_,
const scalar_type& class_selection_
) : trainer(trainer_), desired_accuracy(desired_accuracy_), class_selection(class_selection_)
{
// make sure requires clause is not broken
DLIB_ASSERT(0 <= desired_accuracy && desired_accuracy <= 1 &&
(class_selection == -1 || class_selection == +1),
"\t roc_trainer_type::roc_trainer_type()"
<< "\n\t invalid inputs were given to this function"
<< "\n\t desired_accuracy: " << desired_accuracy
<< "\n\t class_selection: " << class_selection
);
}
template <
typename in_sample_vector_type,
typename in_scalar_vector_type
>
const trained_function_type train (
const in_sample_vector_type& samples,
const in_scalar_vector_type& labels
) const
/*!
requires
- is_binary_classification_problem(samples, labels) == true
!*/
{
// make sure requires clause is not broken
DLIB_ASSERT(is_binary_classification_problem(samples, labels),
"\t roc_trainer_type::train()"
<< "\n\t invalid inputs were given to this function"
);
return do_train(vector_to_matrix(samples), vector_to_matrix(labels));
}
private:
template <
typename in_sample_vector_type,
typename in_scalar_vector_type
>
const trained_function_type do_train (
const in_sample_vector_type& samples,
const in_scalar_vector_type& labels
) const
{
trained_function_type df = trainer.train(samples, labels);
// clear out the old bias
df.b = 0;
// obtain all the scores from the df using all the class_selection labeled samples
std::vector<double> scores;
for (long i = 0; i < samples.size(); ++i)
{
if (labels(i) == class_selection)
scores.push_back(df(samples(i)));
}
if (class_selection == +1)
std::sort(scores.rbegin(), scores.rend());
else
std::sort(scores.begin(), scores.end());
// now pick out the index that gives us the desired accuracy with regards to selected class
unsigned long idx = static_cast<unsigned long>(desired_accuracy*scores.size() + 0.5);
if (idx >= scores.size())
idx = scores.size()-1;
df.b = scores[idx];
return df;
}
trainer_type trainer;
scalar_type desired_accuracy;
scalar_type class_selection;
};
// ----------------------------------------------------------------------------------------
template <
typename trainer_type
>
const roc_trainer_type<trainer_type> roc_c1_trainer (
const trainer_type& trainer,
const typename trainer_type::scalar_type& desired_accuracy
) { return roc_trainer_type<trainer_type>(trainer, desired_accuracy, +1); }
// ----------------------------------------------------------------------------------------
template <
typename trainer_type
>
const roc_trainer_type<trainer_type> roc_c2_trainer (
const trainer_type& trainer,
const typename trainer_type::scalar_type& desired_accuracy
) { return roc_trainer_type<trainer_type>(trainer, desired_accuracy, -1); }
// ----------------------------------------------------------------------------------------
}
#endif // DLIB_ROC_TRAINEr_H_
// Copyright (C) 2009 Davis E. King (davisking@users.sourceforge.net)
// License: Boost Software License See LICENSE.txt for the full license.
#undef DLIB_ROC_TRAINEr_ABSTRACT_
#ifdef DLIB_ROC_TRAINEr_ABSTRACT_
#include "../algs.h"
namespace dlib
{
// ----------------------------------------------------------------------------------------
template <
typename trainer_type
>
class roc_trainer_type
{
/*!
REQUIREMENTS ON trainer_type
- trainer_type == some kind of batch trainer object (e.g. svm_nu_trainer)
WHAT THIS OBJECT REPRESENTS
This object is a simple trainer post processor that allows you to
easily adjust the bias term in a trained decision_function object.
That is, this object lets you pick a point on the ROC curve and
it will adjust the bias term appropriately.
So for example, suppose you wanted to set the bias term so that
the accuracy of your decision function on +1 labeled samples was 99%.
To do this you would use an instance of this object declared as follows:
roc_trainer_type(your_trainer, 0.99, +1);
!*/
public:
typedef typename trainer_type::kernel_type kernel_type;
typedef typename trainer_type::scalar_type scalar_type;
typedef typename trainer_type::sample_type sample_type;
typedef typename trainer_type::mem_manager_type mem_manager_type;
typedef typename trainer_type::trained_function_type trained_function_type;
roc_trainer_type (
);
/*!
ensures
- This object is in an uninitialized state. You must
construct a real one with the other constructor and assign it
to this instance before you use this object.
!*/
roc_trainer_type (
const trainer_type& trainer_,
const scalar_type& desired_accuracy_,
const scalar_type& class_selection_
);
/*!
requires
- 0 <= desired_accuracy_ <= 1
- class_selection_ == +1 or -1
ensures
- when training is performed using this object it will automatically
adjust the bias term in the returned decision function so that it
achieves the desired accuracy on the selected class type.
!*/
template <
typename in_sample_vector_type,
typename in_scalar_vector_type
>
const trained_function_type train (
const in_sample_vector_type& samples,
const in_scalar_vector_type& labels
) const
/*!
requires
- is_binary_classification_problem(samples, labels) == true
- x == a matrix or something convertible to a matrix via vector_to_matrix().
Also, x should contain sample_type objects.
- y == a matrix or something convertible to a matrix via vector_to_matrix().
Also, y should contain scalar_type objects.
ensures
- performs training using the trainer object given to this object's
constructor, then modifies the bias term in the returned decision function
as discussed above, and finally returns the decision function.
!*/
};
// ----------------------------------------------------------------------------------------
template <
typename trainer_type
>
const roc_trainer_type<trainer_type> roc_c1_trainer (
const trainer_type& trainer,
const typename trainer_type::scalar_type& desired_accuracy
) { return roc_trainer_type<trainer_type>(trainer, desired_accuracy, +1); }
/*!
requires
- 0 <= desired_accuracy <= 1
- trainer_type == some kind of batch trainer object that creates decision_function
objects (e.g. svm_nu_trainer)
ensures
- returns a roc_trainer_type object that has been instantiated with the given
arguments. The returned roc trainer will select the decision function
bias that gives the desired accuracy with respect to the +1 class.
!*/
// ----------------------------------------------------------------------------------------
template <
typename trainer_type
>
const roc_trainer_type<trainer_type> roc_c2_trainer (
const trainer_type& trainer,
const typename trainer_type::scalar_type& desired_accuracy
) { return roc_trainer_type<trainer_type>(trainer, desired_accuracy, -1); }
/*!
requires
- 0 <= desired_accuracy <= 1
- trainer_type == some kind of batch trainer object that creates decision_function
objects (e.g. svm_nu_trainer)
ensures
- returns a roc_trainer_type object that has been instantiated with the given
arguments. The returned roc trainer will select the decision function
bias that gives the desired accuracy with respect to the -1 class.
!*/
// ----------------------------------------------------------------------------------------
}
#endif // DLIB_ROC_TRAINEr_ABSTRACT_
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment