Commit 8db3f4e5 authored by Patrick Snape's avatar Patrick Snape

Add a save method to detectors and predictors

Also, removed the saving of the upsample which I missed from
before (since I'm not using the struct now). I understand why
the upsample was being saved, but I don't necessarily agree it
is particularly useful as you should really be upsampling on
a case by case basis at test time.
parent 5b485a62
......@@ -108,6 +108,14 @@ std::vector<rectangle> run_detector_with_upscale (
}
}
void save_simple_object_detector(const simple_object_detector& detector, const std::string& detector_output_filename)
{
std::ofstream fout(detector_output_filename.c_str(), std::ios::binary);
int version = 1;
serialize(detector, fout);
serialize(version, fout);
}
// ----------------------------------------------------------------------------------------
inline void train_simple_object_detector_on_images_py (
......@@ -384,6 +392,7 @@ ensures \n\
default will be used.
!*/
)
.def("save", save_simple_object_detector, (arg("detector_output_filename")), "Save a simple_object_detector to the provided path.")
.def_pickle(serialize_pickle<type>());
}
{
......
......@@ -35,6 +35,14 @@ full_object_detection run_predictor (
}
}
void save_shape_predictor(const shape_predictor& predictor, const std::string& predictor_output_filename)
{
std::ofstream fout(predictor_output_filename.c_str(), std::ios::binary);
int version = 1;
serialize(predictor, fout);
serialize(version, fout);
}
// ----------------------------------------------------------------------------------------
rectangle full_obj_det_get_rect (const full_object_detection& detection)
......@@ -226,6 +234,7 @@ train_shape_predictor() routine.")
ensures \n\
- This function runs the shape predictor on the input image and returns \n\
a single full object detection.")
.def("save", save_shape_predictor, (arg("predictor_output_filename")), "Save a shape_predictor to the provided path.")
.def_pickle(serialize_pickle<type>());
}
{
......
......@@ -176,24 +176,15 @@ namespace dlib
// upsample the images at most two times to help make the boxes obtainable.
std::vector<std::vector<rectangle> > temp(boxes), removed;
removed = remove_unobtainable_rectangles(trainer, images, temp);
if (impl::contains_any_boxes(removed))
{
++upsample_amount;
if (options.be_verbose)
std::cout << "upsample images..." << std::endl;
upsample_image_dataset<pyramid_down<2> >(images, boxes, ignore);
temp = boxes;
removed = remove_unobtainable_rectangles(trainer, images, temp);
if (impl::contains_any_boxes(removed))
while (impl::contains_any_boxes(removed) && upsample_amount < 2)
{
++upsample_amount;
if (options.be_verbose)
std::cout << "upsample images..." << std::endl;
std::cout << "Upsample images..." << std::endl;
upsample_image_dataset<pyramid_down<2> >(images, boxes, ignore);
temp = boxes;
removed = remove_unobtainable_rectangles(trainer, images, temp);
}
}
// if we weren't able to get all the boxes to match then throw an error
if (impl::contains_any_boxes(removed))
impl::throw_invalid_box_error_message(dataset_filename, removed, options);
......@@ -207,7 +198,6 @@ namespace dlib
int version = 1;
serialize(detector, fout);
serialize(version, fout);
serialize(upsample_amount, fout);
if (options.be_verbose)
{
......@@ -218,10 +208,10 @@ namespace dlib
std::cout << "Trained with sliding window " << width << " pixels wide by " << height << " pixels tall." << std::endl;
if (upsample_amount != 0)
{
if (upsample_amount == 1)
std::cout << "Upsampled images " << upsample_amount << " time to allow detection of small boxes." << std::endl;
else
std::cout << "Upsampled images " << upsample_amount << " times to allow detection of small boxes." << std::endl;
// Unsampled images # time(s) to allow detection of small boxes
std::cout << "Upsampled images " << upsample_amount;
std::cout << (upsample_amount == 1) ? " time" : " times";
std::cout << " to allow detection of small boxes." << std::endl;
}
if (options.add_left_right_image_flips)
std::cout << "Trained on both left and right flipped versions of images." << std::endl;
......@@ -270,10 +260,6 @@ namespace dlib
deserialize(version, fin);
if (version != 1)
throw error("Unknown simple_object_detector format.");
deserialize(upsample_amount, fin);
for (unsigned int i = 0; i < upsample_amount; ++i)
upsample_image_dataset<pyramid_down<2> >(images, boxes);
matrix<double,1,3> res = test_object_detection_function(detector, images, boxes, ignore);
simple_test_results ret;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment