Commit 3126372c authored by Davis King's avatar Davis King

Added the translation jittering option to the python API.

parent a9c940b1
......@@ -189,6 +189,8 @@ void bind_shape_predictors(py::module &m)
cause overfitting. The value must be in the range (0, 1].")
.def_readwrite("oversampling_amount", &type::oversampling_amount,
"The number of randomly selected initial starting points sampled for each training example")
.def_readwrite("oversampling_translation_jitter", &type::oversampling_translation_jitter,
"The amount of translation jittering to apply to bounding boxes, a good value is in in the range [0 0.5].")
.def_readwrite("feature_pool_size", &type::feature_pool_size,
"Number of pixels used to generate features for the random trees.")
.def_readwrite("lambda_param", &type::lambda_param,
......
......@@ -25,6 +25,7 @@ namespace dlib
num_trees_per_cascade_level = 500;
nu = 0.1;
oversampling_amount = 20;
oversampling_translation_jitter = 0;
feature_pool_size = 400;
lambda_param = 0.1;
num_test_splits = 20;
......@@ -39,6 +40,7 @@ namespace dlib
unsigned long num_trees_per_cascade_level;
double nu;
unsigned long oversampling_amount;
double oversampling_translation_jitter;
unsigned long feature_pool_size;
double lambda_param;
unsigned long num_test_splits;
......@@ -56,12 +58,14 @@ namespace dlib
{
try
{
serialize("shape_predictor_training_options", out);
serialize(item.be_verbose,out);
serialize(item.cascade_depth,out);
serialize(item.tree_depth,out);
serialize(item.num_trees_per_cascade_level,out);
serialize(item.nu,out);
serialize(item.oversampling_amount,out);
serialize(item.oversampling_translation_jitter,out);
serialize(item.feature_pool_size,out);
serialize(item.lambda_param,out);
serialize(item.num_test_splits,out);
......@@ -81,12 +85,14 @@ namespace dlib
{
try
{
check_serialized_version("shape_predictor_training_options", in);
deserialize(item.be_verbose,in);
deserialize(item.cascade_depth,in);
deserialize(item.tree_depth,in);
deserialize(item.num_trees_per_cascade_level,in);
deserialize(item.nu,in);
deserialize(item.oversampling_amount,in);
deserialize(item.oversampling_translation_jitter,in);
deserialize(item.feature_pool_size,in);
deserialize(item.lambda_param,in);
deserialize(item.num_test_splits,in);
......@@ -109,6 +115,7 @@ namespace dlib
<< "num_trees_per_cascade_level=" << o.num_trees_per_cascade_level << ","
<< "nu=" << o.nu << ","
<< "oversampling_amount=" << o.oversampling_amount << ","
<< "oversampling_translation_jitter=" << o.oversampling_translation_jitter << ","
<< "feature_pool_size=" << o.feature_pool_size << ","
<< "lambda_param=" << o.lambda_param << ","
<< "num_test_splits=" << o.num_test_splits << ","
......@@ -166,6 +173,7 @@ namespace dlib
trainer.set_nu(options.nu);
trainer.set_random_seed(options.random_seed);
trainer.set_oversampling_amount(options.oversampling_amount);
trainer.set_oversampling_translation_jitter(options.oversampling_translation_jitter);
trainer.set_feature_pool_size(options.feature_pool_size);
trainer.set_feature_pool_region_padding(options.feature_pool_region_padding);
trainer.set_lambda(options.lambda_param);
......@@ -180,6 +188,7 @@ namespace dlib
std::cout << "Training with nu: " << options.nu << std::endl;
std::cout << "Training with random seed: " << options.random_seed << std::endl;
std::cout << "Training with oversampling amount: " << options.oversampling_amount << std::endl;
std::cout << "Training with oversampling translation jitter: " << options.oversampling_translation_jitter << std::endl;
std::cout << "Training with feature pool size: " << options.feature_pool_size << std::endl;
std::cout << "Training with feature pool region padding: " << options.feature_pool_region_padding << std::endl;
std::cout << "Training with " << options.num_threads << " threads." << std::endl;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment