Commit dd19ce84 authored by Patrick Snape's avatar Patrick Snape

Update the interface to be more Pythonic

This is the biggest change so far. Now, there are two different
classes of interface. One where you pass ONLY file paths,
and one where you pass ONLY Python objects.

The file paths are maintained to keep a matching interface with
the C++ examples of dlib. So shape predicition and object
detection can be trained using the dlib XML file paths and then
serialize the detectors to disk.

Shape prediction and object detection can also be trained using
numpy arrays and in-memory objects. In this case, the predictor
and detector objects are returned from the training functions.
To facilitate serializing these objects, they now have a 'save'
method.

Tetsing follows a similar pattern, in that it can take either XML
files are or in-memory objects. I also added back the concept of
upsampling during testing to make amends for removing the
simple_object_detector_py struct.
parent 8db3f4e5
This diff is collapsed.
...@@ -84,10 +84,9 @@ boost::shared_ptr<full_object_detection> full_obj_det_init(object& pyrect, objec ...@@ -84,10 +84,9 @@ boost::shared_ptr<full_object_detection> full_obj_det_init(object& pyrect, objec
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
inline void train_shape_predictor_on_images_py ( inline shape_predictor train_shape_predictor_on_images_py (
const object& pyimages, const boost::python::list& pyimages,
const object& pydetections, const boost::python::list& pydetections,
const std::string& predictor_output_filename,
const shape_predictor_training_options& options const shape_predictor_training_options& options
) )
{ {
...@@ -99,15 +98,15 @@ inline void train_shape_predictor_on_images_py ( ...@@ -99,15 +98,15 @@ inline void train_shape_predictor_on_images_py (
dlib::array<array2d<rgb_pixel> > images(num_images); dlib::array<array2d<rgb_pixel> > images(num_images);
images_and_nested_params_to_dlib(pyimages, pydetections, images, detections); images_and_nested_params_to_dlib(pyimages, pydetections, images, detections);
train_shape_predictor_on_images("", images, detections, predictor_output_filename, options); return train_shape_predictor_on_images(images, detections, options);
} }
inline double test_shape_predictor_with_images_py ( inline double test_shape_predictor_with_images_py (
const object& pyimages, const boost::python::list& pyimages,
const object& pydetections, const boost::python::list& pydetections,
const object& pyscales, const boost::python::list& pyscales,
const std::string& predictor_filename const shape_predictor& predictor
) )
{ {
const unsigned long num_images = len(pyimages); const unsigned long num_images = len(pyimages);
...@@ -141,17 +140,17 @@ inline double test_shape_predictor_with_images_py ( ...@@ -141,17 +140,17 @@ inline double test_shape_predictor_with_images_py (
} }
} }
return test_shape_predictor_with_images(images, detections, scales, predictor_filename); return test_shape_predictor_with_images(images, detections, scales, predictor);
} }
inline double test_shape_predictor_with_images_no_scales_py ( inline double test_shape_predictor_with_images_no_scales_py (
const object& pyimages, const boost::python::list& pyimages,
const object& pydetections, const boost::python::list& pydetections,
const std::string& predictor_filename const shape_predictor& predictor
) )
{ {
boost::python::list pyscales; boost::python::list pyscales;
return test_shape_predictor_with_images_py(pyimages, pydetections, pyscales, predictor_filename); return test_shape_predictor_with_images_py(pyimages, pydetections, pyscales, predictor);
} }
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
...@@ -239,7 +238,7 @@ ensures \n\ ...@@ -239,7 +238,7 @@ ensures \n\
} }
{ {
def("train_shape_predictor", train_shape_predictor_on_images_py, def("train_shape_predictor", train_shape_predictor_on_images_py,
(arg("images"), arg("object_detections"), arg("predictor_output_filename"), arg("options")), (arg("images"), arg("object_detections"), arg("options")),
"requires \n\ "requires \n\
- options.lambda > 0 \n\ - options.lambda > 0 \n\
- options.nu > 0 \n\ - options.nu > 0 \n\
...@@ -255,7 +254,7 @@ ensures \n\ ...@@ -255,7 +254,7 @@ ensures \n\
preprocessing techniques to the training procedure for shape_predictors \n\ preprocessing techniques to the training procedure for shape_predictors \n\
objects. So the point of this function is to provide you with a very easy \n\ objects. So the point of this function is to provide you with a very easy \n\
way to train a basic shape predictor. \n\ way to train a basic shape predictor. \n\
- The trained shape predictor is serialized to the file predictor_output_filename."); - The trained shape_predictor is returned");
def("train_shape_predictor", train_shape_predictor, def("train_shape_predictor", train_shape_predictor,
(arg("dataset_filename"), arg("predictor_output_filename"), arg("options")), (arg("dataset_filename"), arg("predictor_output_filename"), arg("options")),
...@@ -289,15 +288,14 @@ ensures \n\ ...@@ -289,15 +288,14 @@ ensures \n\
for shape_predictor_trainer() for a detailed definition of the mean average error."); for shape_predictor_trainer() for a detailed definition of the mean average error.");
def("test_shape_predictor", test_shape_predictor_with_images_no_scales_py, def("test_shape_predictor", test_shape_predictor_with_images_no_scales_py,
(arg("images"), arg("detections"), arg("predictor_filename")), (arg("images"), arg("detections"), arg("shape_predictor")),
"requires \n\ "requires \n\
- len(images) == len(object_detections) \n\ - len(images) == len(object_detections) \n\
- images should be a list of numpy matrices that represent images, either RGB or grayscale. \n\ - images should be a list of numpy matrices that represent images, either RGB or grayscale. \n\
- object_detections should be a list of lists of dlib.full_object_detection objects. \ - object_detections should be a list of lists of dlib.full_object_detection objects. \
Each dlib.full_object_detection contains the bounding box and the lists of points that make up the object parts.\n\ Each dlib.full_object_detection contains the bounding box and the lists of points that make up the object parts.\n\
ensures \n\ ensures \n\
- Loads a shape_predictor from the file predictor_filename. This means \n\ - shape_predictor should be a file produced by the train_shape_predictor() \n\
predictor_filename should be a file produced by the train_shape_predictor() \n\
routine. \n\ routine. \n\
- This function tests the predictor against the dataset and returns the \n\ - This function tests the predictor against the dataset and returns the \n\
mean average error of the detector. In fact, The \n\ mean average error of the detector. In fact, The \n\
...@@ -307,7 +305,7 @@ ensures \n\ ...@@ -307,7 +305,7 @@ ensures \n\
def("test_shape_predictor", test_shape_predictor_with_images_py, def("test_shape_predictor", test_shape_predictor_with_images_py,
(arg("images"), arg("detections"), arg("scales"), arg("predictor_filename")), (arg("images"), arg("detections"), arg("scales"), arg("shape_predictor")),
"requires \n\ "requires \n\
- len(images) == len(object_detections) \n\ - len(images) == len(object_detections) \n\
- len(object_detections) == len(scales) \n\ - len(object_detections) == len(scales) \n\
...@@ -318,8 +316,7 @@ ensures \n\ ...@@ -318,8 +316,7 @@ ensures \n\
- object_detections should be a list of lists of dlib.full_object_detection objects. \ - object_detections should be a list of lists of dlib.full_object_detection objects. \
Each dlib.full_object_detection contains the bounding box and the lists of points that make up the object parts.\n\ Each dlib.full_object_detection contains the bounding box and the lists of points that make up the object parts.\n\
ensures \n\ ensures \n\
- Loads a shape_predictor from the file predictor_filename. This means \n\ - shape_predictor should be a file produced by the train_shape_predictor() \n\
predictor_filename should be a file produced by the train_shape_predictor() \n\
routine. \n\ routine. \n\
- This function tests the predictor against the dataset and returns the \n\ - This function tests the predictor against the dataset and returns the \n\
mean average error of the detector. In fact, The \n\ mean average error of the detector. In fact, The \n\
......
...@@ -65,11 +65,9 @@ namespace dlib ...@@ -65,11 +65,9 @@ namespace dlib
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
template <typename image_array> template <typename image_array>
inline void train_shape_predictor_on_images ( inline shape_predictor train_shape_predictor_on_images (
const std::string& dataset_filename, // can be "" if it's not applicable
image_array& images, image_array& images,
std::vector<std::vector<full_object_detection> >& detections, std::vector<std::vector<full_object_detection> >& detections,
const std::string& predictor_output_filename,
const shape_predictor_training_options& options const shape_predictor_training_options& options
) )
{ {
...@@ -116,13 +114,7 @@ namespace dlib ...@@ -116,13 +114,7 @@ namespace dlib
shape_predictor predictor = trainer.train(images, detections); shape_predictor predictor = trainer.train(images, detections);
std::ofstream fout(predictor_output_filename.c_str(), std::ios::binary); return predictor;
int version = 1;
serialize(predictor, fout);
serialize(version, fout);
if (options.be_verbose)
std::cout << "Training complete, saved predictor to file " << predictor_output_filename << std::endl;
} }
inline void train_shape_predictor ( inline void train_shape_predictor (
...@@ -135,7 +127,15 @@ namespace dlib ...@@ -135,7 +127,15 @@ namespace dlib
std::vector<std::vector<full_object_detection> > objects; std::vector<std::vector<full_object_detection> > objects;
load_image_dataset(images, objects, dataset_filename); load_image_dataset(images, objects, dataset_filename);
train_shape_predictor_on_images(dataset_filename, images, objects, predictor_output_filename, options); shape_predictor predictor = train_shape_predictor_on_images(images, objects, options);
std::ofstream fout(predictor_output_filename.c_str(), std::ios::binary);
int version = 1;
serialize(predictor, fout);
serialize(version, fout);
if (options.be_verbose)
std::cout << "Training complete, saved predictor to file " << predictor_output_filename << std::endl;
} }
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
...@@ -145,7 +145,7 @@ namespace dlib ...@@ -145,7 +145,7 @@ namespace dlib
image_array& images, image_array& images,
std::vector<std::vector<full_object_detection> >& detections, std::vector<std::vector<full_object_detection> >& detections,
std::vector<std::vector<double> >& scales, std::vector<std::vector<double> >& scales,
const std::string& predictor_filename const shape_predictor& predictor
) )
{ {
if (images.size() != detections.size()) if (images.size() != detections.size())
...@@ -153,16 +153,6 @@ namespace dlib ...@@ -153,16 +153,6 @@ namespace dlib
if (scales.size() > 0 && scales.size() != images.size()) if (scales.size() > 0 && scales.size() != images.size())
throw error("The list of scales must have the same length as the list of detections."); throw error("The list of scales must have the same length as the list of detections.");
shape_predictor predictor;
int version = 0;
std::ifstream fin(predictor_filename.c_str(), std::ios::binary);
if (!fin)
throw error("Unable to open file " + predictor_filename);
deserialize(predictor, fin);
deserialize(version, fin);
if (version != 1)
throw error("Unknown shape_predictor format.");
if (scales.size() > 0) if (scales.size() > 0)
return test_shape_predictor(predictor, images, detections, scales); return test_shape_predictor(predictor, images, detections, scales);
else else
...@@ -174,13 +164,25 @@ namespace dlib ...@@ -174,13 +164,25 @@ namespace dlib
const std::string& predictor_filename const std::string& predictor_filename
) )
{ {
// Load the images, no scales can be provided
dlib::array<array2d<rgb_pixel> > images; dlib::array<array2d<rgb_pixel> > images;
// This interface cannot take the scales parameter. // This interface cannot take the scales parameter.
std::vector<std::vector<double> > scales; std::vector<std::vector<double> > scales;
std::vector<std::vector<full_object_detection> > objects; std::vector<std::vector<full_object_detection> > objects;
load_image_dataset(images, objects, dataset_filename); load_image_dataset(images, objects, dataset_filename);
return test_shape_predictor_with_images(images, objects, scales, predictor_filename); // Load the shape predictor
shape_predictor predictor;
int version = 0;
std::ifstream fin(predictor_filename.c_str(), std::ios::binary);
if (!fin)
throw error("Unable to open file " + predictor_filename);
deserialize(predictor, fin);
deserialize(version, fin);
if (version != 1)
throw error("Unknown shape_predictor format.");
return test_shape_predictor_with_images(images, objects, scales, predictor);
} }
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
......
...@@ -127,12 +127,11 @@ namespace dlib ...@@ -127,12 +127,11 @@ namespace dlib
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
template <typename image_array> template <typename image_array>
inline void train_simple_object_detector_on_images ( inline simple_object_detector train_simple_object_detector_on_images (
const std::string& dataset_filename, // can be "" if it's not applicable const std::string& dataset_filename, // can be "" if it's not applicable
image_array& images, image_array& images,
std::vector<std::vector<rectangle> >& boxes, std::vector<std::vector<rectangle> >& boxes,
std::vector<std::vector<rectangle> >& ignore, std::vector<std::vector<rectangle> >& ignore,
const std::string& detector_output_filename,
const simple_object_detector_training_options& options const simple_object_detector_training_options& options
) )
{ {
...@@ -169,7 +168,6 @@ namespace dlib ...@@ -169,7 +168,6 @@ namespace dlib
trainer.be_verbose(); trainer.be_verbose();
} }
unsigned long upsample_amount = 0; unsigned long upsample_amount = 0;
// now make sure all the boxes are obtainable by the scanner. We will try and // now make sure all the boxes are obtainable by the scanner. We will try and
...@@ -194,14 +192,9 @@ namespace dlib ...@@ -194,14 +192,9 @@ namespace dlib
simple_object_detector detector = trainer.train(images, boxes, ignore); simple_object_detector detector = trainer.train(images, boxes, ignore);
std::ofstream fout(detector_output_filename.c_str(), std::ios::binary);
int version = 1;
serialize(detector, fout);
serialize(version, fout);
if (options.be_verbose) if (options.be_verbose)
{ {
std::cout << "Training complete, saved detector to file " << detector_output_filename << std::endl; std::cout << "Training complete." << std::endl;
std::cout << "Trained with C: " << options.C << std::endl; std::cout << "Trained with C: " << options.C << std::endl;
std::cout << "Training with epsilon: " << options.epsilon << std::endl; std::cout << "Training with epsilon: " << options.epsilon << std::endl;
std::cout << "Trained using " << options.num_threads << " threads."<< std::endl; std::cout << "Trained using " << options.num_threads << " threads."<< std::endl;
...@@ -216,6 +209,8 @@ namespace dlib ...@@ -216,6 +209,8 @@ namespace dlib
if (options.add_left_right_image_flips) if (options.add_left_right_image_flips)
std::cout << "Trained on both left and right flipped versions of images." << std::endl; std::cout << "Trained on both left and right flipped versions of images." << std::endl;
} }
return detector;
} }
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
...@@ -230,7 +225,15 @@ namespace dlib ...@@ -230,7 +225,15 @@ namespace dlib
std::vector<std::vector<rectangle> > boxes, ignore; std::vector<std::vector<rectangle> > boxes, ignore;
ignore = load_image_dataset(images, boxes, dataset_filename); ignore = load_image_dataset(images, boxes, dataset_filename);
train_simple_object_detector_on_images(dataset_filename, images, boxes, ignore, detector_output_filename, options); simple_object_detector detector = train_simple_object_detector_on_images(dataset_filename, images, boxes, ignore, options);
std::ofstream fout(detector_output_filename.c_str(), std::ios::binary);
int version = 1;
serialize(detector, fout);
serialize(version, fout);
if (options.be_verbose)
std::cout << "Saved detector to file " << detector_output_filename << std::endl;
} }
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
...@@ -245,21 +248,14 @@ namespace dlib ...@@ -245,21 +248,14 @@ namespace dlib
template <typename image_array> template <typename image_array>
inline const simple_test_results test_simple_object_detector_with_images ( inline const simple_test_results test_simple_object_detector_with_images (
image_array& images, image_array& images,
const unsigned int upsample_amount,
std::vector<std::vector<rectangle> >& boxes, std::vector<std::vector<rectangle> >& boxes,
std::vector<std::vector<rectangle> >& ignore, std::vector<std::vector<rectangle> >& ignore,
const std::string& detector_filename simple_object_detector& detector
) )
{ {
simple_object_detector detector; for (unsigned int i = 0; i < upsample_amount; ++i)
int version = 0; upsample_image_dataset<pyramid_down<2> >(images, boxes);
unsigned int upsample_amount = 0;
std::ifstream fin(detector_filename.c_str(), std::ios::binary);
if (!fin)
throw error("Unable to open file " + detector_filename);
deserialize(detector, fin);
deserialize(version, fin);
if (version != 1)
throw error("Unknown simple_object_detector format.");
matrix<double,1,3> res = test_object_detection_function(detector, images, boxes, ignore); matrix<double,1,3> res = test_object_detection_function(detector, images, boxes, ignore);
simple_test_results ret; simple_test_results ret;
...@@ -271,14 +267,27 @@ namespace dlib ...@@ -271,14 +267,27 @@ namespace dlib
inline const simple_test_results test_simple_object_detector ( inline const simple_test_results test_simple_object_detector (
const std::string& dataset_filename, const std::string& dataset_filename,
const std::string& detector_filename const std::string& detector_filename,
const unsigned int upsample_amount
) )
{ {
// Load all the testing images
dlib::array<array2d<rgb_pixel> > images; dlib::array<array2d<rgb_pixel> > images;
std::vector<std::vector<rectangle> > boxes, ignore; std::vector<std::vector<rectangle> > boxes, ignore;
ignore = load_image_dataset(images, boxes, dataset_filename); ignore = load_image_dataset(images, boxes, dataset_filename);
return test_simple_object_detector_with_images(images, boxes, ignore, detector_filename); // Load the detector off disk
simple_object_detector detector;
int version = 0;
std::ifstream fin(detector_filename.c_str(), std::ios::binary);
if (!fin)
throw error("Unable to open file " + detector_filename);
deserialize(detector, fin);
deserialize(version, fin);
if (version != 1)
throw error("Unknown simple_object_detector format.");
return test_simple_object_detector_with_images(images, upsample_amount, boxes, ignore, detector);
} }
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment