Commit dd19ce84 authored by Patrick Snape's avatar Patrick Snape

Update the interface to be more Pythonic

This is the biggest change so far. Now, there are two different
classes of interface. One where you pass ONLY file paths,
and one where you pass ONLY Python objects.

The file paths are maintained to keep a matching interface with
the C++ examples of dlib. So shape predicition and object
detection can be trained using the dlib XML file paths and then
serialize the detectors to disk.

Shape prediction and object detection can also be trained using
numpy arrays and in-memory objects. In this case, the predictor
and detector objects are returned from the training functions.
To facilitate serializing these objects, they now have a 'save'
method.

Tetsing follows a similar pattern, in that it can take either XML
files are or in-memory objects. I also added back the concept of
upsampling during testing to make amends for removing the
simple_object_detector_py struct.
parent 8db3f4e5
...@@ -118,10 +118,9 @@ void save_simple_object_detector(const simple_object_detector& detector, const s ...@@ -118,10 +118,9 @@ void save_simple_object_detector(const simple_object_detector& detector, const s
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
inline void train_simple_object_detector_on_images_py ( inline simple_object_detector train_simple_object_detector_on_images_py (
const object& pyimages, const boost::python::list& pyimages,
const object& pyboxes, const boost::python::list& pyboxes,
const std::string& detector_output_filename,
const simple_object_detector_training_options& options const simple_object_detector_training_options& options
) )
{ {
...@@ -134,13 +133,14 @@ inline void train_simple_object_detector_on_images_py ( ...@@ -134,13 +133,14 @@ inline void train_simple_object_detector_on_images_py (
dlib::array<array2d<rgb_pixel> > images(num_images); dlib::array<array2d<rgb_pixel> > images(num_images);
images_and_nested_params_to_dlib(pyimages, pyboxes, images, boxes); images_and_nested_params_to_dlib(pyimages, pyboxes, images, boxes);
train_simple_object_detector_on_images("", images, boxes, ignore, detector_output_filename, options); return train_simple_object_detector_on_images("", images, boxes, ignore, options);
} }
inline simple_test_results test_simple_object_detector_with_images_py ( inline simple_test_results test_simple_object_detector_with_images_py (
const object& pyimages, const boost::python::list& pyimages,
const object& pyboxes, const boost::python::list& pyboxes,
const std::string& detector_filename simple_object_detector& detector,
const unsigned int unsample_amount
) )
{ {
const unsigned long num_images = len(pyimages); const unsigned long num_images = len(pyimages);
...@@ -152,7 +152,7 @@ inline simple_test_results test_simple_object_detector_with_images_py ( ...@@ -152,7 +152,7 @@ inline simple_test_results test_simple_object_detector_with_images_py (
dlib::array<array2d<rgb_pixel> > images(num_images); dlib::array<array2d<rgb_pixel> > images(num_images);
images_and_nested_params_to_dlib(pyimages, pyboxes, images, boxes); images_and_nested_params_to_dlib(pyimages, pyboxes, images, boxes);
return test_simple_object_detector_with_images(images, boxes, ignore, detector_filename); return test_simple_object_detector_with_images(images, unsample_amount, boxes, ignore, detector);
} }
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
...@@ -165,19 +165,12 @@ void bind_object_detection() ...@@ -165,19 +165,12 @@ void bind_object_detection()
"This object is a container for the options to the train_simple_object_detector() routine.") "This object is a container for the options to the train_simple_object_detector() routine.")
.add_property("be_verbose", &simple_object_detector_training_options::be_verbose, .add_property("be_verbose", &simple_object_detector_training_options::be_verbose,
&simple_object_detector_training_options::be_verbose, &simple_object_detector_training_options::be_verbose,
"If true, train_simple_object_detector() will print out a lot of information to the screen while training." "If true, train_simple_object_detector() will print out a lot of information to the screen while training.")
)
.add_property("add_left_right_image_flips", &simple_object_detector_training_options::add_left_right_image_flips, .add_property("add_left_right_image_flips", &simple_object_detector_training_options::add_left_right_image_flips,
&simple_object_detector_training_options::add_left_right_image_flips, &simple_object_detector_training_options::add_left_right_image_flips,
"if true, train_simple_object_detector() will assume the objects are \n\ "if true, train_simple_object_detector() will assume the objects are \n\
left/right symmetric and add in left right flips of the training \n\ left/right symmetric and add in left right flips of the training \n\
images. This doubles the size of the training dataset." images. This doubles the size of the training dataset.")
/*!
if true, train_simple_object_detector() will assume the objects are
left/right symmetric and add in left right flips of the training
images. This doubles the size of the training dataset.
!*/
)
.add_property("detection_window_size", &simple_object_detector_training_options::detection_window_size, .add_property("detection_window_size", &simple_object_detector_training_options::detection_window_size,
&simple_object_detector_training_options::detection_window_size, &simple_object_detector_training_options::detection_window_size,
"The sliding window used will have about this many pixels inside it.") "The sliding window used will have about this many pixels inside it.")
...@@ -187,46 +180,22 @@ images. This doubles the size of the training dataset." ...@@ -187,46 +180,22 @@ images. This doubles the size of the training dataset."
structural_object_detection_trainer::set_c(). Larger values of C \n\ structural_object_detection_trainer::set_c(). Larger values of C \n\
will encourage the trainer to fit the data better but might lead to \n\ will encourage the trainer to fit the data better but might lead to \n\
overfitting. Therefore, you must determine the proper setting of \n\ overfitting. Therefore, you must determine the proper setting of \n\
this parameter experimentally." this parameter experimentally.")
/*!
C is the usual SVM C regularization parameter. So it is passed to
structural_object_detection_trainer::set_c(). Larger values of C
will encourage the trainer to fit the data better but might lead to
overfitting. Therefore, you must determine the proper setting of
this parameter experimentally.
!*/
)
.add_property("epsilon", &simple_object_detector_training_options::epsilon, .add_property("epsilon", &simple_object_detector_training_options::epsilon,
&simple_object_detector_training_options::epsilon, &simple_object_detector_training_options::epsilon,
"epsilon is the stopping epsilon. Smaller values make the trainer's \n\ "epsilon is the stopping epsilon. Smaller values make the trainer's \n\
solver more accurate but might take longer to train." solver more accurate but might take longer to train.")
/*!
epsilon is the stopping epsilon. Smaller values make the trainer's
solver more accurate but might take longer to train.
!*/
)
.add_property("num_threads", &simple_object_detector_training_options::num_threads, .add_property("num_threads", &simple_object_detector_training_options::num_threads,
&simple_object_detector_training_options::num_threads, &simple_object_detector_training_options::num_threads,
"train_simple_object_detector() will use this many threads of \n\ "train_simple_object_detector() will use this many threads of \n\
execution. Set this to the number of CPU cores on your machine to \n\ execution. Set this to the number of CPU cores on your machine to \n\
obtain the fastest training speed." obtain the fastest training speed.");
/*!
train_simple_object_detector() will use this many threads of
execution. Set this to the number of CPU cores on your machine to
obtain the fastest training speed.
!*/
);
class_<simple_test_results>("simple_test_results") class_<simple_test_results>("simple_test_results")
.add_property("precision", &simple_test_results::precision) .add_property("precision", &simple_test_results::precision)
.add_property("recall", &simple_test_results::recall) .add_property("recall", &simple_test_results::recall)
.add_property("average_precision", &simple_test_results::average_precision) .add_property("average_precision", &simple_test_results::average_precision)
.def("__str__", &::print_simple_test_results); .def("__str__", &::print_simple_test_results);
{ {
typedef rectangle type; typedef rectangle type;
class_<type>("rectangle", "This object represents a rectangular area of an image.") class_<type>("rectangle", "This object represents a rectangular area of an image.")
...@@ -258,25 +227,10 @@ ensures \n\ ...@@ -258,25 +227,10 @@ ensures \n\
preprocessing techniques to the training procedure for simple_object_detector \n\ preprocessing techniques to the training procedure for simple_object_detector \n\
objects. So the point of this function is to provide you with a very easy \n\ objects. So the point of this function is to provide you with a very easy \n\
way to train a basic object detector. \n\ way to train a basic object detector. \n\
- The trained object detector is serialized to the file detector_output_filename." - The trained object detector is serialized to the file detector_output_filename.");
/*!
requires
- options.C > 0
ensures
- Uses the structural_object_detection_trainer to train a
simple_object_detector based on the labeled images in the XML file
dataset_filename. This function assumes the file dataset_filename is in the
XML format produced by dlib's save_image_dataset_metadata() routine.
- This function will apply a reasonable set of default parameters and
preprocessing techniques to the training procedure for simple_object_detector
objects. So the point of this function is to provide you with a very easy
way to train a basic object detector.
- The trained object detector is serialized to the file detector_output_filename.
!*/
);
def("train_simple_object_detector", train_simple_object_detector_on_images_py, def("train_simple_object_detector", train_simple_object_detector_on_images_py,
(arg("images"), arg("boxes"), arg("detector_output_filename"), arg("options")), (arg("images"), arg("boxes"), arg("options")),
"requires \n\ "requires \n\
- options.C > 0 \n\ - options.C > 0 \n\
- len(images) == len(boxes) \n\ - len(images) == len(boxes) \n\
...@@ -289,28 +243,13 @@ ensures \n\ ...@@ -289,28 +243,13 @@ ensures \n\
preprocessing techniques to the training procedure for simple_object_detector \n\ preprocessing techniques to the training procedure for simple_object_detector \n\
objects. So the point of this function is to provide you with a very easy \n\ objects. So the point of this function is to provide you with a very easy \n\
way to train a basic object detector. \n\ way to train a basic object detector. \n\
- The trained object detector is serialized to the file detector_output_filename." - The trained object detector is returned.");
/*!
requires
- options.C > 0
- len(images) == len(boxes)
- images should be a list of numpy matrices that represent images, either RGB or grayscale.
- boxes should be a dlib.rectangles object (i.e. an array of rectangles).
- boxes should be a list of lists of dlib.rectangle object.
ensures
- Uses the structural_object_detection_trainer to train a
simple_object_detector based on the labeled images and bounding boxes.
- This function will apply a reasonable set of default parameters and
preprocessing techniques to the training procedure for simple_object_detector
objects. So the point of this function is to provide you with a very easy
way to train a basic object detector.
- The trained object detector is serialized to the file detector_output_filename.
!*/
);
def("test_simple_object_detector", test_simple_object_detector, def("test_simple_object_detector", test_simple_object_detector,
(arg("dataset_filename"), arg("detector_filename")), (arg("dataset_filename"), arg("detector_filename"), arg("upsample_amount")=0),
"ensures \n\ "requires \n\
- Optionally, take the number of times to upsample the testing images. \n\
ensures \n\
- Loads an image dataset from dataset_filename. We assume dataset_filename is \n\ - Loads an image dataset from dataset_filename. We assume dataset_filename is \n\
a file using the XML format written by save_image_dataset_metadata(). \n\ a file using the XML format written by save_image_dataset_metadata(). \n\
- Loads a simple_object_detector from the file detector_filename. This means \n\ - Loads a simple_object_detector from the file detector_filename. This means \n\
...@@ -322,28 +261,15 @@ ensures \n\ ...@@ -322,28 +261,15 @@ ensures \n\
test_object_detection_function() routine. Therefore, see the documentation \n\ test_object_detection_function() routine. Therefore, see the documentation \n\
for test_object_detection_function() for a detailed definition of these \n\ for test_object_detection_function() for a detailed definition of these \n\
metrics. " metrics. "
/*!
ensures
- Loads an image dataset from dataset_filename. We assume dataset_filename is
a file using the XML format written by save_image_dataset_metadata().
- Loads a simple_object_detector from the file detector_filename. This means
detector_filename should be a file produced by the train_simple_object_detector()
routine.
- This function tests the detector against the dataset and returns the
precision, recall, and average precision of the detector. In fact, The
return value of this function is identical to that of dlib's
test_object_detection_function() routine. Therefore, see the documentation
for test_object_detection_function() for a detailed definition of these
metrics.
!*/
); );
def("test_simple_object_detector", test_simple_object_detector_with_images_py, def("test_simple_object_detector", test_simple_object_detector_with_images_py,
(arg("images"), arg("boxes"), arg("detector_filename")), (arg("images"), arg("boxes"), arg("detector"), arg("upsample_amount")=0),
"requires \n\ "requires \n\
- len(images) == len(boxes) \n\ - len(images) == len(boxes) \n\
- images should be a list of numpy matrices that represent images, either RGB or grayscale. \n\ - images should be a list of numpy matrices that represent images, either RGB or grayscale. \n\
- boxes should be a list of lists of dlib.rectangle object. \n\ - boxes should be a list of lists of dlib.rectangle object. \n\
- Optionally, take the number of times to upsample the testing images. \n\
ensures \n\ ensures \n\
- Loads a simple_object_detector from the file detector_filename. This means \n\ - Loads a simple_object_detector from the file detector_filename. This means \n\
detector_filename should be a file produced by the train_simple_object_detector() \n\ detector_filename should be a file produced by the train_simple_object_detector() \n\
...@@ -361,11 +287,7 @@ ensures \n\ ...@@ -361,11 +287,7 @@ ensures \n\
"This object represents a sliding window histogram-of-oriented-gradients based object detector.") "This object represents a sliding window histogram-of-oriented-gradients based object detector.")
.def("__init__", make_constructor(&load_object_from_file<type>), .def("__init__", make_constructor(&load_object_from_file<type>),
"Loads a simple_object_detector from a file that contains the output of the \n\ "Loads a simple_object_detector from a file that contains the output of the \n\
train_simple_object_detector() routine." train_simple_object_detector() routine.")
/*!
Loads a simple_object_detector from a file that contains the output of the
train_simple_object_detector() routine.
!*/)
.def("__call__", run_detector_with_upscale, (arg("image"), arg("upsample_num_times")=0), .def("__call__", run_detector_with_upscale, (arg("image"), arg("upsample_num_times")=0),
"requires \n\ "requires \n\
- image is a numpy ndarray containing either an 8bit grayscale or RGB \n\ - image is a numpy ndarray containing either an 8bit grayscale or RGB \n\
...@@ -377,21 +299,7 @@ ensures \n\ ...@@ -377,21 +299,7 @@ ensures \n\
- Upsamples the image upsample_num_times before running the basic \n\ - Upsamples the image upsample_num_times before running the basic \n\
detector. If you don't know how many times you want to upsample then \n\ detector. If you don't know how many times you want to upsample then \n\
don't provide a value for upsample_num_times and an appropriate \n\ don't provide a value for upsample_num_times and an appropriate \n\
default will be used." default will be used.")
/*!
requires
- image is a numpy ndarray containing either an 8bit grayscale or RGB
image.
- upsample_num_times >= 0
ensures
- This function runs the object detector on the input image and returns
a list of detections.
- Upsamples the image upsample_num_times before running the basic
detector. If you don't know how many times you want to upsample then
don't provide a value for upsample_num_times and an appropriate
default will be used.
!*/
)
.def("save", save_simple_object_detector, (arg("detector_output_filename")), "Save a simple_object_detector to the provided path.") .def("save", save_simple_object_detector, (arg("detector_output_filename")), "Save a simple_object_detector to the provided path.")
.def_pickle(serialize_pickle<type>()); .def_pickle(serialize_pickle<type>());
} }
...@@ -406,5 +314,3 @@ ensures \n\ ...@@ -406,5 +314,3 @@ ensures \n\
} }
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
...@@ -84,10 +84,9 @@ boost::shared_ptr<full_object_detection> full_obj_det_init(object& pyrect, objec ...@@ -84,10 +84,9 @@ boost::shared_ptr<full_object_detection> full_obj_det_init(object& pyrect, objec
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
inline void train_shape_predictor_on_images_py ( inline shape_predictor train_shape_predictor_on_images_py (
const object& pyimages, const boost::python::list& pyimages,
const object& pydetections, const boost::python::list& pydetections,
const std::string& predictor_output_filename,
const shape_predictor_training_options& options const shape_predictor_training_options& options
) )
{ {
...@@ -99,15 +98,15 @@ inline void train_shape_predictor_on_images_py ( ...@@ -99,15 +98,15 @@ inline void train_shape_predictor_on_images_py (
dlib::array<array2d<rgb_pixel> > images(num_images); dlib::array<array2d<rgb_pixel> > images(num_images);
images_and_nested_params_to_dlib(pyimages, pydetections, images, detections); images_and_nested_params_to_dlib(pyimages, pydetections, images, detections);
train_shape_predictor_on_images("", images, detections, predictor_output_filename, options); return train_shape_predictor_on_images(images, detections, options);
} }
inline double test_shape_predictor_with_images_py ( inline double test_shape_predictor_with_images_py (
const object& pyimages, const boost::python::list& pyimages,
const object& pydetections, const boost::python::list& pydetections,
const object& pyscales, const boost::python::list& pyscales,
const std::string& predictor_filename const shape_predictor& predictor
) )
{ {
const unsigned long num_images = len(pyimages); const unsigned long num_images = len(pyimages);
...@@ -141,17 +140,17 @@ inline double test_shape_predictor_with_images_py ( ...@@ -141,17 +140,17 @@ inline double test_shape_predictor_with_images_py (
} }
} }
return test_shape_predictor_with_images(images, detections, scales, predictor_filename); return test_shape_predictor_with_images(images, detections, scales, predictor);
} }
inline double test_shape_predictor_with_images_no_scales_py ( inline double test_shape_predictor_with_images_no_scales_py (
const object& pyimages, const boost::python::list& pyimages,
const object& pydetections, const boost::python::list& pydetections,
const std::string& predictor_filename const shape_predictor& predictor
) )
{ {
boost::python::list pyscales; boost::python::list pyscales;
return test_shape_predictor_with_images_py(pyimages, pydetections, pyscales, predictor_filename); return test_shape_predictor_with_images_py(pyimages, pydetections, pyscales, predictor);
} }
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
...@@ -239,7 +238,7 @@ ensures \n\ ...@@ -239,7 +238,7 @@ ensures \n\
} }
{ {
def("train_shape_predictor", train_shape_predictor_on_images_py, def("train_shape_predictor", train_shape_predictor_on_images_py,
(arg("images"), arg("object_detections"), arg("predictor_output_filename"), arg("options")), (arg("images"), arg("object_detections"), arg("options")),
"requires \n\ "requires \n\
- options.lambda > 0 \n\ - options.lambda > 0 \n\
- options.nu > 0 \n\ - options.nu > 0 \n\
...@@ -255,7 +254,7 @@ ensures \n\ ...@@ -255,7 +254,7 @@ ensures \n\
preprocessing techniques to the training procedure for shape_predictors \n\ preprocessing techniques to the training procedure for shape_predictors \n\
objects. So the point of this function is to provide you with a very easy \n\ objects. So the point of this function is to provide you with a very easy \n\
way to train a basic shape predictor. \n\ way to train a basic shape predictor. \n\
- The trained shape predictor is serialized to the file predictor_output_filename."); - The trained shape_predictor is returned");
def("train_shape_predictor", train_shape_predictor, def("train_shape_predictor", train_shape_predictor,
(arg("dataset_filename"), arg("predictor_output_filename"), arg("options")), (arg("dataset_filename"), arg("predictor_output_filename"), arg("options")),
...@@ -289,15 +288,14 @@ ensures \n\ ...@@ -289,15 +288,14 @@ ensures \n\
for shape_predictor_trainer() for a detailed definition of the mean average error."); for shape_predictor_trainer() for a detailed definition of the mean average error.");
def("test_shape_predictor", test_shape_predictor_with_images_no_scales_py, def("test_shape_predictor", test_shape_predictor_with_images_no_scales_py,
(arg("images"), arg("detections"), arg("predictor_filename")), (arg("images"), arg("detections"), arg("shape_predictor")),
"requires \n\ "requires \n\
- len(images) == len(object_detections) \n\ - len(images) == len(object_detections) \n\
- images should be a list of numpy matrices that represent images, either RGB or grayscale. \n\ - images should be a list of numpy matrices that represent images, either RGB or grayscale. \n\
- object_detections should be a list of lists of dlib.full_object_detection objects. \ - object_detections should be a list of lists of dlib.full_object_detection objects. \
Each dlib.full_object_detection contains the bounding box and the lists of points that make up the object parts.\n\ Each dlib.full_object_detection contains the bounding box and the lists of points that make up the object parts.\n\
ensures \n\ ensures \n\
- Loads a shape_predictor from the file predictor_filename. This means \n\ - shape_predictor should be a file produced by the train_shape_predictor() \n\
predictor_filename should be a file produced by the train_shape_predictor() \n\
routine. \n\ routine. \n\
- This function tests the predictor against the dataset and returns the \n\ - This function tests the predictor against the dataset and returns the \n\
mean average error of the detector. In fact, The \n\ mean average error of the detector. In fact, The \n\
...@@ -307,7 +305,7 @@ ensures \n\ ...@@ -307,7 +305,7 @@ ensures \n\
def("test_shape_predictor", test_shape_predictor_with_images_py, def("test_shape_predictor", test_shape_predictor_with_images_py,
(arg("images"), arg("detections"), arg("scales"), arg("predictor_filename")), (arg("images"), arg("detections"), arg("scales"), arg("shape_predictor")),
"requires \n\ "requires \n\
- len(images) == len(object_detections) \n\ - len(images) == len(object_detections) \n\
- len(object_detections) == len(scales) \n\ - len(object_detections) == len(scales) \n\
...@@ -318,8 +316,7 @@ ensures \n\ ...@@ -318,8 +316,7 @@ ensures \n\
- object_detections should be a list of lists of dlib.full_object_detection objects. \ - object_detections should be a list of lists of dlib.full_object_detection objects. \
Each dlib.full_object_detection contains the bounding box and the lists of points that make up the object parts.\n\ Each dlib.full_object_detection contains the bounding box and the lists of points that make up the object parts.\n\
ensures \n\ ensures \n\
- Loads a shape_predictor from the file predictor_filename. This means \n\ - shape_predictor should be a file produced by the train_shape_predictor() \n\
predictor_filename should be a file produced by the train_shape_predictor() \n\
routine. \n\ routine. \n\
- This function tests the predictor against the dataset and returns the \n\ - This function tests the predictor against the dataset and returns the \n\
mean average error of the detector. In fact, The \n\ mean average error of the detector. In fact, The \n\
......
...@@ -65,11 +65,9 @@ namespace dlib ...@@ -65,11 +65,9 @@ namespace dlib
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
template <typename image_array> template <typename image_array>
inline void train_shape_predictor_on_images ( inline shape_predictor train_shape_predictor_on_images (
const std::string& dataset_filename, // can be "" if it's not applicable
image_array& images, image_array& images,
std::vector<std::vector<full_object_detection> >& detections, std::vector<std::vector<full_object_detection> >& detections,
const std::string& predictor_output_filename,
const shape_predictor_training_options& options const shape_predictor_training_options& options
) )
{ {
...@@ -116,13 +114,7 @@ namespace dlib ...@@ -116,13 +114,7 @@ namespace dlib
shape_predictor predictor = trainer.train(images, detections); shape_predictor predictor = trainer.train(images, detections);
std::ofstream fout(predictor_output_filename.c_str(), std::ios::binary); return predictor;
int version = 1;
serialize(predictor, fout);
serialize(version, fout);
if (options.be_verbose)
std::cout << "Training complete, saved predictor to file " << predictor_output_filename << std::endl;
} }
inline void train_shape_predictor ( inline void train_shape_predictor (
...@@ -135,7 +127,15 @@ namespace dlib ...@@ -135,7 +127,15 @@ namespace dlib
std::vector<std::vector<full_object_detection> > objects; std::vector<std::vector<full_object_detection> > objects;
load_image_dataset(images, objects, dataset_filename); load_image_dataset(images, objects, dataset_filename);
train_shape_predictor_on_images(dataset_filename, images, objects, predictor_output_filename, options); shape_predictor predictor = train_shape_predictor_on_images(images, objects, options);
std::ofstream fout(predictor_output_filename.c_str(), std::ios::binary);
int version = 1;
serialize(predictor, fout);
serialize(version, fout);
if (options.be_verbose)
std::cout << "Training complete, saved predictor to file " << predictor_output_filename << std::endl;
} }
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
...@@ -145,7 +145,7 @@ namespace dlib ...@@ -145,7 +145,7 @@ namespace dlib
image_array& images, image_array& images,
std::vector<std::vector<full_object_detection> >& detections, std::vector<std::vector<full_object_detection> >& detections,
std::vector<std::vector<double> >& scales, std::vector<std::vector<double> >& scales,
const std::string& predictor_filename const shape_predictor& predictor
) )
{ {
if (images.size() != detections.size()) if (images.size() != detections.size())
...@@ -153,16 +153,6 @@ namespace dlib ...@@ -153,16 +153,6 @@ namespace dlib
if (scales.size() > 0 && scales.size() != images.size()) if (scales.size() > 0 && scales.size() != images.size())
throw error("The list of scales must have the same length as the list of detections."); throw error("The list of scales must have the same length as the list of detections.");
shape_predictor predictor;
int version = 0;
std::ifstream fin(predictor_filename.c_str(), std::ios::binary);
if (!fin)
throw error("Unable to open file " + predictor_filename);
deserialize(predictor, fin);
deserialize(version, fin);
if (version != 1)
throw error("Unknown shape_predictor format.");
if (scales.size() > 0) if (scales.size() > 0)
return test_shape_predictor(predictor, images, detections, scales); return test_shape_predictor(predictor, images, detections, scales);
else else
...@@ -174,13 +164,25 @@ namespace dlib ...@@ -174,13 +164,25 @@ namespace dlib
const std::string& predictor_filename const std::string& predictor_filename
) )
{ {
// Load the images, no scales can be provided
dlib::array<array2d<rgb_pixel> > images; dlib::array<array2d<rgb_pixel> > images;
// This interface cannot take the scales parameter. // This interface cannot take the scales parameter.
std::vector<std::vector<double> > scales; std::vector<std::vector<double> > scales;
std::vector<std::vector<full_object_detection> > objects; std::vector<std::vector<full_object_detection> > objects;
load_image_dataset(images, objects, dataset_filename); load_image_dataset(images, objects, dataset_filename);
return test_shape_predictor_with_images(images, objects, scales, predictor_filename); // Load the shape predictor
shape_predictor predictor;
int version = 0;
std::ifstream fin(predictor_filename.c_str(), std::ios::binary);
if (!fin)
throw error("Unable to open file " + predictor_filename);
deserialize(predictor, fin);
deserialize(version, fin);
if (version != 1)
throw error("Unknown shape_predictor format.");
return test_shape_predictor_with_images(images, objects, scales, predictor);
} }
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
......
...@@ -127,12 +127,11 @@ namespace dlib ...@@ -127,12 +127,11 @@ namespace dlib
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
template <typename image_array> template <typename image_array>
inline void train_simple_object_detector_on_images ( inline simple_object_detector train_simple_object_detector_on_images (
const std::string& dataset_filename, // can be "" if it's not applicable const std::string& dataset_filename, // can be "" if it's not applicable
image_array& images, image_array& images,
std::vector<std::vector<rectangle> >& boxes, std::vector<std::vector<rectangle> >& boxes,
std::vector<std::vector<rectangle> >& ignore, std::vector<std::vector<rectangle> >& ignore,
const std::string& detector_output_filename,
const simple_object_detector_training_options& options const simple_object_detector_training_options& options
) )
{ {
...@@ -169,7 +168,6 @@ namespace dlib ...@@ -169,7 +168,6 @@ namespace dlib
trainer.be_verbose(); trainer.be_verbose();
} }
unsigned long upsample_amount = 0; unsigned long upsample_amount = 0;
// now make sure all the boxes are obtainable by the scanner. We will try and // now make sure all the boxes are obtainable by the scanner. We will try and
...@@ -194,14 +192,9 @@ namespace dlib ...@@ -194,14 +192,9 @@ namespace dlib
simple_object_detector detector = trainer.train(images, boxes, ignore); simple_object_detector detector = trainer.train(images, boxes, ignore);
std::ofstream fout(detector_output_filename.c_str(), std::ios::binary);
int version = 1;
serialize(detector, fout);
serialize(version, fout);
if (options.be_verbose) if (options.be_verbose)
{ {
std::cout << "Training complete, saved detector to file " << detector_output_filename << std::endl; std::cout << "Training complete." << std::endl;
std::cout << "Trained with C: " << options.C << std::endl; std::cout << "Trained with C: " << options.C << std::endl;
std::cout << "Training with epsilon: " << options.epsilon << std::endl; std::cout << "Training with epsilon: " << options.epsilon << std::endl;
std::cout << "Trained using " << options.num_threads << " threads."<< std::endl; std::cout << "Trained using " << options.num_threads << " threads."<< std::endl;
...@@ -216,6 +209,8 @@ namespace dlib ...@@ -216,6 +209,8 @@ namespace dlib
if (options.add_left_right_image_flips) if (options.add_left_right_image_flips)
std::cout << "Trained on both left and right flipped versions of images." << std::endl; std::cout << "Trained on both left and right flipped versions of images." << std::endl;
} }
return detector;
} }
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
...@@ -230,7 +225,15 @@ namespace dlib ...@@ -230,7 +225,15 @@ namespace dlib
std::vector<std::vector<rectangle> > boxes, ignore; std::vector<std::vector<rectangle> > boxes, ignore;
ignore = load_image_dataset(images, boxes, dataset_filename); ignore = load_image_dataset(images, boxes, dataset_filename);
train_simple_object_detector_on_images(dataset_filename, images, boxes, ignore, detector_output_filename, options); simple_object_detector detector = train_simple_object_detector_on_images(dataset_filename, images, boxes, ignore, options);
std::ofstream fout(detector_output_filename.c_str(), std::ios::binary);
int version = 1;
serialize(detector, fout);
serialize(version, fout);
if (options.be_verbose)
std::cout << "Saved detector to file " << detector_output_filename << std::endl;
} }
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
...@@ -245,21 +248,14 @@ namespace dlib ...@@ -245,21 +248,14 @@ namespace dlib
template <typename image_array> template <typename image_array>
inline const simple_test_results test_simple_object_detector_with_images ( inline const simple_test_results test_simple_object_detector_with_images (
image_array& images, image_array& images,
const unsigned int upsample_amount,
std::vector<std::vector<rectangle> >& boxes, std::vector<std::vector<rectangle> >& boxes,
std::vector<std::vector<rectangle> >& ignore, std::vector<std::vector<rectangle> >& ignore,
const std::string& detector_filename simple_object_detector& detector
) )
{ {
simple_object_detector detector; for (unsigned int i = 0; i < upsample_amount; ++i)
int version = 0; upsample_image_dataset<pyramid_down<2> >(images, boxes);
unsigned int upsample_amount = 0;
std::ifstream fin(detector_filename.c_str(), std::ios::binary);
if (!fin)
throw error("Unable to open file " + detector_filename);
deserialize(detector, fin);
deserialize(version, fin);
if (version != 1)
throw error("Unknown simple_object_detector format.");
matrix<double,1,3> res = test_object_detection_function(detector, images, boxes, ignore); matrix<double,1,3> res = test_object_detection_function(detector, images, boxes, ignore);
simple_test_results ret; simple_test_results ret;
...@@ -271,14 +267,27 @@ namespace dlib ...@@ -271,14 +267,27 @@ namespace dlib
inline const simple_test_results test_simple_object_detector ( inline const simple_test_results test_simple_object_detector (
const std::string& dataset_filename, const std::string& dataset_filename,
const std::string& detector_filename const std::string& detector_filename,
const unsigned int upsample_amount
) )
{ {
// Load all the testing images
dlib::array<array2d<rgb_pixel> > images; dlib::array<array2d<rgb_pixel> > images;
std::vector<std::vector<rectangle> > boxes, ignore; std::vector<std::vector<rectangle> > boxes, ignore;
ignore = load_image_dataset(images, boxes, dataset_filename); ignore = load_image_dataset(images, boxes, dataset_filename);
return test_simple_object_detector_with_images(images, boxes, ignore, detector_filename); // Load the detector off disk
simple_object_detector detector;
int version = 0;
std::ifstream fin(detector_filename.c_str(), std::ios::binary);
if (!fin)
throw error("Unable to open file " + detector_filename);
deserialize(detector, fin);
deserialize(version, fin);
if (version != 1)
throw error("Unknown simple_object_detector format.");
return test_simple_object_detector_with_images(images, upsample_amount, boxes, ignore, detector);
} }
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment