Commit 3711b278 authored by Davis King's avatar Davis King

Added unit tests for scan_image_boxes

parent 18c305fc
......@@ -248,6 +248,20 @@ namespace
temp.push_back(centered_rect(point(123,121), 70,70));
fill_rect(images[2],temp.back(),255); // Paint the square white
object_locations.push_back(temp);
// corrupt each image with random noise just to make this a little more
// challenging
dlib::rand rnd;
for (unsigned long i = 0; i < images.size(); ++i)
{
for (long r = 0; r < images[i].nr(); ++r)
{
for (long c = 0; c < images[i].nc(); ++c)
{
images[i][r][c] = put_in_range(0,255,images[i][r][c] + 10*rnd.get_random_gaussian());
}
}
}
}
template <
......@@ -389,6 +403,46 @@ namespace
}
}
// ----------------------------------------------------------------------------------------
void test_1_boxes (
)
{
print_spinner();
dlog << LINFO << "test_1_boxes()";
typedef dlib::array<array2d<unsigned char> > grayscale_image_array_type;
grayscale_image_array_type images;
std::vector<std::vector<rectangle> > object_locations;
make_simple_test_data(images, object_locations);
typedef hashed_feature_image<hog_image<3,3,1,4,hog_signed_gradient,hog_full_interpolation> > feature_extractor_type;
typedef scan_image_boxes<feature_extractor_type> image_scanner_type;
image_scanner_type scanner;
setup_hashed_features(scanner, images, 9);
structural_object_detection_trainer<image_scanner_type> trainer(scanner);
trainer.set_num_threads(4);
trainer.set_overlap_tester(test_box_overlap(0,0));
object_detector<image_scanner_type> detector = trainer.train(images, object_locations);
matrix<double> res = test_object_detection_function(detector, images, object_locations);
dlog << LINFO << "Test detector (precision,recall): " << res;
DLIB_TEST(sum(res) == 3);
{
ostringstream sout;
serialize(detector, sout);
istringstream sin(sout.str());
object_detector<image_scanner_type> d2;
deserialize(d2, sin);
matrix<double> res = test_object_detection_function(d2, images, object_locations);
dlog << LINFO << "Test detector (precision,recall): " << res;
DLIB_TEST(sum(res) == 3);
validate_some_object_detector_stuff(images, detector);
}
}
// ----------------------------------------------------------------------------------------
void test_1m (
......@@ -615,6 +669,51 @@ namespace
}
}
// ----------------------------------------------------------------------------------------
void test_1_poly_nn_boxes (
)
{
print_spinner();
dlog << LINFO << "test_1_poly_nn_boxes()";
typedef dlib::array<array2d<unsigned char> > grayscale_image_array_type;
grayscale_image_array_type images;
std::vector<std::vector<rectangle> > object_locations;
make_simple_test_data(images, object_locations);
typedef nearest_neighbor_feature_image<poly_image<5> > feature_extractor_type;
typedef scan_image_boxes<feature_extractor_type> image_scanner_type;
image_scanner_type scanner;
feature_extractor_type nnfe;
pyramid_down pyr_down;
poly_image<5> polyi;
nnfe.set_basis(randomly_sample_image_features(images, pyr_down, polyi, 80));
scanner.copy_configuration(nnfe);
structural_object_detection_trainer<image_scanner_type> trainer(scanner);
trainer.set_num_threads(4);
object_detector<image_scanner_type> detector = trainer.train(images, object_locations);
matrix<double> res = test_object_detection_function(detector, images, object_locations);
dlog << LINFO << "Test detector (precision,recall): " << res;
DLIB_TEST(sum(res) == 3);
{
ostringstream sout;
serialize(detector, sout);
istringstream sin(sout.str());
object_detector<image_scanner_type> d2;
deserialize(d2, sin);
matrix<double> res = test_object_detection_function(d2, images, object_locations);
dlog << LINFO << "Test detector (precision,recall): " << res;
DLIB_TEST(sum(res) == 3);
validate_some_object_detector_stuff(images, detector);
}
}
// ----------------------------------------------------------------------------------------
void test_2 (
......@@ -749,6 +848,74 @@ namespace
}
}
// ----------------------------------------------------------------------------------------
class funny_box_generator
{
public:
template <typename image_type>
void operator() (
const image_type& img,
std::vector<rectangle>& rects
) const
{
rects.clear();
find_candidate_object_locations(img.img, rects);
dlog << LINFO << "funny_box_generator, rects.size(): "<< rects.size();
}
};
inline void serialize(const funny_box_generator&, std::ostream& ) {}
inline void deserialize(funny_box_generator&, std::istream& ) {}
// make sure everything works even when the image isn't a dlib::array2d.
// So test with funny_image.
void test_3_boxes (
)
{
print_spinner();
dlog << LINFO << "test_3_boxes()";
typedef dlib::array<array2d<unsigned char> > grayscale_image_array_type;
typedef dlib::array<funny_image> funny_image_array_type;
grayscale_image_array_type images_temp;
funny_image_array_type images;
std::vector<std::vector<rectangle> > object_locations;
make_simple_test_data(images_temp, object_locations);
images.resize(images_temp.size());
for (unsigned long i = 0; i < images_temp.size(); ++i)
{
images[i].img.swap(images_temp[i]);
}
typedef scan_image_boxes<very_simple_feature_extractor, funny_box_generator> image_scanner_type;
image_scanner_type scanner;
structural_object_detection_trainer<image_scanner_type> trainer(scanner);
trainer.set_num_threads(4);
object_detector<image_scanner_type> detector = trainer.train(images, object_locations);
matrix<double> res = test_object_detection_function(detector, images, object_locations);
dlog << LINFO << "Test detector (precision,recall): " << res;
DLIB_TEST(sum(res) == 3);
res = cross_validate_object_detection_trainer(trainer, images, object_locations, 3);
dlog << LINFO << "3-fold cross validation (precision,recall): " << res;
DLIB_TEST(sum(res) == 3);
{
ostringstream sout;
serialize(detector, sout);
istringstream sin(sout.str());
object_detector<image_scanner_type> d2;
deserialize(d2, sin);
matrix<double> res = test_object_detection_function(d2, images, object_locations);
dlog << LINFO << "Test detector (precision,recall): " << res;
DLIB_TEST(sum(res) == 3);
}
}
// ----------------------------------------------------------------------------------------
class object_detector_tester : public tester
......@@ -763,6 +930,10 @@ namespace
void perform_test (
)
{
test_1_boxes();
test_1_poly_nn_boxes();
test_3_boxes();
test_1();
test_1m();
test_1_fine_hog();
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment