Commit a66a7dd8 authored by Davis King's avatar Davis King

Added an initial version of the structural_object_detection_trainer. This is

a tool for learning the parameters for an object like scan_image_pyramid.
parent b795c19b
// Copyright (C) 2011 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#ifndef DLIB_STRUCTURAL_OBJECT_DETECTION_TRAiNER_H__
#define DLIB_STRUCTURAL_OBJECT_DETECTION_TRAiNER_H__
#include "structural_object_detection_trainer_abstract.h"
#include "../algs.h"
#include "../optimization.h"
#include "structural_svm_object_detection_problem.h"
#include "../image_processing/object_detector.h"
#include "../image_processing/box_overlap_testing.h"
namespace dlib
{
// ----------------------------------------------------------------------------------------
template <
typename image_scanner_type,
typename overlap_tester_type = test_box_overlap
>
class structural_object_detection_trainer : noncopyable
{
public:
typedef double scalar_type;
typedef default_memory_manager mem_manager_type;
typedef object_detector<image_scanner_type,overlap_tester_type> trained_function_type;
explicit structural_object_detection_trainer (
const image_scanner_type& scanner_
)
{
// make sure requires clause is not broken
DLIB_ASSERT(scanner_.get_num_detection_templates() > 0,
"\t structural_object_detection_trainer::structural_object_detection_trainer(scanner_)"
<< "\n\t You can't have zero detection templates"
<< "\n\t this: " << this
);
C = 1;
verbose = false;
eps = 0.3;
num_threads = 2;
max_cache_size = 40;
overlap_eps = 0.5;
loss_per_missed_target = 1;
loss_per_false_alarm = 1;
scanner.copy_configuration(scanner_);
}
void set_overlap_tester (
const overlap_tester_type& tester
)
{
overlap_tester = tester;
}
overlap_tester_type get_overlap_tester (
) const
{
return overlap_tester;
}
void set_num_threads (
unsigned long num
)
{
num_threads = num;
}
unsigned long get_num_threads (
) const
{
return num_threads;
}
void set_epsilon (
scalar_type eps_
)
{
// make sure requires clause is not broken
DLIB_ASSERT(eps_ > 0,
"\t void structural_object_detection_trainer::set_epsilon()"
<< "\n\t eps_ must be greater than 0"
<< "\n\t eps_: " << eps_
<< "\n\t this: " << this
);
eps = eps_;
}
const scalar_type get_epsilon (
) const { return eps; }
void set_max_cache_size (
unsigned long max_size
)
{
max_cache_size = max_size;
}
unsigned long get_max_cache_size (
) const
{
return max_cache_size;
}
void be_verbose (
)
{
verbose = true;
}
void be_quiet (
)
{
verbose = false;
}
void set_oca (
const oca& item
)
{
solver = item;
}
const oca get_oca (
) const
{
return solver;
}
void set_c (
scalar_type C_
)
{
// make sure requires clause is not broken
DLIB_ASSERT(C_ > 0,
"\t void structural_object_detection_trainer::set_c()"
<< "\n\t C_ must be greater than 0"
<< "\n\t C_: " << C_
<< "\n\t this: " << this
);
C = C_;
}
const scalar_type get_c (
) const
{
return C;
}
void set_overlap_eps (
double eps
)
{
// make sure requires clause is not broken
DLIB_ASSERT(0 < eps && eps < 1,
"\t void structural_object_detection_trainer::set_overlap_eps(eps)"
<< "\n\t Invalid inputs were given to this function "
<< "\n\t eps: " << eps
<< "\n\t this: " << this
);
overlap_eps = eps;
}
double get_overlap_eps (
) const
{
return overlap_eps;
}
double get_loss_per_missed_target (
) const
{
return loss_per_missed_target;
}
void set_loss_per_missed_target (
double loss
)
{
// make sure requires clause is not broken
DLIB_ASSERT(loss > 0,
"\t void structural_object_detection_trainer::set_loss_per_missed_target(loss)"
<< "\n\t Invalid inputs were given to this function "
<< "\n\t loss: " << loss
<< "\n\t this: " << this
);
loss_per_missed_target = loss;
}
double get_loss_per_false_alarm (
) const
{
return loss_per_false_alarm;
}
void set_loss_per_false_alarm (
double loss
)
{
// make sure requires clause is not broken
DLIB_ASSERT(loss > 0,
"\t void structural_object_detection_trainer::set_loss_per_false_alarm(loss)"
<< "\n\t Invalid inputs were given to this function "
<< "\n\t loss: " << loss
<< "\n\t this: " << this
);
loss_per_false_alarm = loss;
}
template <
typename image_array_type
>
const trained_function_type train (
const image_array_type& images,
const std::vector<std::vector<rectangle> >& truth_rects
) const
{
// make sure requires clause is not broken
DLIB_ASSERT(is_learning_problem(images,truth_rects) == true,
"\t trained_function_type structural_object_detection_trainer::train(x,y)"
<< "\n\t invalid inputs were given to this function"
<< "\n\t images.size(): " << images.size()
<< "\n\t truth_rects.size(): " << truth_rects.size()
<< "\n\t is_learning_problem(images,truth_rects): " << is_learning_problem(images,truth_rects)
);
structural_svm_object_detection_problem<image_scanner_type,overlap_tester_type,image_array_type >
svm_prob(scanner, overlap_tester, images, truth_rects, num_threads);
if (verbose)
svm_prob.be_verbose();
svm_prob.set_c(C);
svm_prob.set_epsilon(eps);
svm_prob.set_max_cache_size(max_cache_size);
svm_prob.set_overlap_eps(overlap_eps);
svm_prob.set_loss_per_missed_target(loss_per_missed_target);
svm_prob.set_loss_per_false_alarm(loss_per_false_alarm);
matrix<double,0,1> w;
solver(svm_prob,w);
return object_detector<image_scanner_type,overlap_tester_type>(scanner, overlap_tester, w);
}
private:
image_scanner_type scanner;
overlap_tester_type overlap_tester;
double C;
oca solver;
double eps;
double overlap_eps;
bool verbose;
unsigned long num_threads;
unsigned long max_cache_size;
double loss_per_missed_target;
double loss_per_false_alarm;
};
// ----------------------------------------------------------------------------------------
}
#endif // DLIB_STRUCTURAL_OBJECT_DETECTION_TRAiNER_H__
// Copyright (C) 2011 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#undef DLIB_STRUCTURAL_OBJECT_DETECTION_TRAiNER_H_ABSTRACT__
#ifdef DLIB_STRUCTURAL_OBJECT_DETECTION_TRAiNER_H_ABSTRACT__
#include "structural_svm_object_detection_problem_abstract.h"
#include "../image_processing/object_detector_abstract.h"
#include "../image_processing/box_overlap_testing_abstract.h"
namespace dlib
{
// ----------------------------------------------------------------------------------------
template <
typename image_scanner_type,
typename overlap_tester_type = test_box_overlap
>
class structural_object_detection_trainer : noncopyable
{
/*!
WHAT THIS OBJECT REPRESENTS
!*/
public:
typedef double scalar_type;
typedef default_memory_manager mem_manager_type;
typedef object_detector<image_scanner_type,overlap_tester_type> trained_function_type;
explicit structural_object_detection_trainer (
const image_scanner_type& scanner
);
/*!
requires
- scanner.get_num_detection_templates() > 0
ensures
- #get_c() == 1
- this object isn't verbose
- #get_epsilon() == 0.3
- #get_num_threads() == 2
- #get_max_cache_size() == 40
- #get_overlap_eps() == 0.5
- #get_loss_per_missed_target() == 1
- #get_loss_per_false_alarm() == 1
- This object will attempt to learn a model for the given
scanner object when train() is called.
!*/
void set_overlap_tester (
const overlap_tester_type& tester
);
overlap_tester_type get_overlap_tester (
) const;
void set_num_threads (
unsigned long num
);
unsigned long get_num_threads (
) const;
void set_epsilon (
scalar_type eps
);
/*!
requires
- eps > 0
ensures
- #get_epsilon() == eps
!*/
const scalar_type get_epsilon (
) const;
void set_max_cache_size (
unsigned long max_size
);
unsigned long get_max_cache_size (
) const;
void be_verbose (
);
void be_quiet (
);
void set_oca (
const oca& item
);
const oca get_oca (
) const;
void set_c (
scalar_type C
);
/*!
requires
- C > 0
ensures
- #get_c() = C
!*/
const scalar_type get_c (
) const;
void set_overlap_eps (
double eps
);
/*!
requires
- 0 < eps < 1
ensures
- #get_overlap_eps() == eps
!*/
double get_overlap_eps (
) const;
double get_loss_per_missed_target (
) const;
void set_loss_per_missed_target (
double loss
);
/*!
requires
- loss > 0
ensures
- #get_loss_per_missed_target() == loss
!*/
double get_loss_per_false_alarm (
) const;
void set_loss_per_false_alarm (
double loss
);
/*!
requires
- loss > 0
ensures
- #get_loss_per_false_alarm() == loss
!*/
template <
typename image_array_type
>
const trained_function_type train (
const image_array_type& images,
const std::vector<std::vector<rectangle> >& truth_rects
) const;
/*!
requires
- is_learning_problem(images, truth_rects) == true
- it must be valid to pass images[0] into the image_scanner_type::load() method.
ensures
-
!*/
};
// ----------------------------------------------------------------------------------------
}
#endif // DLIB_STRUCTURAL_OBJECT_DETECTION_TRAiNER_H_ABSTRACT__
......@@ -8,6 +8,7 @@
#include "svm/structural_svm_problem_threaded.h"
#include "svm/structural_svm_distributed.h"
#include "svm/structural_svm_object_detection_problem.h"
#include "svm/structural_object_detection_trainer.h"
#endif // DLIB_SVm_THREADED_HEADER
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment