train_object_detector.cpp 18 KB
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422
// The contents of this file are in the public domain. See LICENSE_FOR_EXAMPLE_PROGRAMS.txt
/*

    This is an example showing how you might use dlib to create a reasonably 
    functional command line tool for object detection.  This example assumes 
    you are familiar with the contents of at least the following example 
    programs:
        - fhog_object_detector_ex.cpp
        - compress_stream_ex.cpp




    This program is a command line tool for learning to detect objects in images.  
    Therefore, to create an object detector it requires a set of annotated training 
    images.  To create this annotated data you will need to use the imglab tool 
    included with dlib.  It is located in the tools/imglab folder and can be compiled
    using the following commands.  
        cd tools/imglab
        mkdir build
        cd build
        cmake ..
        cmake --build . --config Release
    Note that you may need to install CMake (www.cmake.org) for this to work.  

    Next, let's assume you have a folder of images called /tmp/images.  These images 
    should contain examples of the objects you want to learn to detect.  You will 
    use the imglab tool to label these objects.  Do this by typing the following
        ./imglab -c mydataset.xml /tmp/images
    This will create a file called mydataset.xml which simply lists the images in 
    /tmp/images.  To annotate them run
        ./imglab mydataset.xml
    A window will appear showing all the images.  You can use the up and down arrow 
    keys to cycle though the images and the mouse to label objects.  In particular, 
    holding the shift key, left clicking, and dragging the mouse will allow you to 
    draw boxes around the objects you wish to detect.  So next, label all the objects 
    with boxes.  Note that it is important to label all the objects since any object 
    not labeled is implicitly assumed to be not an object we should detect.  If there
    are objects you are not sure about you should draw a box around them, then double
    click the box and press i.  This will cross out the box and mark it as "ignore".
    The training code in dlib will then simply ignore detections matching that box.
    

    Once you finish labeling objects go to the file menu, click save, and then close 
    the program. This will save the object boxes back to mydataset.xml.  You can verify 
    this by opening the tool again with
        ./imglab mydataset.xml
    and observing that the boxes are present.

    Returning to the present example program, we can compile it using cmake just as we 
    did with the imglab tool.  Once compiled, we can issue the command 
        ./train_object_detector -tv mydataset.xml
    which will train an object detection model based on our labeled data.  The model 
    will be saved to the file object_detector.svm.  Once this has finished we can use 
    the object detector to locate objects in new images with a command like
        ./train_object_detector some_image.png
    This command will display some_image.png in a window and any detected objects will
    be indicated by a red box.

    Finally, to make running this example easy dlib includes some training data in the
    examples/faces folder.  Therefore, you can test this program out by running the
    following sequence of commands:
      ./train_object_detector -tv examples/faces/training.xml -u1 --flip
      ./train_object_detector --test examples/faces/testing.xml -u1
      ./train_object_detector examples/faces/*.jpg -u1
    That will make a face detector that performs perfectly on the test images listed in
    testing.xml and then it will show you the detections on all the images.
*/


#include <dlib/svm_threaded.h>
#include <dlib/string.h>
#include <dlib/gui_widgets.h>
#include <dlib/image_processing.h>
#include <dlib/data_io.h>
#include <dlib/cmd_line_parser.h>


#include <iostream>
#include <fstream>


using namespace std;
using namespace dlib;

// ----------------------------------------------------------------------------------------

void pick_best_window_size (
    const std::vector<std::vector<rectangle> >& boxes,
    unsigned long& width,
    unsigned long& height,
    const unsigned long target_size
)
/*!
    ensures
        - Finds the average aspect ratio of the elements of boxes and outputs a width
          and height such that the aspect ratio is equal to the average and also the
          area is equal to target_size.  That is, the following will be approximately true:
            - #width*#height == target_size
            - #width/#height == the average aspect ratio of the elements of boxes.
!*/
{
    // find the average width and height
    running_stats<double> avg_width, avg_height;
    for (unsigned long i = 0; i < boxes.size(); ++i)
    {
        for (unsigned long j = 0; j < boxes[i].size(); ++j)
        {
            avg_width.add(boxes[i][j].width());
            avg_height.add(boxes[i][j].height());
        }
    }

    // now adjust the box size so that it is about target_pixels pixels in size
    double size = avg_width.mean()*avg_height.mean();
    double scale = std::sqrt(target_size/size);

    width = (unsigned long)(avg_width.mean()*scale+0.5);
    height = (unsigned long)(avg_height.mean()*scale+0.5);
    // make sure the width and height never round to zero.
    if (width == 0)
        width = 1;
    if (height == 0)
        height = 1;
}

// ----------------------------------------------------------------------------------------

bool contains_any_boxes (
    const std::vector<std::vector<rectangle> >& boxes
)
{
    for (unsigned long i = 0; i < boxes.size(); ++i)
    {
        if (boxes[i].size() != 0)
            return true;
    }
    return false;
}

// ----------------------------------------------------------------------------------------

void throw_invalid_box_error_message (
    const std::string& dataset_filename,
    const std::vector<std::vector<rectangle> >& removed,
    const unsigned long target_size
)
{
    image_dataset_metadata::dataset data;
    load_image_dataset_metadata(data, dataset_filename);

    std::ostringstream sout;
    sout << "Error!  An impossible set of object boxes was given for training. ";
    sout << "All the boxes need to have a similar aspect ratio and also not be ";
    sout << "smaller than about " << target_size << " pixels in area. ";
    sout << "The following images contain invalid boxes:\n";
    std::ostringstream sout2;
    for (unsigned long i = 0; i < removed.size(); ++i)
    {
        if (removed[i].size() != 0)
        {
            const std::string imgname = data.images[i].filename;
            sout2 << "  " << imgname << "\n";
        }
    }
    throw error("\n"+wrap_string(sout.str()) + "\n" + sout2.str());
}

// ----------------------------------------------------------------------------------------

int main(int argc, char** argv)
{  
    try
    {
        command_line_parser parser;
        parser.add_option("h","Display this help message.");
        parser.add_option("t","Train an object detector and save the detector to disk.");
        parser.add_option("cross-validate",
                          "Perform cross-validation on an image dataset and print the results.");
        parser.add_option("test", "Test a trained detector on an image dataset and print the results.");
        parser.add_option("u", "Upsample each input image <arg> times. Each upsampling quadruples the number of pixels in the image (default: 0).", 1);

        parser.set_group_name("training/cross-validation sub-options");
        parser.add_option("v","Be verbose.");
        parser.add_option("folds","When doing cross-validation, do <arg> folds (default: 3).",1);
        parser.add_option("c","Set the SVM C parameter to <arg> (default: 1.0).",1);
        parser.add_option("threads", "Use <arg> threads for training (default: 4).",1);
        parser.add_option("eps", "Set training epsilon to <arg> (default: 0.01).", 1);
        parser.add_option("target-size", "Set size of the sliding window to about <arg> pixels in area (default: 80*80).", 1);
        parser.add_option("flip", "Add left/right flipped copies of the images into the training dataset.  Useful when the objects "
            "you want to detect are left/right symmetric.");


        parser.parse(argc, argv);

        // Now we do a little command line validation.  Each of the following functions
        // checks something and throws an exception if the test fails.
        const char* one_time_opts[] = {"h", "v", "t", "cross-validate", "c", "threads", "target-size",
                                        "folds", "test", "eps", "u", "flip"};
        parser.check_one_time_options(one_time_opts); // Can't give an option more than once
        // Make sure the arguments to these options are within valid ranges if they are supplied by the user.
        parser.check_option_arg_range("c", 1e-12, 1e12);
        parser.check_option_arg_range("eps", 1e-5, 1e4);
        parser.check_option_arg_range("threads", 1, 1000);
        parser.check_option_arg_range("folds", 2, 100);
        parser.check_option_arg_range("u", 0, 8);
        parser.check_option_arg_range("target-size", 4*4, 10000*10000);
        const char* incompatible[] = {"t", "cross-validate", "test"};
        parser.check_incompatible_options(incompatible);
        // You are only allowed to give these training_sub_ops if you also give either -t or --cross-validate.
        const char* training_ops[] = {"t", "cross-validate"};
        const char* training_sub_ops[] = {"v", "c", "threads", "target-size", "eps", "flip"};
        parser.check_sub_options(training_ops, training_sub_ops); 
        parser.check_sub_option("cross-validate", "folds"); 


        if (parser.option("h"))
        {
            cout << "Usage: train_object_detector [options] <image dataset file|image file>\n";
            parser.print_options(); 
                                       
            return EXIT_SUCCESS;
        }


        typedef scan_fhog_pyramid<pyramid_down<6> > image_scanner_type; 
        // Get the upsample option from the user but use 0 if it wasn't given.
        const unsigned long upsample_amount = get_option(parser, "u", 0);

        if (parser.option("t") || parser.option("cross-validate"))
        {
            if (parser.number_of_arguments() != 1)
            {
                cout << "You must give an image dataset metadata XML file produced by the imglab tool." << endl;
                cout << "\nTry the -h option for more information." << endl;
                return EXIT_FAILURE;
            }

            dlib::array<array2d<unsigned char> > images;
            std::vector<std::vector<rectangle> > object_locations, ignore;

            cout << "Loading image dataset from metadata file " << parser[0] << endl;
            ignore = load_image_dataset(images, object_locations, parser[0]);
            cout << "Number of images loaded: " << images.size() << endl;

            // Get the options from the user, but use default values if they are not
            // supplied.
            const int threads = get_option(parser, "threads", 4);
            const double C   = get_option(parser, "c", 1.0);
            const double eps = get_option(parser, "eps", 0.01);
            unsigned int num_folds = get_option(parser, "folds", 3);
            const unsigned long target_size = get_option(parser, "target-size", 80*80);
            // You can't do more folds than there are images.  
            if (num_folds > images.size())
                num_folds = images.size();

            // Upsample images if the user asked us to do that.
            for (unsigned long i = 0; i < upsample_amount; ++i)
                upsample_image_dataset<pyramid_down<2> >(images, object_locations, ignore);


            image_scanner_type scanner;
            unsigned long width, height;
            pick_best_window_size(object_locations, width, height, target_size);
            scanner.set_detection_window_size(width, height); 

            structural_object_detection_trainer<image_scanner_type> trainer(scanner);
            trainer.set_num_threads(threads);
            if (parser.option("v"))
                trainer.be_verbose();
            trainer.set_c(C);
            trainer.set_epsilon(eps);

            // Now make sure all the boxes are obtainable by the scanner.  
            std::vector<std::vector<rectangle> > removed;
            removed = remove_unobtainable_rectangles(trainer, images, object_locations);
            // if we weren't able to get all the boxes to match then throw an error 
            if (contains_any_boxes(removed))
            {
                unsigned long scale = upsample_amount+1;
                scale = scale*scale;
                throw_invalid_box_error_message(parser[0], removed, target_size/scale);
            }

            if (parser.option("flip"))
                add_image_left_right_flips(images, object_locations, ignore);

            if (parser.option("t"))
            {
                // Do the actual training and save the results into the detector object.  
                object_detector<image_scanner_type> detector = trainer.train(images, object_locations, ignore);

                cout << "Saving trained detector to object_detector.svm" << endl;
                serialize("object_detector.svm") << detector;

                cout << "Testing detector on training data..." << endl;
                cout << "Test detector (precision,recall,AP): " << test_object_detection_function(detector, images, object_locations) << endl;
            }
            else
            {
                // shuffle the order of the training images
                randomize_samples(images, object_locations);

                cout << num_folds << "-fold cross validation (precision,recall,AP): "
                     << cross_validate_object_detection_trainer(trainer, images, object_locations, ignore, num_folds) << endl;
            }

            cout << "Parameters used: " << endl;
            cout << "  threads:                 "<< threads << endl;
            cout << "  C:                       "<< C << endl;
            cout << "  eps:                     "<< eps << endl;
            cout << "  target-size:             "<< target_size << endl;
            cout << "  detection window width:  "<< width << endl;
            cout << "  detection window height: "<< height << endl;
            cout << "  upsample this many times : "<< upsample_amount << endl;
            if (parser.option("flip"))
                cout << "  trained using left/right flips." << endl;
            if (parser.option("cross-validate"))
                cout << "  num_folds: "<< num_folds << endl;
            cout << endl;

            return EXIT_SUCCESS;
        }







        // The rest of the code is devoted to testing an already trained object detector.

        if (parser.number_of_arguments() == 0)
        {
            cout << "You must give an image or an image dataset metadata XML file produced by the imglab tool." << endl;
            cout << "\nTry the -h option for more information." << endl;
            return EXIT_FAILURE;
        }

        // load a previously trained object detector and try it out on some data
        ifstream fin("object_detector.svm", ios::binary);
        if (!fin)
        {
            cout << "Can't find a trained object detector file object_detector.svm. " << endl;
            cout << "You need to train one using the -t option." << endl;
            cout << "\nTry the -h option for more information." << endl;
            return EXIT_FAILURE;

        }
        object_detector<image_scanner_type> detector;
        deserialize(detector, fin);

        dlib::array<array2d<unsigned char> > images;
        // Check if the command line argument is an XML file
        if (tolower(right_substr(parser[0],".")) == "xml")
        {
            std::vector<std::vector<rectangle> > object_locations, ignore;
            cout << "Loading image dataset from metadata file " << parser[0] << endl;
            ignore = load_image_dataset(images, object_locations, parser[0]);
            cout << "Number of images loaded: " << images.size() << endl;

            // Upsample images if the user asked us to do that.
            for (unsigned long i = 0; i < upsample_amount; ++i)
                upsample_image_dataset<pyramid_down<2> >(images, object_locations, ignore);

            if (parser.option("test"))
            {
                cout << "Testing detector on data..." << endl;
                cout << "Results (precision,recall,AP): " << test_object_detection_function(detector, images, object_locations, ignore) << endl;
                return EXIT_SUCCESS;
            }
        }
        else
        {
            // In this case, the user should have given some image files.  So just
            // load them.
            images.resize(parser.number_of_arguments());
            for (unsigned long i = 0; i < images.size(); ++i)
                load_image(images[i], parser[i]);

            // Upsample images if the user asked us to do that.
            for (unsigned long i = 0; i < upsample_amount; ++i)
            {
                for (unsigned long j = 0; j < images.size(); ++j)
                    pyramid_up(images[j]);
            }
        }


        // Test the detector on the images we loaded and display the results
        // in a window.
        image_window win;
        for (unsigned long i = 0; i < images.size(); ++i)
        {
            // Run the detector on images[i] 
            const std::vector<rectangle> rects = detector(images[i]);
            cout << "Number of detections: "<< rects.size() << endl;

            // Put the image and detections into the window.
            win.clear_overlay();
            win.set_image(images[i]);
            win.add_overlay(rects, rgb_pixel(255,0,0));

            cout << "Hit enter to see the next image.";
            cin.get();
        }


    }
    catch (exception& e)
    {
        cout << "\nexception thrown!" << endl;
        cout << e.what() << endl;
        cout << "\nTry the -h option for more information." << endl;
        return EXIT_FAILURE;
    }

    return EXIT_SUCCESS;
}

// ----------------------------------------------------------------------------------------