// The contents of this file are in the public domain. See LICENSE_FOR_EXAMPLE_PROGRAMS.txt /* This is an example illustrating the use of the deep learning tools from the dlib C++ Library. In it, we will show how to use the loss_metric layer to do metric learning. */ #include <dlib/dnn.h> #include <iostream> using namespace std; using namespace dlib; int main() try { using net_type = loss_metric<fc<2,input<matrix<double,0,1>>>>; net_type net; dnn_trainer<net_type> trainer(net); trainer.set_learning_rate(0.1); trainer.set_min_learning_rate(0.00001); trainer.set_mini_batch_size(128); trainer.be_verbose(); trainer.set_iterations_without_progress_threshold(100); std::vector<matrix<double,0,1>> samples; std::vector<unsigned long> labels; samples.push_back({1,0,0,0,0,0,0,0}); labels.push_back(1); samples.push_back({0,1,0,0,0,0,0,0}); labels.push_back(1); samples.push_back({0,0,1,0,0,0,0,0}); labels.push_back(2); samples.push_back({0,0,0,1,0,0,0,0}); labels.push_back(2); samples.push_back({0,0,0,0,1,0,0,0}); labels.push_back(3); samples.push_back({0,0,0,0,0,1,0,0}); labels.push_back(3); samples.push_back({0,0,0,0,0,0,1,0}); labels.push_back(4); samples.push_back({0,0,0,0,0,0,0,1}); labels.push_back(4); trainer.train(samples, labels); // Run all the images through the network to get their vector embeddings. std::vector<matrix<float,0,1>> embedded = net(images); for (size_t i = 0; i < embedded.size(); ++i) cout << "label: " << labels[i] << "\t" << trans(embedded[i]); // Now, check if the embedding puts things with the same labels near each other and // things with different labels far apart. int num_right = 0; int num_wrong = 0; for (size_t i = 0; i < embedded.size(); ++i) { for (size_t j = i+1; j < embedded.size(); ++j) { if (labels[i] == labels[j]) { // The loss_metric layer will cause things with the same label to be less // than net.loss_details().get_distance_threshold() distance from each // other. So we can use that distance value as our testing threshold. if (length(embedded[i]-embedded[j]) < net.loss_details().get_distance_threshold()) ++num_right; else ++num_wrong; } else { if (length(embedded[i]-embedded[j]) >= net.loss_details().get_distance_threshold()) ++num_right; else ++num_wrong; } } } cout << "num_right: "<< num_right << endl; cout << "num_wrong: "<< num_wrong << endl; } catch(std::exception& e) { cout << e.what() << endl; }