Commit ddc44067 authored by Davis King's avatar Davis King

Added a simple python interface for training fhog object detectors.

parent 15207aad
......@@ -8,6 +8,7 @@
#include <boost/python/suite/indexing/vector_indexing_suite.hpp>
#include <dlib/image_processing/frontal_face_detector.h>
#include <dlib/gui_widgets.h>
#include "simple_object_detector.h"
using namespace dlib;
......@@ -120,6 +121,38 @@ std::vector<rectangle> run_detector (
}
// ----------------------------------------------------------------------------------------
struct simple_object_detector_py
{
simple_object_detector detector;
unsigned int upsampling_amount;
std::vector<rectangle> run_detector1 (object img, const unsigned int upsampling_amount_)
{ return ::run_detector(detector, img, upsampling_amount_); }
std::vector<rectangle> run_detector2 (object img)
{ return ::run_detector(detector, img, upsampling_amount); }
};
void serialize (const simple_object_detector_py& item, std::ostream& out)
{
int version = 1;
serialize(item.detector, out);
serialize(version, out);
serialize(item.upsampling_amount, out);
}
void deserialize (simple_object_detector_py& item, std::istream& in)
{
int version = 0;
deserialize(item.detector, in);
deserialize(version, in);
if (version != 1)
throw dlib::serialization_error("Unexpected version found while deserializing a simple_object_detector.");
deserialize(item.upsampling_amount, in);
}
// ----------------------------------------------------------------------------------------
void image_window_set_image (
......@@ -163,16 +196,11 @@ boost::shared_ptr<image_window> make_image_window_from_image_and_title(object im
// ----------------------------------------------------------------------------------------
boost::shared_ptr<frontal_face_detector> load_fhog_object_detector_from_file (
const std::string& filename
)
string print_simple_test_results(const simple_test_results& r)
{
ifstream fin(filename.c_str(), ios::binary);
if (!fin)
throw dlib::error("Unable to open " + filename);
boost::shared_ptr<frontal_face_detector> detector(new frontal_face_detector());
deserialize(*detector, fin);
return detector;
std::ostringstream sout;
sout << "precision: "<<r.precision << ", recall: "<< r.recall << ", average precision: " << r.average_precision;
return sout.str();
}
// ----------------------------------------------------------------------------------------
......@@ -181,6 +209,23 @@ void bind_object_detection()
{
using boost::python::arg;
class_<simple_object_detector_training_options>("simple_object_detector_training_options")
.add_property("be_verbose", &simple_object_detector_training_options::be_verbose,
&simple_object_detector_training_options::be_verbose)
.add_property("add_left_right_image_flips", &simple_object_detector_training_options::add_left_right_image_flips,
&simple_object_detector_training_options::add_left_right_image_flips)
.add_property("detection_window_size", &simple_object_detector_training_options::detection_window_size,
&simple_object_detector_training_options::detection_window_size)
.add_property("num_threads", &simple_object_detector_training_options::num_threads,
&simple_object_detector_training_options::num_threads);
class_<simple_test_results>("simple_test_results")
.add_property("precision", &simple_test_results::precision)
.add_property("recall", &simple_test_results::recall)
.add_property("average_precision", &simple_test_results::average_precision)
.def("__str__", &::print_simple_test_results);
{
typedef rectangle type;
class_<type>("rectangle", "This object represents a rectangular area of an image.")
......@@ -199,12 +244,77 @@ void bind_object_detection()
def("get_frontal_face_detector", get_frontal_face_detector,
"Returns the default face detector");
def("train_simple_object_detector", train_simple_object_detector,
(arg("dataset_filename"), arg("detector_output_filename"), arg("C"), arg("options")=simple_object_detector_training_options()),
"whatever");
def("test_simple_object_detector", test_simple_object_detector,
(arg("dataset_filename"), arg("detector_filename")),
"whatever");
{
typedef simple_object_detector_py type;
class_<type>("simple_object_detector",
"This object represents a sliding window histogram-of-oriented-gradients based object detector.")
.def("__init__", make_constructor(&load_object_from_file<type>),
"Loads a simple_object_detector from a file that contains the output of the \n\
train_simple_object_detector() routine."
/*!
Loads a simple_object_detector from a file that contains the output of the
train_simple_object_detector() routine.
!*/)
.def("__call__", &type::run_detector1, (arg("image"), arg("upsample_num_times")),
"requires \n\
- image is a numpy ndarray containing either an 8bit grayscale or RGB \n\
image. \n\
- upsample_num_times >= 0 \n\
ensures \n\
- This function runs the object detector on the input image and returns \n\
a list of detections. \n\
- Upsamples the image upsample_num_times before running the basic \n\
detector. If you don't know how many times you want to upsample then \n\
don't provide a value for upsample_num_times and an appropriate \n\
default will be used."
/*!
requires
- image is a numpy ndarray containing either an 8bit grayscale or RGB
image.
- upsample_num_times >= 0
ensures
- This function runs the object detector on the input image and returns
a list of detections.
- Upsamples the image upsample_num_times before running the basic
detector. If you don't know how many times you want to upsample then
don't provide a value for upsample_num_times and an appropriate
default will be used.
!*/
)
.def("__call__", &type::run_detector2, (arg("image")),
"requires \n\
- image is a numpy ndarray containing either an 8bit grayscale or RGB \n\
image. \n\
ensures \n\
- This function runs the object detector on the input image and returns \n\
a list of detections. "
/*!
requires
- image is a numpy ndarray containing either an 8bit grayscale or RGB
image.
ensures
- This function runs the object detector on the input image and returns
a list of detections.
!*/
)
.def_pickle(serialize_pickle<type>());
}
{
typedef frontal_face_detector type;
class_<type>("fhog_object_detector",
"This object represents a sliding window histogram-of-oriented-gradients based object detector.")
.def("__init__", make_constructor(&load_fhog_object_detector_from_file),
"Loads a fhog_object_detector from a file.")
.def("__init__", make_constructor(&load_object_from_file<type>),
"Loads a fhog_object_detector from a file that contains a serialized \n\
object_detector<scan_fhog_pyramid<pyramid_down<6>>> object. " )
.def("__call__", &::run_detector, (arg("image"), arg("upsample_num_times")=0),
"requires \n\
- image is a numpy ndarray containing either an 8bit \n\
......
This diff is collapsed.
// Copyright (C) 2014 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#undef DLIB_SIMPLE_ObJECT_DETECTOR_ABSTRACT_H__
#ifdef DLIB_SIMPLE_ObJECT_DETECTOR_ABSTRACT_H__
#include <dlib/image_processing/object_detector_abstract.h>
#include <dlib/image_processing/scan_fhog_pyramid_abstract.h>
#include <dlib/svm/structural_object_detection_trainer_abstract.h>
#include <dlib/data_io/image_dataset_metadata.h>
#include <dlib/matrix.h>
namespace dlib
{
// ----------------------------------------------------------------------------------------
struct fhog_training_options
{
/*!
WHAT THIS OBJECT REPRESENTS
This object is a container for the more advanced options to the
train_simple_object_detector() routine. The parameters have the following
interpretations:
- be_verbose: If true, train_simple_object_detector() will print out a
lot of information to the screen while training.
- add_left_right_image_flips: if true, train_simple_object_detector()
will assume the objects are left/right symmetric and add in left
right flips of the training images. This doubles the size of the
training dataset.
- num_threads: train_simple_object_detector() will use this many
threads of execution. Set this to the number of CPU cores on your
machine to obtain the fastest training speed.
- detection_window_size: The sliding window used will have about this
many pixels inside it.
!*/
fhog_training_options()
{
be_verbose = false;
add_left_right_image_flips = false;
num_threads = 4;
detection_window_size = 80*80;
}
bool be_verbose;
bool add_left_right_image_flips;
unsigned long num_threads;
unsigned long detection_window_size;
};
// ----------------------------------------------------------------------------------------
typedef object_detector<scan_fhog_pyramid<pyramid_down<6> > > simple_object_detector;
// ----------------------------------------------------------------------------------------
void train_simple_object_detector (
const std::string& dataset_filename,
const std::string& detector_output_filename,
const double C,
const fhog_training_options& options = fhog_training_options()
);
/*!
requires
- C > 0
ensures
- Uses the structural_object_detection_trainer to train a
simple_object_detector based on the labeled images in the XML file
dataset_filename. This function assumes the file dataset_filename is in the
XML format produced by the save_image_dataset_metadata() routine.
- This function will apply a reasonable set of default parameters and
preprocessing techniques to the training procedure for simple_object_detector
objects. So the point of this function is to provide you with a very easy
way to train a basic object detector.
- C is the usual SVM C regularization parameter. So it is passed to
structural_object_detection_trainer::set_c(). Larger values of C will
encourage the trainer to fit the data better but might lead to overfitting.
Therefore, you must determine the proper setting of this parameter
experimentally.
- The trained object detector is serialized to the file detector_output_filename.
!*/
// ----------------------------------------------------------------------------------------
struct simple_test_results
{
double precision;
double recall;
double average_precision;
};
inline const simple_test_results test_simple_object_detector (
const std::string& dataset_filename,
const std::string& detector_filename
);
/*!
ensures
- Loads an image dataset from dataset_filename. We assume dataset_filename is
a file using the XML format written by save_image_dataset_metadata().
- Loads a simple_object_detector from the file detector_filename. This means
detector_filename should be a file produced by the train_simple_object_detector()
routine defined above.
- This function tests the detector against the dataset and returns three
numbers that tell you how well the detector does at detecting the objects in
the dataset. The return value of this function is identical to that of
test_object_detection_function(). Therefore, see the documentation for
test_object_detection_function() for an extended definition of these metrics.
!*/
// ----------------------------------------------------------------------------------------
}
#endif // DLIB_SIMPLE_ObJECT_DETECTOR_ABSTRACT_H__
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment