Commit 9de4e129 authored by Davis King's avatar Davis King

Added a spec for the assignment problem validation functions and added

missing asserts.
parent 25e976fe
...@@ -23,6 +23,28 @@ namespace dlib ...@@ -23,6 +23,28 @@ namespace dlib
const std::vector<typename assignment_function::label_type>& labels const std::vector<typename assignment_function::label_type>& labels
) )
{ {
// make sure requires clause is not broken
#ifdef ENABLE_ASSERTS
if (assigner.forces_assignment())
{
DLIB_ASSERT(is_forced_assignment_problem(samples, labels),
"\t double test_assignment_function()"
<< "\n\t invalid inputs were given to this function"
<< "\n\t is_forced_assignment_problem(samples,labels): " << is_forced_assignment_problem(samples,labels)
<< "\n\t is_assignment_problem(samples,labels): " << is_assignment_problem(samples,labels)
<< "\n\t is_learning_problem(samples,labels): " << is_learning_problem(samples,labels)
);
}
else
{
DLIB_ASSERT(is_assignment_problem(samples, labels),
"\t double test_assignment_function()"
<< "\n\t invalid inputs were given to this function"
<< "\n\t is_assignment_problem(samples,labels): " << is_assignment_problem(samples,labels)
<< "\n\t is_learning_problem(samples,labels): " << is_learning_problem(samples,labels)
);
}
#endif
double total_right = 0; double total_right = 0;
double total = 0; double total = 0;
for (unsigned long i = 0; i < samples.size(); ++i) for (unsigned long i = 0; i < samples.size(); ++i)
...@@ -55,6 +77,37 @@ namespace dlib ...@@ -55,6 +77,37 @@ namespace dlib
const long folds const long folds
) )
{ {
// make sure requires clause is not broken
#ifdef ENABLE_ASSERTS
if (trainer.forces_assignment())
{
DLIB_ASSERT(is_forced_assignment_problem(samples, labels) &&
1 < folds && folds <= static_cast<long>(samples.size()),
"\t double cross_validate_assignment_trainer()"
<< "\n\t invalid inputs were given to this function"
<< "\n\t samples.size(): " << samples.size()
<< "\n\t folds: " << folds
<< "\n\t is_forced_assignment_problem(samples,labels): " << is_forced_assignment_problem(samples,labels)
<< "\n\t is_assignment_problem(samples,labels): " << is_assignment_problem(samples,labels)
<< "\n\t is_learning_problem(samples,labels): " << is_learning_problem(samples,labels)
);
}
else
{
DLIB_ASSERT(is_assignment_problem(samples, labels) &&
1 < folds && folds <= static_cast<long>(samples.size()),
"\t double cross_validate_assignment_trainer()"
<< "\n\t invalid inputs were given to this function"
<< "\n\t samples.size(): " << samples.size()
<< "\n\t folds: " << folds
<< "\n\t is_assignment_problem(samples,labels): " << is_assignment_problem(samples,labels)
<< "\n\t is_learning_problem(samples,labels): " << is_learning_problem(samples,labels)
);
}
#endif
typedef typename trainer_type::sample_type sample_type; typedef typename trainer_type::sample_type sample_type;
typedef typename trainer_type::label_type label_type; typedef typename trainer_type::label_type label_type;
......
// Copyright (C) 2011 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#undef DLIB_CROSS_VALIDATE_ASSiGNEMNT_TRAINER_ABSTRACT_H__
#ifdef DLIB_CROSS_VALIDATE_ASSiGNEMNT_TRAINER_ABSTRACT_H__
#include <vector>
#include "../matrix.h"
#include "svm.h"
namespace dlib
{
// ----------------------------------------------------------------------------------------
template <
typename assignment_function
>
double test_assignment_function (
const assignment_function& assigner,
const std::vector<typename assignment_function::sample_type>& samples,
const std::vector<typename assignment_function::label_type>& labels
);
/*!
requires
- is_assignment_problem(samples, labels)
- if (assigner.forces_assignment()) then
- is_forced_assignment_problem(samples, labels)
- assignment_function == an instantiation of the dlib::assignment_function
template or an object with a compatible interface.
ensures
- Tests assigner against the given samples and labels and returns the fraction
of assignments predicted correctly.
!*/
// ----------------------------------------------------------------------------------------
template <
typename trainer_type
>
double cross_validate_assignment_trainer (
const trainer_type& trainer,
const std::vector<typename trainer_type::sample_type>& samples,
const std::vector<typename trainer_type::label_type>& labels,
const long folds
);
/*!
requires
- is_assignment_problem(samples, labels)
- if (trainer.forces_assignment()) then
- is_forced_assignment_problem(samples, labels)
- 1 < folds <= samples.size()
- trainer_type == dlib::structural_assignment_trainer or an object
with a compatible interface.
ensures
- performs k-fold cross validation by using the given trainer to solve the
given assignment learning problem for the given number of folds. Each fold
is tested using the output of the trainer and the fraction of assignments
predicted correctly is returned.
- The number of folds used is given by the folds argument.
!*/
// ----------------------------------------------------------------------------------------
}
#endif // DLIB_CROSS_VALIDATE_ASSiGNEMNT_TRAINER_ABSTRACT_H__
...@@ -29,6 +29,8 @@ namespace dlib ...@@ -29,6 +29,8 @@ namespace dlib
typedef assignment_function<feature_extractor> trained_function_type; typedef assignment_function<feature_extractor> trained_function_type;
bool forces_assignment(
) const { return false; } // TODO
const assignment_function<feature_extractor> train ( const assignment_function<feature_extractor> train (
const std::vector<sample_type>& x, const std::vector<sample_type>& x,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment