Commit dd19ce84 authored by Patrick Snape's avatar Patrick Snape

Update the interface to be more Pythonic

This is the biggest change so far. Now, there are two different
classes of interface. One where you pass ONLY file paths,
and one where you pass ONLY Python objects.

The file paths are maintained to keep a matching interface with
the C++ examples of dlib. So shape predicition and object
detection can be trained using the dlib XML file paths and then
serialize the detectors to disk.

Shape prediction and object detection can also be trained using
numpy arrays and in-memory objects. In this case, the predictor
and detector objects are returned from the training functions.
To facilitate serializing these objects, they now have a 'save'
method.

Tetsing follows a similar pattern, in that it can take either XML
files are or in-memory objects. I also added back the concept of
upsampling during testing to make amends for removing the
simple_object_detector_py struct.
parent 8db3f4e5
This diff is collapsed.
......@@ -84,10 +84,9 @@ boost::shared_ptr<full_object_detection> full_obj_det_init(object& pyrect, objec
// ----------------------------------------------------------------------------------------
inline void train_shape_predictor_on_images_py (
const object& pyimages,
const object& pydetections,
const std::string& predictor_output_filename,
inline shape_predictor train_shape_predictor_on_images_py (
const boost::python::list& pyimages,
const boost::python::list& pydetections,
const shape_predictor_training_options& options
)
{
......@@ -99,15 +98,15 @@ inline void train_shape_predictor_on_images_py (
dlib::array<array2d<rgb_pixel> > images(num_images);
images_and_nested_params_to_dlib(pyimages, pydetections, images, detections);
train_shape_predictor_on_images("", images, detections, predictor_output_filename, options);
return train_shape_predictor_on_images(images, detections, options);
}
inline double test_shape_predictor_with_images_py (
const object& pyimages,
const object& pydetections,
const object& pyscales,
const std::string& predictor_filename
const boost::python::list& pyimages,
const boost::python::list& pydetections,
const boost::python::list& pyscales,
const shape_predictor& predictor
)
{
const unsigned long num_images = len(pyimages);
......@@ -141,17 +140,17 @@ inline double test_shape_predictor_with_images_py (
}
}
return test_shape_predictor_with_images(images, detections, scales, predictor_filename);
return test_shape_predictor_with_images(images, detections, scales, predictor);
}
inline double test_shape_predictor_with_images_no_scales_py (
const object& pyimages,
const object& pydetections,
const std::string& predictor_filename
const boost::python::list& pyimages,
const boost::python::list& pydetections,
const shape_predictor& predictor
)
{
boost::python::list pyscales;
return test_shape_predictor_with_images_py(pyimages, pydetections, pyscales, predictor_filename);
return test_shape_predictor_with_images_py(pyimages, pydetections, pyscales, predictor);
}
// ----------------------------------------------------------------------------------------
......@@ -239,7 +238,7 @@ ensures \n\
}
{
def("train_shape_predictor", train_shape_predictor_on_images_py,
(arg("images"), arg("object_detections"), arg("predictor_output_filename"), arg("options")),
(arg("images"), arg("object_detections"), arg("options")),
"requires \n\
- options.lambda > 0 \n\
- options.nu > 0 \n\
......@@ -255,7 +254,7 @@ ensures \n\
preprocessing techniques to the training procedure for shape_predictors \n\
objects. So the point of this function is to provide you with a very easy \n\
way to train a basic shape predictor. \n\
- The trained shape predictor is serialized to the file predictor_output_filename.");
- The trained shape_predictor is returned");
def("train_shape_predictor", train_shape_predictor,
(arg("dataset_filename"), arg("predictor_output_filename"), arg("options")),
......@@ -289,15 +288,14 @@ ensures \n\
for shape_predictor_trainer() for a detailed definition of the mean average error.");
def("test_shape_predictor", test_shape_predictor_with_images_no_scales_py,
(arg("images"), arg("detections"), arg("predictor_filename")),
(arg("images"), arg("detections"), arg("shape_predictor")),
"requires \n\
- len(images) == len(object_detections) \n\
- images should be a list of numpy matrices that represent images, either RGB or grayscale. \n\
- object_detections should be a list of lists of dlib.full_object_detection objects. \
Each dlib.full_object_detection contains the bounding box and the lists of points that make up the object parts.\n\
ensures \n\
- Loads a shape_predictor from the file predictor_filename. This means \n\
predictor_filename should be a file produced by the train_shape_predictor() \n\
- shape_predictor should be a file produced by the train_shape_predictor() \n\
routine. \n\
- This function tests the predictor against the dataset and returns the \n\
mean average error of the detector. In fact, The \n\
......@@ -307,7 +305,7 @@ ensures \n\
def("test_shape_predictor", test_shape_predictor_with_images_py,
(arg("images"), arg("detections"), arg("scales"), arg("predictor_filename")),
(arg("images"), arg("detections"), arg("scales"), arg("shape_predictor")),
"requires \n\
- len(images) == len(object_detections) \n\
- len(object_detections) == len(scales) \n\
......@@ -318,8 +316,7 @@ ensures \n\
- object_detections should be a list of lists of dlib.full_object_detection objects. \
Each dlib.full_object_detection contains the bounding box and the lists of points that make up the object parts.\n\
ensures \n\
- Loads a shape_predictor from the file predictor_filename. This means \n\
predictor_filename should be a file produced by the train_shape_predictor() \n\
- shape_predictor should be a file produced by the train_shape_predictor() \n\
routine. \n\
- This function tests the predictor against the dataset and returns the \n\
mean average error of the detector. In fact, The \n\
......
......@@ -65,11 +65,9 @@ namespace dlib
// ----------------------------------------------------------------------------------------
template <typename image_array>
inline void train_shape_predictor_on_images (
const std::string& dataset_filename, // can be "" if it's not applicable
inline shape_predictor train_shape_predictor_on_images (
image_array& images,
std::vector<std::vector<full_object_detection> >& detections,
const std::string& predictor_output_filename,
const shape_predictor_training_options& options
)
{
......@@ -116,13 +114,7 @@ namespace dlib
shape_predictor predictor = trainer.train(images, detections);
std::ofstream fout(predictor_output_filename.c_str(), std::ios::binary);
int version = 1;
serialize(predictor, fout);
serialize(version, fout);
if (options.be_verbose)
std::cout << "Training complete, saved predictor to file " << predictor_output_filename << std::endl;
return predictor;
}
inline void train_shape_predictor (
......@@ -135,7 +127,15 @@ namespace dlib
std::vector<std::vector<full_object_detection> > objects;
load_image_dataset(images, objects, dataset_filename);
train_shape_predictor_on_images(dataset_filename, images, objects, predictor_output_filename, options);
shape_predictor predictor = train_shape_predictor_on_images(images, objects, options);
std::ofstream fout(predictor_output_filename.c_str(), std::ios::binary);
int version = 1;
serialize(predictor, fout);
serialize(version, fout);
if (options.be_verbose)
std::cout << "Training complete, saved predictor to file " << predictor_output_filename << std::endl;
}
// ----------------------------------------------------------------------------------------
......@@ -145,7 +145,7 @@ namespace dlib
image_array& images,
std::vector<std::vector<full_object_detection> >& detections,
std::vector<std::vector<double> >& scales,
const std::string& predictor_filename
const shape_predictor& predictor
)
{
if (images.size() != detections.size())
......@@ -153,16 +153,6 @@ namespace dlib
if (scales.size() > 0 && scales.size() != images.size())
throw error("The list of scales must have the same length as the list of detections.");
shape_predictor predictor;
int version = 0;
std::ifstream fin(predictor_filename.c_str(), std::ios::binary);
if (!fin)
throw error("Unable to open file " + predictor_filename);
deserialize(predictor, fin);
deserialize(version, fin);
if (version != 1)
throw error("Unknown shape_predictor format.");
if (scales.size() > 0)
return test_shape_predictor(predictor, images, detections, scales);
else
......@@ -174,13 +164,25 @@ namespace dlib
const std::string& predictor_filename
)
{
// Load the images, no scales can be provided
dlib::array<array2d<rgb_pixel> > images;
// This interface cannot take the scales parameter.
std::vector<std::vector<double> > scales;
std::vector<std::vector<full_object_detection> > objects;
load_image_dataset(images, objects, dataset_filename);
return test_shape_predictor_with_images(images, objects, scales, predictor_filename);
// Load the shape predictor
shape_predictor predictor;
int version = 0;
std::ifstream fin(predictor_filename.c_str(), std::ios::binary);
if (!fin)
throw error("Unable to open file " + predictor_filename);
deserialize(predictor, fin);
deserialize(version, fin);
if (version != 1)
throw error("Unknown shape_predictor format.");
return test_shape_predictor_with_images(images, objects, scales, predictor);
}
// ----------------------------------------------------------------------------------------
......
......@@ -127,12 +127,11 @@ namespace dlib
// ----------------------------------------------------------------------------------------
template <typename image_array>
inline void train_simple_object_detector_on_images (
inline simple_object_detector train_simple_object_detector_on_images (
const std::string& dataset_filename, // can be "" if it's not applicable
image_array& images,
std::vector<std::vector<rectangle> >& boxes,
std::vector<std::vector<rectangle> >& ignore,
const std::string& detector_output_filename,
const simple_object_detector_training_options& options
)
{
......@@ -169,7 +168,6 @@ namespace dlib
trainer.be_verbose();
}
unsigned long upsample_amount = 0;
// now make sure all the boxes are obtainable by the scanner. We will try and
......@@ -194,14 +192,9 @@ namespace dlib
simple_object_detector detector = trainer.train(images, boxes, ignore);
std::ofstream fout(detector_output_filename.c_str(), std::ios::binary);
int version = 1;
serialize(detector, fout);
serialize(version, fout);
if (options.be_verbose)
{
std::cout << "Training complete, saved detector to file " << detector_output_filename << std::endl;
std::cout << "Training complete." << std::endl;
std::cout << "Trained with C: " << options.C << std::endl;
std::cout << "Training with epsilon: " << options.epsilon << std::endl;
std::cout << "Trained using " << options.num_threads << " threads."<< std::endl;
......@@ -216,6 +209,8 @@ namespace dlib
if (options.add_left_right_image_flips)
std::cout << "Trained on both left and right flipped versions of images." << std::endl;
}
return detector;
}
// ----------------------------------------------------------------------------------------
......@@ -230,7 +225,15 @@ namespace dlib
std::vector<std::vector<rectangle> > boxes, ignore;
ignore = load_image_dataset(images, boxes, dataset_filename);
train_simple_object_detector_on_images(dataset_filename, images, boxes, ignore, detector_output_filename, options);
simple_object_detector detector = train_simple_object_detector_on_images(dataset_filename, images, boxes, ignore, options);
std::ofstream fout(detector_output_filename.c_str(), std::ios::binary);
int version = 1;
serialize(detector, fout);
serialize(version, fout);
if (options.be_verbose)
std::cout << "Saved detector to file " << detector_output_filename << std::endl;
}
// ----------------------------------------------------------------------------------------
......@@ -245,21 +248,14 @@ namespace dlib
template <typename image_array>
inline const simple_test_results test_simple_object_detector_with_images (
image_array& images,
const unsigned int upsample_amount,
std::vector<std::vector<rectangle> >& boxes,
std::vector<std::vector<rectangle> >& ignore,
const std::string& detector_filename
simple_object_detector& detector
)
{
simple_object_detector detector;
int version = 0;
unsigned int upsample_amount = 0;
std::ifstream fin(detector_filename.c_str(), std::ios::binary);
if (!fin)
throw error("Unable to open file " + detector_filename);
deserialize(detector, fin);
deserialize(version, fin);
if (version != 1)
throw error("Unknown simple_object_detector format.");
for (unsigned int i = 0; i < upsample_amount; ++i)
upsample_image_dataset<pyramid_down<2> >(images, boxes);
matrix<double,1,3> res = test_object_detection_function(detector, images, boxes, ignore);
simple_test_results ret;
......@@ -271,14 +267,27 @@ namespace dlib
inline const simple_test_results test_simple_object_detector (
const std::string& dataset_filename,
const std::string& detector_filename
const std::string& detector_filename,
const unsigned int upsample_amount
)
{
// Load all the testing images
dlib::array<array2d<rgb_pixel> > images;
std::vector<std::vector<rectangle> > boxes, ignore;
ignore = load_image_dataset(images, boxes, dataset_filename);
return test_simple_object_detector_with_images(images, boxes, ignore, detector_filename);
// Load the detector off disk
simple_object_detector detector;
int version = 0;
std::ifstream fin(detector_filename.c_str(), std::ios::binary);
if (!fin)
throw error("Unable to open file " + detector_filename);
deserialize(detector, fin);
deserialize(version, fin);
if (version != 1)
throw error("Unknown simple_object_detector format.");
return test_simple_object_detector_with_images(images, upsample_amount, boxes, ignore, detector);
}
// ----------------------------------------------------------------------------------------
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment