Commit 8e1e548a authored by Davis King's avatar Davis King

Updated this example to use the scan_fhog_pyramid version of the object

detector since it is much more user friendly.
parent e79a7648
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
functional command line tool for object detection. This example assumes functional command line tool for object detection. This example assumes
you are familiar with the contents of at least the following example you are familiar with the contents of at least the following example
programs: programs:
- object_detector_ex.cpp - fhog_object_detector_ex.cpp
- compress_stream_ex.cpp - compress_stream_ex.cpp
...@@ -35,7 +35,11 @@ ...@@ -35,7 +35,11 @@
holding the shift key, left clicking, and dragging the mouse will allow you to holding the shift key, left clicking, and dragging the mouse will allow you to
draw boxes around the objects you wish to detect. So next, label all the objects draw boxes around the objects you wish to detect. So next, label all the objects
with boxes. Note that it is important to label all the objects since any object with boxes. Note that it is important to label all the objects since any object
not labeled is implicitly assumed to be not an object we should detect. not labeled is implicitly assumed to be not an object we should detect. If there
are objects you are not sure about you should draw a box around them, then double
click the box and press i. This will cross out the box and mark it as "ignore".
The training code in dlib will then simply ignore detections matching that box.
Once you finish labeling objects go to the file menu, click save, and then close Once you finish labeling objects go to the file menu, click save, and then close
the program. This will save the object boxes back to mydataset.xml. You can verify the program. This will save the object boxes back to mydataset.xml. You can verify
...@@ -53,18 +57,20 @@ ...@@ -53,18 +57,20 @@
This command will display some_image.png in a window and any detected objects will This command will display some_image.png in a window and any detected objects will
be indicated by a red box. be indicated by a red box.
Finally, to make running this example easy dlib includes some training data in the
There are a number of other useful command line options in the current example examples/faces folder. Therefore, you can test this program out by running the
program which you can explore below. following sequence of commands:
./train_object_detector -tv examples/faces/training.xml -u1 --flip
./train_object_detector --test examples/faces/testing.xml -u1
./train_object_detector examples/faces/*.jpg -u1
That will make a face detector that performs perfectly on the test images listed in
testing.xml and then it will show you the detections on all the images.
*/ */
#include <dlib/svm_threaded.h> #include <dlib/svm_threaded.h>
#include <dlib/string.h> #include <dlib/string.h>
#include <dlib/gui_widgets.h> #include <dlib/gui_widgets.h>
#include <dlib/array.h>
#include <dlib/array2d.h>
#include <dlib/image_keypoint.h>
#include <dlib/image_processing.h> #include <dlib/image_processing.h>
#include <dlib/data_io.h> #include <dlib/data_io.h>
#include <dlib/cmd_line_parser.h> #include <dlib/cmd_line_parser.h>
...@@ -79,44 +85,131 @@ using namespace dlib; ...@@ -79,44 +85,131 @@ using namespace dlib;
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
void pick_best_window_size (
const std::vector<std::vector<rectangle> >& boxes,
unsigned long& width,
unsigned long& height,
const unsigned long target_size
)
/*!
ensures
- Finds the average aspect ratio of the elements of boxes and outputs a width
and height such that the aspect ratio is equal to the average and also the
area is equal to target_size. That is, the following will be approximately true:
- #width*#height == target_size
- #width/#height == the average aspect ratio of the elements of boxes.
!*/
{
// find the average width and height
running_stats<double> avg_width, avg_height;
for (unsigned long i = 0; i < boxes.size(); ++i)
{
for (unsigned long j = 0; j < boxes[i].size(); ++j)
{
avg_width.add(boxes[i][j].width());
avg_height.add(boxes[i][j].height());
}
}
// now adjust the box size so that it is about target_pixels pixels in size
double size = avg_width.mean()*avg_height.mean();
double scale = std::sqrt(target_size/size);
width = (unsigned long)(avg_width.mean()*scale+0.5);
height = (unsigned long)(avg_height.mean()*scale+0.5);
// make sure the width and height never round to zero.
if (width == 0)
width = 1;
if (height == 0)
height = 1;
}
// ----------------------------------------------------------------------------------------
bool contains_any_boxes (
const std::vector<std::vector<rectangle> >& boxes
)
{
for (unsigned long i = 0; i < boxes.size(); ++i)
{
if (boxes[i].size() != 0)
return true;
}
return false;
}
// ----------------------------------------------------------------------------------------
void throw_invalid_box_error_message (
const std::string& dataset_filename,
const std::vector<std::vector<rectangle> >& removed,
const unsigned long target_size
)
{
image_dataset_metadata::dataset data;
load_image_dataset_metadata(data, dataset_filename);
std::ostringstream sout;
sout << "Error! An impossible set of object boxes was given for training. ";
sout << "All the boxes need to have a similar aspect ratio and also not be ";
sout << "smaller than about " << target_size << " pixels in area. ";
sout << "The following images contain invalid boxes:\n";
std::ostringstream sout2;
for (unsigned long i = 0; i < removed.size(); ++i)
{
if (removed[i].size() != 0)
{
const std::string imgname = data.images[i].filename;
sout2 << " " << imgname << "\n";
}
}
throw error("\n"+wrap_string(sout.str()) + "\n" + sout2.str());
}
// ----------------------------------------------------------------------------------------
int main(int argc, char** argv) int main(int argc, char** argv)
{ {
try try
{ {
command_line_parser parser; command_line_parser parser;
parser.add_option("h","Display this help message."); parser.add_option("h","Display this help message.");
parser.add_option("v","Be verbose.");
parser.add_option("t","Train an object detector and save the detector to disk."); parser.add_option("t","Train an object detector and save the detector to disk.");
parser.add_option("cross-validate", parser.add_option("cross-validate",
"Perform cross-validation on an image dataset and print the results."); "Perform cross-validation on an image dataset and print the results.");
parser.add_option("test", "Test a trained detector on an image dataset and print the results.");
parser.add_option("u", "Upsample each input image <arg> times. Each upsampling quadruples the number of pixels in the image (default: 0).", 1);
parser.set_group_name("training/cross-validation sub-options");
parser.add_option("v","Be verbose.");
parser.add_option("folds","When doing cross-validation, do <arg> folds (default: 3).",1); parser.add_option("folds","When doing cross-validation, do <arg> folds (default: 3).",1);
parser.add_option("c","Set the SVM C parameter to <arg> (default: 1.0).",1); parser.add_option("c","Set the SVM C parameter to <arg> (default: 1.0).",1);
parser.add_option("threads", "Use <arg> threads for training <arg> (default: 4).",1); parser.add_option("threads", "Use <arg> threads for training <arg> (default: 4).",1);
parser.add_option("grid-size", "Extract features in a detection window from an <arg> by <arg> grid. (default: 2).",1); parser.add_option("eps", "Set training epsilon to <arg> (default: 0.01).", 1);
parser.add_option("hash-bits", "Use <arg> bits for the feature hashing (default: 10).", 1); parser.add_option("target-size", "Set size of the sliding window to about <arg> pixels in area (default: 80*80).", 1);
parser.add_option("test", "Test a trained detector on an image dataset and print the results."); parser.add_option("flip", "Add left/right flipped copies of the images into the training dataset. Useful when the objects "
parser.add_option("eps", "Set training epsilon to <arg> (default: 0.3).", 1); "you want to detect are left/right symmetric.");
parser.parse(argc, argv); parser.parse(argc, argv);
// Now we do a little command line validation. Each of the following functions // Now we do a little command line validation. Each of the following functions
// checks something and throws an exception if the test fails. // checks something and throws an exception if the test fails.
const char* one_time_opts[] = {"h", "v", "t", "cross-validate", "c", "threads", "grid-size", "hash-bits", const char* one_time_opts[] = {"h", "v", "t", "cross-validate", "c", "threads", "target-size",
"folds", "test", "eps"}; "folds", "test", "eps", "u", "flip"};
parser.check_one_time_options(one_time_opts); // Can't give an option more than once parser.check_one_time_options(one_time_opts); // Can't give an option more than once
// Make sure the arguments to these options are within valid ranges if they are supplied by the user. // Make sure the arguments to these options are within valid ranges if they are supplied by the user.
parser.check_option_arg_range("c", 1e-12, 1e12); parser.check_option_arg_range("c", 1e-12, 1e12);
parser.check_option_arg_range("eps", 1e-5, 1e4); parser.check_option_arg_range("eps", 1e-5, 1e4);
parser.check_option_arg_range("threads", 1, 1000); parser.check_option_arg_range("threads", 1, 1000);
parser.check_option_arg_range("grid-size", 1, 100);
parser.check_option_arg_range("hash-bits", 1, 32);
parser.check_option_arg_range("folds", 2, 100); parser.check_option_arg_range("folds", 2, 100);
parser.check_option_arg_range("u", 0, 8);
parser.check_option_arg_range("target-size", 4*4, 10000*10000);
const char* incompatible[] = {"t", "cross-validate", "test"}; const char* incompatible[] = {"t", "cross-validate", "test"};
parser.check_incompatible_options(incompatible); parser.check_incompatible_options(incompatible);
// You are only allowed to give these training_sub_ops if you also give either -t or --cross-validate. // You are only allowed to give these training_sub_ops if you also give either -t or --cross-validate.
const char* training_ops[] = {"t", "cross-validate"}; const char* training_ops[] = {"t", "cross-validate"};
const char* training_sub_ops[] = {"v", "c", "threads", "grid-size", "hash-bits"}; const char* training_sub_ops[] = {"v", "c", "threads", "target-size", "eps", "flip"};
parser.check_sub_options(training_ops, training_sub_ops); parser.check_sub_options(training_ops, training_sub_ops);
parser.check_sub_option("cross-validate", "folds"); parser.check_sub_option("cross-validate", "folds");
...@@ -130,10 +223,9 @@ int main(int argc, char** argv) ...@@ -130,10 +223,9 @@ int main(int argc, char** argv)
} }
typedef scan_fhog_pyramid<pyramid_down<6> > image_scanner_type;
// Get the upsample option from the user but use 0 if it wasn't given.
typedef hashed_feature_image<hog_image<4,4,1,9,hog_signed_gradient,hog_full_interpolation> > feature_extractor_type; const unsigned long upsample_amount = get_option(parser, "u", 0);
typedef scan_image_pyramid<pyramid_down<3>, feature_extractor_type> image_scanner_type;
if (parser.option("t") || parser.option("cross-validate")) if (parser.option("t") || parser.option("cross-validate"))
{ {
...@@ -145,43 +237,58 @@ int main(int argc, char** argv) ...@@ -145,43 +237,58 @@ int main(int argc, char** argv)
} }
dlib::array<array2d<unsigned char> > images; dlib::array<array2d<unsigned char> > images;
std::vector<std::vector<rectangle> > object_locations; std::vector<std::vector<rectangle> > object_locations, ignore;
cout << "Loading image dataset from metadata file " << parser[0] << endl; cout << "Loading image dataset from metadata file " << parser[0] << endl;
load_image_dataset(images, object_locations, parser[0]); ignore = load_image_dataset(images, object_locations, parser[0]);
cout << "Number of images loaded: " << images.size() << endl; cout << "Number of images loaded: " << images.size() << endl;
// Get the value of the hash-bits option if the user supplied it. Otherwise // Get the options from the user, but use default values if they are not
// use the default value of 10. // supplied.
const int hash_bits = get_option(parser, "hash-bits", 10);
const int grid_size = get_option(parser, "grid-size", 2);
const int threads = get_option(parser, "threads", 4); const int threads = get_option(parser, "threads", 4);
const double C = get_option(parser, "c", 1.0); const double C = get_option(parser, "c", 1.0);
const double eps = get_option(parser, "eps", 0.3); const double eps = get_option(parser, "eps", 0.01);
unsigned int num_folds = get_option(parser, "folds", 3); unsigned int num_folds = get_option(parser, "folds", 3);
const unsigned long target_size = get_option(parser, "target-size", 80*80);
// You can't do more folds than there are images. // You can't do more folds than there are images.
if (num_folds > images.size()) if (num_folds > images.size())
num_folds = images.size(); num_folds = images.size();
// Upsample images if the user asked us to do that.
for (unsigned long i = 0; i < upsample_amount; ++i)
upsample_image_dataset<pyramid_down<2> >(images, object_locations, ignore);
image_scanner_type scanner; image_scanner_type scanner;
setup_grid_detection_templates_verbose(scanner, object_locations, grid_size, grid_size); unsigned long width, height;
setup_hashed_features(scanner, images, hash_bits); pick_best_window_size(object_locations, width, height, target_size);
scanner.set_detection_window_size(width, height);
structural_object_detection_trainer<image_scanner_type> trainer(scanner); structural_object_detection_trainer<image_scanner_type> trainer(scanner);
trainer.set_num_threads(threads); trainer.set_num_threads(threads);
if (parser.option("v")) if (parser.option("v"))
trainer.be_verbose(); trainer.be_verbose();
trainer.set_c(C); trainer.set_c(C);
trainer.set_epsilon(eps); trainer.set_epsilon(eps);
// Now make sure all the boxes are obtainable by the scanner.
std::vector<std::vector<rectangle> > removed;
removed = remove_unobtainable_rectangles(trainer, images, object_locations);
// if we weren't able to get all the boxes to match then throw an error
if (contains_any_boxes(removed))
{
unsigned long scale = upsample_amount+1;
scale = scale*scale;
throw_invalid_box_error_message(parser[0], removed, target_size/scale);
}
if (parser.option("flip"))
add_image_left_right_flips(images, object_locations, ignore);
if (parser.option("t")) if (parser.option("t"))
{ {
// Do the actual training and save the results into the detector object. // Do the actual training and save the results into the detector object.
object_detector<image_scanner_type> detector = trainer.train(images, object_locations); object_detector<image_scanner_type> detector = trainer.train(images, object_locations, ignore);
cout << "Saving trained detector to object_detector.svm" << endl; cout << "Saving trained detector to object_detector.svm" << endl;
ofstream fout("object_detector.svm", ios::binary); ofstream fout("object_detector.svm", ios::binary);
...@@ -197,15 +304,19 @@ int main(int argc, char** argv) ...@@ -197,15 +304,19 @@ int main(int argc, char** argv)
randomize_samples(images, object_locations); randomize_samples(images, object_locations);
cout << num_folds << "-fold cross validation (precision,recall,AP): " cout << num_folds << "-fold cross validation (precision,recall,AP): "
<< cross_validate_object_detection_trainer(trainer, images, object_locations, num_folds) << endl; << cross_validate_object_detection_trainer(trainer, images, object_locations, ignore, num_folds) << endl;
} }
cout << "Parameters used: " << endl; cout << "Parameters used: " << endl;
cout << " hash-bits: "<< hash_bits << endl;
cout << " grid-size: "<< grid_size << endl;
cout << " threads: "<< threads << endl; cout << " threads: "<< threads << endl;
cout << " C: "<< C << endl; cout << " C: "<< C << endl;
cout << " eps: "<< eps << endl; cout << " eps: "<< eps << endl;
cout << " target-size: "<< target_size << endl;
cout << " detection window width: "<< width << endl;
cout << " detection window height: "<< height << endl;
cout << " upsample this many times : "<< upsample_amount << endl;
if (parser.option("flip"))
cout << " trained using left/right flips." << endl;
if (parser.option("cross-validate")) if (parser.option("cross-validate"))
cout << " num_folds: "<< num_folds << endl; cout << " num_folds: "<< num_folds << endl;
cout << endl; cout << endl;
...@@ -215,10 +326,12 @@ int main(int argc, char** argv) ...@@ -215,10 +326,12 @@ int main(int argc, char** argv)
// The rest of the code is devoted to testing out an already trained
// object detector.
// The rest of the code is devoted to testing an already trained object detector.
if (parser.number_of_arguments() == 0) if (parser.number_of_arguments() == 0)
{ {
cout << "You must give an image or an image dataset metadata XML file produced by the imglab tool." << endl; cout << "You must give an image or an image dataset metadata XML file produced by the imglab tool." << endl;
...@@ -243,15 +356,19 @@ int main(int argc, char** argv) ...@@ -243,15 +356,19 @@ int main(int argc, char** argv)
// Check if the command line argument is an XML file // Check if the command line argument is an XML file
if (tolower(right_substr(parser[0],".")) == "xml") if (tolower(right_substr(parser[0],".")) == "xml")
{ {
std::vector<std::vector<rectangle> > object_locations; std::vector<std::vector<rectangle> > object_locations, ignore;
cout << "Loading image dataset from metadata file " << parser[0] << endl; cout << "Loading image dataset from metadata file " << parser[0] << endl;
load_image_dataset(images, object_locations, parser[0]); ignore = load_image_dataset(images, object_locations, parser[0]);
cout << "Number of images loaded: " << images.size() << endl; cout << "Number of images loaded: " << images.size() << endl;
// Upsample images if the user asked us to do that.
for (unsigned long i = 0; i < upsample_amount; ++i)
upsample_image_dataset<pyramid_down<2> >(images, object_locations, ignore);
if (parser.option("test")) if (parser.option("test"))
{ {
cout << "Testing detector on data..." << endl; cout << "Testing detector on data..." << endl;
cout << "Results (precision,recall,AP): " << test_object_detection_function(detector, images, object_locations) << endl; cout << "Results (precision,recall,AP): " << test_object_detection_function(detector, images, object_locations, ignore) << endl;
return EXIT_SUCCESS; return EXIT_SUCCESS;
} }
} }
...@@ -262,6 +379,13 @@ int main(int argc, char** argv) ...@@ -262,6 +379,13 @@ int main(int argc, char** argv)
images.resize(parser.number_of_arguments()); images.resize(parser.number_of_arguments());
for (unsigned long i = 0; i < images.size(); ++i) for (unsigned long i = 0; i < images.size(); ++i)
load_image(images[i], parser[i]); load_image(images[i], parser[i]);
// Upsample images if the user asked us to do that.
for (unsigned long i = 0; i < upsample_amount; ++i)
{
for (unsigned long j = 0; j < images.size(); ++j)
pyramid_up(images[j]);
}
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment