Commit 47bdf95f authored by Davis King's avatar Davis King

added more stuff to example

parent bd79b877
......@@ -60,8 +60,7 @@ int main(int argc, char** argv) try
);
//dnn_trainer<net_type,adam> trainer(net,adam(0.001));
dnn_trainer<net_type> trainer(net,sgd(0.1));
dnn_trainer<net_type,adam> trainer(net,adam(0.001));
trainer.be_verbose();
trainer.set_synchronization_file("mnist_resnet_sync", std::chrono::seconds(100));
std::vector<matrix<unsigned char>> mini_batch_samples;
......@@ -86,11 +85,29 @@ int main(int argc, char** argv) try
// wait for threaded processing to stop.
trainer.get_net();
// You can access sub layers of the network like this:
net.subnet().subnet().get_output();
layer<avg_pool>(net).get_output();
net.clean();
serialize("mnist_network.dat") << net;
serialize("mnist_res_network.dat") << net;
typedef loss_multiclass_log<fc<avg_pool<
ares<ares<ares<ares<
repeat<10,ares,
ares<
ares<
input<matrix<unsigned char>
>>>>>>>>>>> test_net_type;
test_net_type tnet = net;
// or you could deserialize the saved network
deserialize("mnist_res_network.dat") >> tnet;
// Run the net on all the data to get predictions
std::vector<unsigned long> predicted_labels = net(training_images);
std::vector<unsigned long> predicted_labels = tnet(training_images);
int num_right = 0;
int num_wrong = 0;
for (size_t i = 0; i < training_images.size(); ++i)
......@@ -105,7 +122,7 @@ int main(int argc, char** argv) try
cout << "training num_wrong: " << num_wrong << endl;
cout << "training accuracy: " << num_right/(double)(num_right+num_wrong) << endl;
predicted_labels = net(testing_images);
predicted_labels = tnet(testing_images);
num_right = 0;
num_wrong = 0;
for (size_t i = 0; i < testing_images.size(); ++i)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment