Commit 4b907532 authored by Davis King's avatar Davis King

Allow python to set the padding mode of the shape_predictor_trainer.

parent 8798955d
...@@ -197,9 +197,42 @@ void bind_shape_predictors(py::module &m) ...@@ -197,9 +197,42 @@ void bind_shape_predictors(py::module &m)
"Controls how tight the feature sampling should be. Lower values enforce closer features.") "Controls how tight the feature sampling should be. Lower values enforce closer features.")
.def_readwrite("num_test_splits", &type::num_test_splits, .def_readwrite("num_test_splits", &type::num_test_splits,
"Number of split features at each node to sample. The one that gives the best split is chosen.") "Number of split features at each node to sample. The one that gives the best split is chosen.")
.def_readwrite("landmark_relative_padding_mode", &type::landmark_relative_padding_mode,
"If True then features are drawn only from the box around the landmarks, otherwise they come from the bounding box and landmarks together. See feature_pool_region_padding doc for more details.")
.def_readwrite("feature_pool_region_padding", &type::feature_pool_region_padding, .def_readwrite("feature_pool_region_padding", &type::feature_pool_region_padding,
"Size of region within which to sample features for the feature pool, \ /*!
e.g a padding of 0.5 would cause the algorithm to sample pixels from a box that was 2x2 pixels") This algorithm works by comparing the relative intensity of pairs of
pixels in the input image. To decide which pixels to look at, the
training algorithm randomly selects pixels from a box roughly centered
around the object of interest. We call this box the feature pool region
box.
Each object of interest is defined by a full_object_detection, which
contains a bounding box and a list of landmarks. If
landmark_relative_padding_mode==True then the feature pool region box is
the tightest box that contains the landmarks inside the
full_object_detection. In this mode the full_object_detection's bounding
box is ignored. Otherwise, if the padding mode is bounding_box_relative
then the feature pool region box is the tightest box that contains BOTH
the landmarks and the full_object_detection's bounding box.
Additionally, you can adjust the size of the feature pool padding region
by setting feature_pool_region_padding to some value. If
feature_pool_region_padding then the feature pool region box is
unmodified and defined exactly as stated above. However, you can expand
the size of the box by setting the padding > 0 or shrink it by setting it
to something < 0.
To explain this precisely, for a padding of 0 we say that the pixels are
sampled from a box of size 1x1. The padding value is added to each side
of the box. So a padding of 0.5 would cause the algorithm to sample
pixels from a box that was 2x2, effectively multiplying the area pixels
are sampled from by 4. Similarly, setting the padding to -0.2 would
cause it to sample from a box 0.6x0.6 in size.
!*/
"Size of region within which to sample features for the feature pool. \
positive values increase the sampling region while negative values decrease it. E.g. padding of 0 means we \
sample fr")
.def_readwrite("random_seed", &type::random_seed, .def_readwrite("random_seed", &type::random_seed,
"The random seed used by the internal random number generator") "The random seed used by the internal random number generator")
.def_readwrite("num_threads", &type::num_threads, .def_readwrite("num_threads", &type::num_threads,
......
...@@ -32,6 +32,7 @@ namespace dlib ...@@ -32,6 +32,7 @@ namespace dlib
feature_pool_region_padding = 0; feature_pool_region_padding = 0;
random_seed = ""; random_seed = "";
num_threads = 0; num_threads = 0;
landmark_relative_padding_mode = true;
} }
bool be_verbose; bool be_verbose;
...@@ -46,6 +47,7 @@ namespace dlib ...@@ -46,6 +47,7 @@ namespace dlib
unsigned long num_test_splits; unsigned long num_test_splits;
double feature_pool_region_padding; double feature_pool_region_padding;
std::string random_seed; std::string random_seed;
bool landmark_relative_padding_mode;
// not serialized // not serialized
unsigned long num_threads; unsigned long num_threads;
...@@ -58,7 +60,7 @@ namespace dlib ...@@ -58,7 +60,7 @@ namespace dlib
{ {
try try
{ {
serialize("shape_predictor_training_options", out); serialize("shape_predictor_training_options_v2", out);
serialize(item.be_verbose,out); serialize(item.be_verbose,out);
serialize(item.cascade_depth,out); serialize(item.cascade_depth,out);
serialize(item.tree_depth,out); serialize(item.tree_depth,out);
...@@ -71,6 +73,7 @@ namespace dlib ...@@ -71,6 +73,7 @@ namespace dlib
serialize(item.num_test_splits,out); serialize(item.num_test_splits,out);
serialize(item.feature_pool_region_padding,out); serialize(item.feature_pool_region_padding,out);
serialize(item.random_seed,out); serialize(item.random_seed,out);
serialize(item.landmark_relative_padding_mode,out);
} }
catch (serialization_error& e) catch (serialization_error& e)
{ {
...@@ -85,7 +88,7 @@ namespace dlib ...@@ -85,7 +88,7 @@ namespace dlib
{ {
try try
{ {
check_serialized_version("shape_predictor_training_options", in); check_serialized_version("shape_predictor_training_options_v2", in);
deserialize(item.be_verbose,in); deserialize(item.be_verbose,in);
deserialize(item.cascade_depth,in); deserialize(item.cascade_depth,in);
deserialize(item.tree_depth,in); deserialize(item.tree_depth,in);
...@@ -98,6 +101,7 @@ namespace dlib ...@@ -98,6 +101,7 @@ namespace dlib
deserialize(item.num_test_splits,in); deserialize(item.num_test_splits,in);
deserialize(item.feature_pool_region_padding,in); deserialize(item.feature_pool_region_padding,in);
deserialize(item.random_seed,in); deserialize(item.random_seed,in);
deserialize(item.landmark_relative_padding_mode,in);
} }
catch (serialization_error& e) catch (serialization_error& e)
{ {
...@@ -122,6 +126,7 @@ namespace dlib ...@@ -122,6 +126,7 @@ namespace dlib
<< "feature_pool_region_padding=" << o.feature_pool_region_padding << ", " << "feature_pool_region_padding=" << o.feature_pool_region_padding << ", "
<< "random_seed=" << o.random_seed << ", " << "random_seed=" << o.random_seed << ", "
<< "num_threads=" << o.num_threads << "num_threads=" << o.num_threads
<< "landmark_relative_padding_mode=" << o.landmark_relative_padding_mode
<< ")"; << ")";
return sout.str(); return sout.str();
} }
...@@ -179,6 +184,10 @@ namespace dlib ...@@ -179,6 +184,10 @@ namespace dlib
trainer.set_lambda(options.lambda_param); trainer.set_lambda(options.lambda_param);
trainer.set_num_test_splits(options.num_test_splits); trainer.set_num_test_splits(options.num_test_splits);
trainer.set_num_threads(options.num_threads); trainer.set_num_threads(options.num_threads);
if (options.landmark_relative_padding_mode)
trainer.set_padding_mode(shape_predictor_trainer::landmark_relative);
else
trainer.set_padding_mode(shape_predictor_trainer::bounding_box_relative);
if (options.be_verbose) if (options.be_verbose)
{ {
...@@ -189,6 +198,7 @@ namespace dlib ...@@ -189,6 +198,7 @@ namespace dlib
std::cout << "Training with random seed: " << options.random_seed << std::endl; std::cout << "Training with random seed: " << options.random_seed << std::endl;
std::cout << "Training with oversampling amount: " << options.oversampling_amount << std::endl; std::cout << "Training with oversampling amount: " << options.oversampling_amount << std::endl;
std::cout << "Training with oversampling translation jitter: " << options.oversampling_translation_jitter << std::endl; std::cout << "Training with oversampling translation jitter: " << options.oversampling_translation_jitter << std::endl;
std::cout << "Training with landmark_relative_padding_mode: " << options.landmark_relative_padding_mode << std::endl;
std::cout << "Training with feature pool size: " << options.feature_pool_size << std::endl; std::cout << "Training with feature pool size: " << options.feature_pool_size << std::endl;
std::cout << "Training with feature pool region padding: " << options.feature_pool_region_padding << std::endl; std::cout << "Training with feature pool region padding: " << options.feature_pool_region_padding << std::endl;
std::cout << "Training with " << options.num_threads << " threads." << std::endl; std::cout << "Training with " << options.num_threads << " threads." << std::endl;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment