Commit 517858ad authored by Davis King's avatar Davis King

changed to run on image net

parent fe70bd12
...@@ -41,7 +41,9 @@ void randomly_crop_image ( ...@@ -41,7 +41,9 @@ void randomly_crop_image (
) )
{ {
// figure out what rectangle we want to crop from the image // figure out what rectangle we want to crop from the image
auto scale = 1-rnd.get_random_double()*0.2; //auto scale = 1-rnd.get_random_double()*0.2;
double mins = 0.466666666, maxs = 0.875;
auto scale = mins + rnd.get_random_double()*(maxs-mins);
auto size = scale*std::min(img.nr(), img.nc()); auto size = scale*std::min(img.nr(), img.nc());
rectangle rect(size, size); rectangle rect(size, size);
// randomly shift the box around // randomly shift the box around
...@@ -49,8 +51,8 @@ void randomly_crop_image ( ...@@ -49,8 +51,8 @@ void randomly_crop_image (
rnd.get_random_32bit_number()%(img.nr()-rect.height())); rnd.get_random_32bit_number()%(img.nr()-rect.height()));
rect = move_rect(rect, offset); rect = move_rect(rect, offset);
// now crop it out as a 250x250 image. // now crop it out as a 224x224 image.
extract_image_chip(img, chip_details(rect, chip_dims(250,250)), crop); extract_image_chip(img, chip_details(rect, chip_dims(224,224)), crop);
// Also randomly flip the image // Also randomly flip the image
if (rnd.get_random_double() > 0.5) if (rnd.get_random_double() > 0.5)
...@@ -71,7 +73,9 @@ void randomly_crop_images ( ...@@ -71,7 +73,9 @@ void randomly_crop_images (
for (long i = 0; i < num_crops; ++i) for (long i = 0; i < num_crops; ++i)
{ {
// figure out what rectangle we want to crop from the image // figure out what rectangle we want to crop from the image
auto scale = 1-rnd.get_random_double()*0.2; //auto scale = 1-rnd.get_random_double()*0.2;
double mins = 0.466666666, maxs = 0.875;
auto scale = mins + rnd.get_random_double()*(maxs-mins);
auto size = scale*std::min(img.nr(), img.nc()); auto size = scale*std::min(img.nr(), img.nc());
rectangle rect(size, size); rectangle rect(size, size);
// randomly shift the box around // randomly shift the box around
...@@ -79,7 +83,7 @@ void randomly_crop_images ( ...@@ -79,7 +83,7 @@ void randomly_crop_images (
rnd.get_random_32bit_number()%(img.nr()-rect.height())); rnd.get_random_32bit_number()%(img.nr()-rect.height()));
rect = move_rect(rect, offset); rect = move_rect(rect, offset);
dets.push_back(chip_details(rect, chip_dims(250,250))); dets.push_back(chip_details(rect, chip_dims(224,224)));
} }
extract_image_chips(img, dets, crops); extract_image_chips(img, dets, crops);
...@@ -104,7 +108,7 @@ struct image_info ...@@ -104,7 +108,7 @@ struct image_info
unsigned long numeric_label; unsigned long numeric_label;
}; };
std::vector<image_info> get_mit67_listing( std::vector<image_info> get_imagenet_listing(
const std::string& images_folder const std::string& images_folder
) )
{ {
...@@ -147,9 +151,10 @@ int main(int argc, char** argv) try ...@@ -147,9 +151,10 @@ int main(int argc, char** argv) try
return 1; return 1;
} }
auto listing = get_mit67_listing(argv[1]); auto listing = get_imagenet_listing(argv[1]);
cout << "images in dataset: " << listing.size() << endl; cout << "images in dataset: " << listing.size() << endl;
if (listing.size() == 0 || listing.back().numeric_label != 66) const auto number_of_classes = listing.back().numeric_label+1;
if (listing.size() == 0 || number_of_classes != 1000)
{ {
cout << "Didn't find the MIT 67 scene dataset. Are you sure you gave the correct folder?" << endl; cout << "Didn't find the MIT 67 scene dataset. Are you sure you gave the correct folder?" << endl;
cout << "Give the Images folder as an argument to this program." << endl; cout << "Give the Images folder as an argument to this program." << endl;
...@@ -161,21 +166,21 @@ int main(int argc, char** argv) try ...@@ -161,21 +166,21 @@ int main(int argc, char** argv) try
const double weight_decay = sa = argv[2]; const double weight_decay = sa = argv[2];
typedef loss_multiclass_log<fc<avg_pool< typedef loss_multiclass_log<fc<avg_pool<
res<res< res<res<res<
res<res< res<res<res<res<res<res<
res<res< res<res<res<res<
res<res< res<res<res<
max_pool<relu<bn<con< max_pool<relu<bn<con<
input<matrix<rgb_pixel> input<matrix<rgb_pixel>
>>>>>>>>>>>>>>>> net_type; >>>>>>>>>>>>>>>>>>>>>>>> net_type;
net_type net(fc_(67), net_type net(fc_(number_of_classes),
avg_pool_(1000,1000,1000,1000), avg_pool_(1000,1000,1000,1000),
res_(512),res_(512,2), res_(512),res_(512),res_(512,2),
res_(256),res_(256,2), res_(256),res_(256),res_(256),res_(256),res_(256),res_(256,2),
res_(128),res_(128,2), res_(128),res_(128),res_(128),res_(128,2),
res_(64), res_(64), res_(64), res_(64), res_(64),
max_pool_(3,3,2,2), relu_(), bn_(CONV_MODE), con_(64,7,7,2,2) max_pool_(3,3,2,2), relu_(), bn_(CONV_MODE), con_(64,7,7,2,2)
); );
...@@ -185,12 +190,13 @@ int main(int argc, char** argv) try ...@@ -185,12 +190,13 @@ int main(int argc, char** argv) try
dnn_trainer<net_type> trainer(net,sgd(initial_step_size, weight_decay)); dnn_trainer<net_type> trainer(net,sgd(initial_step_size, weight_decay));
trainer.be_verbose(); trainer.be_verbose();
trainer.set_synchronization_file("mit67_sync3_"+cast_to_string(weight_decay), std::chrono::minutes(5)); trainer.set_synchronization_file("sync_imagenet_full_training_set_40000_minstep_"+cast_to_string(weight_decay), std::chrono::minutes(5));
trainer.set_iterations_between_step_size_adjust(40000);
std::vector<matrix<rgb_pixel>> samples; std::vector<matrix<rgb_pixel>> samples;
std::vector<unsigned long> labels; std::vector<unsigned long> labels;
randomize_samples(listing); randomize_samples(listing);
const size_t training_part = listing.size()*0.7; const size_t training_part = listing.size()*1.0;
dlib::rand rnd; dlib::rand rnd;
...@@ -198,14 +204,14 @@ int main(int argc, char** argv) try ...@@ -198,14 +204,14 @@ int main(int argc, char** argv) try
const bool do_training = true; const bool do_training = true;
if (do_training) if (do_training)
{ {
while(trainer.get_step_size() >= 1e-4) while(trainer.get_step_size() >= 1e-3)
{ {
samples.clear(); samples.clear();
labels.clear(); labels.clear();
// make a 64 image mini-batch // make a 128 image mini-batch
matrix<rgb_pixel> img, crop; matrix<rgb_pixel> img, crop;
while(samples.size() < 64) while(samples.size() < 128)
{ {
auto l = listing[rnd.get_random_32bit_number()%training_part]; auto l = listing[rnd.get_random_32bit_number()%training_part];
load_image(img, l.filename); load_image(img, l.filename);
...@@ -222,25 +228,25 @@ int main(int argc, char** argv) try ...@@ -222,25 +228,25 @@ int main(int argc, char** argv) try
net.clean(); net.clean();
cout << "saving network" << endl; cout << "saving network" << endl;
serialize("mit67_network3_"+cast_to_string(weight_decay)+".dat") << net; serialize("imagenet_full_training_set_40000_minstep_"+cast_to_string(weight_decay)+".dat") << net;
} }
const bool test_network = true; const bool test_network = false;
if (test_network) if (test_network)
{ {
typedef loss_multiclass_log<fc<avg_pool< typedef loss_multiclass_log<fc<avg_pool<
ares<ares< ares<ares<ares<
ares<ares< ares<ares<ares<ares<ares<ares<
ares<ares< ares<ares<ares<ares<
ares<ares< ares<ares<ares<
max_pool<relu<affine<con< max_pool<relu<affine<con<
input<matrix<rgb_pixel> input<matrix<rgb_pixel>
>>>>>>>>>>>>>>>> anet_type; >>>>>>>>>>>>>>>>>>>>>>>> anet_type;
anet_type net; anet_type net;
deserialize("mit67_network3_"+cast_to_string(weight_decay)+".dat") >> net; deserialize("imagenet_network3_"+cast_to_string(weight_decay)+".dat") >> net;
dlib::array<matrix<rgb_pixel>> images; dlib::array<matrix<rgb_pixel>> images;
std::vector<unsigned long> labels; std::vector<unsigned long> labels;
...@@ -249,6 +255,7 @@ int main(int argc, char** argv) try ...@@ -249,6 +255,7 @@ int main(int argc, char** argv) try
int num_right = 0; int num_right = 0;
int num_wrong = 0; int num_wrong = 0;
console_progress_indicator pbar(training_part); console_progress_indicator pbar(training_part);
/*
for (size_t i = 0; i < training_part; ++i) for (size_t i = 0; i < training_part; ++i)
{ {
pbar.print_status(i); pbar.print_status(i);
...@@ -261,6 +268,7 @@ int main(int argc, char** argv) try ...@@ -261,6 +268,7 @@ int main(int argc, char** argv) try
else else
++num_wrong; ++num_wrong;
} }
*/
cout << "\ntraining num_right: " << num_right << endl; cout << "\ntraining num_right: " << num_right << endl;
cout << "training num_wrong: " << num_wrong << endl; cout << "training num_wrong: " << num_wrong << endl;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment