Commit f8419124 authored by Davis King's avatar Davis King

Cleaned up the object_detector Python API a little and also pushed the nuclear

norm regularization stuff to Python. This also includes adding
num_separable_filters() and threshold_filter_singular_values() to the Python API.
parent 9a845c51
......@@ -130,6 +130,8 @@ void bind_object_detection(py::module& m)
py::class_<type>(m, "simple_object_detector_training_options",
"This object is a container for the options to the train_simple_object_detector() routine.")
.def(py::init())
.def("__str__", &::print_simple_object_detector_training_options)
.def("__repr__", &::print_simple_object_detector_training_options)
.def_readwrite("be_verbose", &type::be_verbose,
"If true, train_simple_object_detector() will print out a lot of information to the screen while training.")
.def_readwrite("add_left_right_image_flips", &type::add_left_right_image_flips,
......@@ -138,6 +140,26 @@ left/right symmetric and add in left right flips of the training \n\
images. This doubles the size of the training dataset.")
.def_readwrite("detection_window_size", &type::detection_window_size,
"The sliding window used will have about this many pixels inside it.")
.def_readwrite("nuclear_norm_regularization_strength", &type::nuclear_norm_regularization_strength,
"This detector works by convolving a filter over a HOG feature image. If that \n\
filter is separable then the convolution can be performed much faster. The \n\
nuclear_norm_regularization_strength parameter encourages the machine learning \n\
algorithm to learn a separable filter. A value of 0 disables this feature, but \n\
any non-zero value places a nuclear norm regularizer on the objective function \n\
and this encourages the learning of a separable filter. Note that setting \n\
nuclear_norm_regularization_strength to a non-zero value can make the training \n\
process take significantly longer, so be patient when using it."
/*!
This detector works by convolving a filter over a HOG feature image. If that
filter is separable then the convolution can be performed much faster. The
nuclear_norm_regularization_strength parameter encourages the machine learning
algorithm to learn a separable filter. A value of 0 disables this feature, but
any non-zero value places a nuclear norm regularizer on the objective function
and this encourages the learning of a separable filter. Note that setting
nuclear_norm_regularization_strength to a non-zero value can make the training
process take significantly longer, so be patient when using it.
!*/
)
.def_readwrite("C", &type::C,
"C is the usual SVM C regularization parameter. So it is passed to \n\
structural_object_detection_trainer::set_c(). Larger values of C \n\
......@@ -156,16 +178,18 @@ obtain the fastest training speed.")
no more than upsample_limit times. Value 0 will forbid trainer to \n\
upsample any images. If trainer is unable to fit all boxes with \n\
required upsample_limit, exception will be thrown. Higher values \n\
of upsample_limit exponentially increases memory requiremens. \n\
of upsample_limit exponentially increases memory requirements. \n\
Values higher than 2 (default) are not recommended.");
}
{
typedef simple_test_results type;
py::class_<type>(m, "simple_test_results")
.def_readwrite("precision", &type::precision)
.def_readwrite("recall", &type::recall)
.def_readwrite("average_precision", &type::average_precision)
.def("__str__", &::print_simple_test_results);
.def("__str__", &::print_simple_test_results)
.def("__repr__", &::print_simple_test_results);
}
// Here, kvals is actually the result of linspace(start, end, num) and it is different from kvals used
......@@ -241,11 +265,8 @@ ensures \n\
- The trained object detector is returned.");
m.def("test_simple_object_detector", test_simple_object_detector,
// Please see test_simple_object_detector for the reason upsampling_amount is -1
py::arg("dataset_filename"), py::arg("detector_filename"), py::arg("upsampling_amount")=-1,
"requires \n\
- Optionally, take the number of times to upsample the testing images (upsampling_amount >= 0). \n\
ensures \n\
"ensures \n\
- Loads an image dataset from dataset_filename. We assume dataset_filename is \n\
a file using the XML format written by save_image_dataset_metadata(). \n\
- Loads a simple_object_detector from the file detector_filename. This means \n\
......@@ -256,7 +277,29 @@ ensures \n\
return value of this function is identical to that of dlib's \n\
test_object_detection_function() routine. Therefore, see the documentation \n\
for test_object_detection_function() for a detailed definition of these \n\
metrics. "
metrics. \n\
- if upsampling_amount>=0 then we upsample the data by upsampling_amount rather than \n\
use any upsampling amount that happens to be encoded in the given detector. If upsampling_amount<0 \n\
then we use the upsampling amount the detector wants to use."
);
m.def("test_simple_object_detector", test_simple_object_detector2,
py::arg("dataset_filename"), py::arg("detector"), py::arg("upsampling_amount")=-1,
"ensures \n\
- Loads an image dataset from dataset_filename. We assume dataset_filename is \n\
a file using the XML format written by save_image_dataset_metadata(). \n\
- Loads a simple_object_detector from the file detector_filename. This means \n\
detector_filename should be a file produced by the train_simple_object_detector() \n\
routine. \n\
- This function tests the detector against the dataset and returns the \n\
precision, recall, and average precision of the detector. In fact, The \n\
return value of this function is identical to that of dlib's \n\
test_object_detection_function() routine. Therefore, see the documentation \n\
for test_object_detection_function() for a detailed definition of these \n\
metrics. \n\
- if upsampling_amount>=0 then we upsample the data by upsampling_amount rather than \n\
use any upsampling amount that happens to be encoded in the given detector. If upsampling_amount<0 \n\
then we use the upsampling amount the detector wants to use."
);
m.def("test_simple_object_detector", test_simple_object_detector_with_images_py,
......@@ -365,8 +408,48 @@ ensures \n\
- This function runs the object detector on the input image and returns \n\
a list of detections.")
.def("save", save_simple_object_detector_py, py::arg("detector_output_filename"), "Save a simple_object_detector to the provided path.")
.def_readwrite("upsampling_amount", &type::upsampling_amount, "The detector upsamples the image this many times before running.")
.def(py::pickle(&getstate<type>, &setstate<type>));
}
m.def("num_separable_filters", [](const simple_object_detector_py& obj) { return num_separable_filters(obj.detector); },
py::arg("detector"),
"Returns the number of separable filters necessary to represent the HOG filters in the given detector."
);
m.def("threshold_filter_singular_values", [](const simple_object_detector_py& obj, double thresh) {
auto temp = obj;
temp.detector = threshold_filter_singular_values(obj.detector, thresh);
return temp;
}, py::arg("detector"), py::arg("thresh"),
"requires \n\
- thresh >= 0 \n\
ensures \n\
- Removes all components of the filters in the given detector that have \n\
singular values that are smaller than the given threshold. Therefore, this \n\
function allows you to control how many separable filters are in a detector. \n\
In particular, as thresh gets larger the quantity \n\
num_separable_filters(threshold_filter_singular_values(detector,thresh)) \n\
will generally get smaller and therefore give a faster running detector. \n\
However, note that at some point a large enough thresh will drop too much \n\
information from the filters and their accuracy will suffer. \n\
- returns the updated detector"
/*!
requires
- thresh >= 0
ensures
- Removes all components of the filters in the given detector that have
singular values that are smaller than the given threshold. Therefore, this
function allows you to control how many separable filters are in a detector.
In particular, as thresh gets larger the quantity
num_separable_filters(threshold_filter_singular_values(detector,thresh))
will generally get smaller and therefore give a faster running detector.
However, note that at some point a large enough thresh will drop too much
information from the filters and their accuracy will suffer.
- returns the updated detector
!*/
);
}
// ----------------------------------------------------------------------------------------
......@@ -12,6 +12,7 @@
#include "dlib/image_processing/remove_unobtainable_rectangles.h"
#include "serialize_object_detector.h"
#include "dlib/svm.h"
#include <sstream>
namespace dlib
......@@ -34,6 +35,7 @@ namespace dlib
C = 1;
epsilon = 0.01;
upsample_limit = 2;
nuclear_norm_regularization_strength = 0;
}
bool be_verbose;
......@@ -43,8 +45,26 @@ namespace dlib
double C;
double epsilon;
unsigned long upsample_limit;
double nuclear_norm_regularization_strength;
};
inline std::string print_simple_object_detector_training_options(const simple_object_detector_training_options& o)
{
std::ostringstream sout;
sout << "simple_object_detector_training_options("
<< "be_verbose=" << o.be_verbose << ", "
<< "add_left_right_image_flips=" << o.add_left_right_image_flips << ", "
<< "num_threads=" << o.num_threads << ", "
<< "detection_window_size=" << o.detection_window_size << ", "
<< "C=" << o.C << ", "
<< "epsilon=" << o.epsilon << ", "
<< "upsample_limit=" << o.upsample_limit << ", "
<< "nuclear_norm_regularization_strength=" << o.nuclear_norm_regularization_strength
<< ")";
return sout.str();
}
// ----------------------------------------------------------------------------------------
namespace impl
......@@ -143,6 +163,9 @@ namespace dlib
if (options.epsilon <= 0)
throw error("Invalid epsilon value given to train_simple_object_detector(), epsilon must be > 0.");
if (options.nuclear_norm_regularization_strength < 0)
throw error("Invalid nuclear_norm_regularization_strength value given to train_simple_object_detector(), it must be must be >= 0.");
if (images.size() != boxes.size())
throw error("The list of images must have the same length as the list of boxes.");
if (images.size() != ignore.size())
......@@ -156,6 +179,7 @@ namespace dlib
unsigned long width, height;
impl::pick_best_window_size(boxes, width, height, options.detection_window_size);
scanner.set_detection_window_size(width, height);
scanner.set_nuclear_norm_regularization_strength(options.nuclear_norm_regularization_strength);
structural_object_detection_trainer<image_scanner_type> trainer(scanner);
trainer.set_num_threads(options.num_threads);
trainer.set_c(options.C);
......@@ -265,12 +289,30 @@ namespace dlib
return ret;
}
inline const simple_test_results test_simple_object_detector2 (
const std::string& dataset_filename,
simple_object_detector_py& detector,
const int upsample_amount
)
{
// Load all the testing images
dlib::array<array2d<rgb_pixel> > images;
std::vector<std::vector<rectangle> > boxes, ignore;
ignore = load_image_dataset(images, boxes, dataset_filename);
unsigned int final_upsampling_amount = 0;
if (upsample_amount < 0)
final_upsampling_amount = detector.upsampling_amount;
return test_simple_object_detector_with_images(images, final_upsampling_amount, boxes, ignore, detector.detector);
}
inline const simple_test_results test_simple_object_detector (
const std::string& dataset_filename,
const std::string& detector_filename,
const int upsample_amount
)
{
// Load all the testing images
dlib::array<array2d<rgb_pixel> > images;
std::vector<std::vector<rectangle> > boxes, ignore;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment