Commit be8dc926 authored by Davis King's avatar Davis King

merged

parents c763fafd f685cb42
...@@ -2386,6 +2386,106 @@ namespace dlib ...@@ -2386,6 +2386,106 @@ namespace dlib
using mult_prev9_ = mult_prev_<tag9>; using mult_prev9_ = mult_prev_<tag9>;
using mult_prev10_ = mult_prev_<tag10>; using mult_prev10_ = mult_prev_<tag10>;
// ----------------------------------------------------------------------------------------
template <
template<typename> class tag
>
class resize_prev_to_tagged_
{
public:
const static unsigned long id = tag_id<tag>::id;
resize_prev_to_tagged_()
{
}
template <typename SUBNET>
void setup (const SUBNET& /*sub*/)
{
}
template <typename SUBNET>
void forward(const SUBNET& sub, resizable_tensor& output)
{
auto& prev = sub.get_output();
auto& tagged = layer<tag>(sub).get_output();
DLIB_CASSERT(prev.num_samples() == tagged.num_samples());
output.set_size(prev.num_samples(),
prev.k(),
tagged.nr(),
tagged.nc());
if (prev.nr() == tagged.nr() && prev.nc() == tagged.nc())
{
tt::copy_tensor(false, output, 0, prev, 0, prev.k());
}
else
{
tt::resize_bilinear(output, prev);
}
}
template <typename SUBNET>
void backward(const tensor& gradient_input, SUBNET& sub, tensor& /*params_grad*/)
{
auto& prev = sub.get_gradient_input();
DLIB_CASSERT(prev.k() == gradient_input.k());
DLIB_CASSERT(prev.num_samples() == gradient_input.num_samples());
if (prev.nr() == gradient_input.nr() && prev.nc() == gradient_input.nc())
{
tt::copy_tensor(true, prev, 0, gradient_input, 0, prev.k());
}
else
{
tt::resize_bilinear_gradient(prev, gradient_input);
}
}
const tensor& get_layer_params() const { return params; }
tensor& get_layer_params() { return params; }
inline dpoint map_input_to_output (const dpoint& p) const { return p; }
inline dpoint map_output_to_input (const dpoint& p) const { return p; }
friend void serialize(const resize_prev_to_tagged_& , std::ostream& out)
{
serialize("resize_prev_to_tagged_", out);
}
friend void deserialize(resize_prev_to_tagged_& , std::istream& in)
{
std::string version;
deserialize(version, in);
if (version != "resize_prev_to_tagged_")
throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::resize_prev_to_tagged_.");
}
friend std::ostream& operator<<(std::ostream& out, const resize_prev_to_tagged_& item)
{
out << "resize_prev_to_tagged"<<id;
return out;
}
friend void to_xml(const resize_prev_to_tagged_& item, std::ostream& out)
{
out << "<resize_prev_to_tagged tag='"<<id<<"'/>\n";
}
private:
resizable_tensor params;
};
template <
template<typename> class tag,
typename SUBNET
>
using resize_prev_to_tagged = add_layer<resize_prev_to_tagged_<tag>, SUBNET>;
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
template < template <
......
...@@ -2382,6 +2382,56 @@ namespace dlib ...@@ -2382,6 +2382,56 @@ namespace dlib
using mult_prev9_ = mult_prev_<tag9>; using mult_prev9_ = mult_prev_<tag9>;
using mult_prev10_ = mult_prev_<tag10>; using mult_prev10_ = mult_prev_<tag10>;
// ----------------------------------------------------------------------------------------
template <
template<typename> class tag
>
class resize_prev_to_tagged_
{
/*!
WHAT THIS OBJECT REPRESENTS
This is an implementation of the EXAMPLE_COMPUTATIONAL_LAYER_ interface
defined above. This layer resizes the output channels of the previous layer
to have the same number of rows and columns as the output of the tagged layer.
This layer uses bilinear interpolation. If the sizes match already, then it
simply copies the data.
Therefore, you supply a tag via resize_prev_to_tagged's template argument that
tells it what layer to use for the target size.
If tensor PREV is resized to size of tensor TAGGED, then a tensor OUT is
produced such that:
- OUT.num_samples() == PREV.num_samples()
- OUT.k() == PREV.k()
- OUT.nr() == TAGGED.nr()
- OUT.nc() == TAGGED.nc()
!*/
public:
resize_prev_to_tagged_(
);
template <typename SUBNET> void setup(const SUBNET& sub);
template <typename SUBNET> void forward(const SUBNET& sub, resizable_tensor& output);
template <typename SUBNET> void backward(const tensor& gradient_input, SUBNET& sub, tensor& params_grad);
dpoint map_input_to_output(dpoint p) const;
dpoint map_output_to_input(dpoint p) const;
const tensor& get_layer_params() const;
tensor& get_layer_params();
/*!
These functions are implemented as described in the EXAMPLE_COMPUTATIONAL_LAYER_ interface.
!*/
};
template <
template<typename> class tag,
typename SUBNET
>
using resize_prev_to_tagged = add_layer<resize_prev_to_tagged_<tag>, SUBNET>;
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
template < template <
......
...@@ -293,7 +293,7 @@ jpeg_make_d_derived_tbl (j_decompress_ptr cinfo, int isDC, int tblno, ...@@ -293,7 +293,7 @@ jpeg_make_d_derived_tbl (j_decompress_ptr cinfo, int isDC, int tblno,
GLOBAL(int) GLOBAL(int)
jpeg_fill_bit_buffer (bitread_working_state * state, jpeg_fill_bit_buffer (bitread_working_state * state,
bit_buf_type get_buffer, register int bits_left, bit_buf_type get_buffer, int bits_left,
int nbits) int nbits)
/* Load up the bit buffer to a depth of at least nbits */ /* Load up the bit buffer to a depth of at least nbits */
{ {
...@@ -399,7 +399,7 @@ jpeg_fill_bit_buffer (bitread_working_state * state, ...@@ -399,7 +399,7 @@ jpeg_fill_bit_buffer (bitread_working_state * state,
GLOBAL(int) GLOBAL(int)
jpeg_huff_decode (bitread_working_state * state, jpeg_huff_decode (bitread_working_state * state,
bit_buf_type get_buffer, register int bits_left, bit_buf_type get_buffer, int bits_left,
d_derived_tbl * htbl, int min_bits) d_derived_tbl * htbl, int min_bits)
{ {
int l = min_bits; int l = min_bits;
......
...@@ -1910,7 +1910,7 @@ namespace ...@@ -1910,7 +1910,7 @@ namespace
template <typename SUBNET> template <typename SUBNET>
using pres = prelu<add_prev1<bn_con<con<8,3,3,1,1,prelu<bn_con<con<8,3,3,1,1,tag1<SUBNET>>>>>>>>; using pres = prelu<add_prev1<bn_con<con<8,3,3,1,1,prelu<bn_con<con<8,3,3,1,1,tag1<SUBNET>>>>>>>>;
void test_visit_funcions() void test_visit_functions()
{ {
using net_type2 = loss_multiclass_log<fc<10, using net_type2 = loss_multiclass_log<fc<10,
avg_pool_everything< avg_pool_everything<
...@@ -3243,7 +3243,7 @@ namespace ...@@ -3243,7 +3243,7 @@ namespace
test_batch_normalize_conv(); test_batch_normalize_conv();
test_basic_tensor_ops(); test_basic_tensor_ops();
test_layers(); test_layers();
test_visit_funcions(); test_visit_functions();
test_copy_tensor_cpu(); test_copy_tensor_cpu();
test_copy_tensor_add_to_cpu(); test_copy_tensor_add_to_cpu();
test_concat(); test_concat();
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
./dnn_semantic_segmentation_ex /path/to/VOC2012-or-other-images ./dnn_semantic_segmentation_ex /path/to/VOC2012-or-other-images
An alternative to steps 2-4 above is to download a pre-trained network An alternative to steps 2-4 above is to download a pre-trained network
from here: http://dlib.net/files/semantic_segmentation_voc2012net.dnn from here: http://dlib.net/files/semantic_segmentation_voc2012net_v2.dnn
It would be a good idea to become familiar with dlib's DNN tooling before reading this It would be a good idea to become familiar with dlib's DNN tooling before reading this
example. So you should read dnn_introduction_ex.cpp and dnn_introduction2_ex.cpp example. So you should read dnn_introduction_ex.cpp and dnn_introduction2_ex.cpp
...@@ -111,16 +111,16 @@ int main(int argc, char** argv) try ...@@ -111,16 +111,16 @@ int main(int argc, char** argv) try
cout << "You call this program like this: " << endl; cout << "You call this program like this: " << endl;
cout << "./dnn_semantic_segmentation_train_ex /path/to/images" << endl; cout << "./dnn_semantic_segmentation_train_ex /path/to/images" << endl;
cout << endl; cout << endl;
cout << "You will also need a trained 'semantic_segmentation_voc2012net.dnn' file." << endl; cout << "You will also need a trained '" << semantic_segmentation_net_filename << "' file." << endl;
cout << "You can either train it yourself (see example program" << endl; cout << "You can either train it yourself (see example program" << endl;
cout << "dnn_semantic_segmentation_train_ex), or download a" << endl; cout << "dnn_semantic_segmentation_train_ex), or download a" << endl;
cout << "copy from here: http://dlib.net/files/semantic_segmentation_voc2012net.dnn" << endl; cout << "copy from here: http://dlib.net/files/" << semantic_segmentation_net_filename << endl;
return 1; return 1;
} }
// Read the file containing the trained network from the working directory. // Read the file containing the trained network from the working directory.
anet_type net; anet_type net;
deserialize("semantic_segmentation_voc2012net.dnn") >> net; deserialize(semantic_segmentation_net_filename) >> net;
// Show inference results in a window. // Show inference results in a window.
image_window win; image_window win;
......
...@@ -23,7 +23,7 @@ ...@@ -23,7 +23,7 @@
./dnn_semantic_segmentation_ex /path/to/VOC2012-or-other-images ./dnn_semantic_segmentation_ex /path/to/VOC2012-or-other-images
An alternative to steps 2-4 above is to download a pre-trained network An alternative to steps 2-4 above is to download a pre-trained network
from here: http://dlib.net/files/semantic_segmentation_voc2012net.dnn from here: http://dlib.net/files/semantic_segmentation_voc2012net_v2.dnn
It would be a good idea to become familiar with dlib's DNN tooling before reading this It would be a good idea to become familiar with dlib's DNN tooling before reading this
example. So you should read dnn_introduction_ex.cpp and dnn_introduction2_ex.cpp example. So you should read dnn_introduction_ex.cpp and dnn_introduction2_ex.cpp
...@@ -116,13 +116,13 @@ const Voc2012class& find_voc2012_class(Predicate predicate) ...@@ -116,13 +116,13 @@ const Voc2012class& find_voc2012_class(Predicate predicate)
// Introduce the building blocks used to define the segmentation network. // Introduce the building blocks used to define the segmentation network.
// The network first does residual downsampling (similar to the dnn_imagenet_(train_)ex // The network first does residual downsampling (similar to the dnn_imagenet_(train_)ex
// example program), and then residual upsampling. The network could be improved e.g. // example program), and then residual upsampling. In addition, U-Net style skip
// by introducing skip connections from the input image, and/or the first layers, to the // connections are used, so that not every simple detail needs to reprented on the low
// last layer(s). (See Long et al., Fully Convolutional Networks for Semantic Segmentation, // levels. (See Ronneberger et al. (2015), U-Net: Convolutional Networks for Biomedical
// https://people.eecs.berkeley.edu/~jonlong/long_shelhamer_fcn.pdf) // Image Segmentation, https://arxiv.org/pdf/1505.04597.pdf)
template <int N, template <typename> class BN, int stride, typename SUBNET> template <int N, template <typename> class BN, int stride, typename SUBNET>
using block = BN<dlib::con<N,3,3,1,1, dlib::relu<BN<dlib::con<N,3,3,stride,stride,SUBNET>>>>>; using block = BN<dlib::con<N,3,3,1,1,dlib::relu<BN<dlib::con<N,3,3,stride,stride,SUBNET>>>>>;
template <int N, template <typename> class BN, int stride, typename SUBNET> template <int N, template <typename> class BN, int stride, typename SUBNET>
using blockt = BN<dlib::cont<N,3,3,1,1,dlib::relu<BN<dlib::cont<N,3,3,stride,stride,SUBNET>>>>>; using blockt = BN<dlib::cont<N,3,3,1,1,dlib::relu<BN<dlib::cont<N,3,3,stride,stride,SUBNET>>>>>;
...@@ -145,55 +145,98 @@ template <int N, typename SUBNET> using ares_up = dlib::relu<residual_up<block ...@@ -145,55 +145,98 @@ template <int N, typename SUBNET> using ares_up = dlib::relu<residual_up<block
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
template <typename SUBNET> using res512 = res<512, SUBNET>; template <typename SUBNET> using res64 = res<64,SUBNET>;
template <typename SUBNET> using res256 = res<256, SUBNET>; template <typename SUBNET> using res128 = res<128,SUBNET>;
template <typename SUBNET> using res128 = res<128, SUBNET>; template <typename SUBNET> using res256 = res<256,SUBNET>;
template <typename SUBNET> using res64 = res<64, SUBNET>; template <typename SUBNET> using res512 = res<512,SUBNET>;
template <typename SUBNET> using ares512 = ares<512, SUBNET>; template <typename SUBNET> using ares64 = ares<64,SUBNET>;
template <typename SUBNET> using ares256 = ares<256, SUBNET>; template <typename SUBNET> using ares128 = ares<128,SUBNET>;
template <typename SUBNET> using ares128 = ares<128, SUBNET>; template <typename SUBNET> using ares256 = ares<256,SUBNET>;
template <typename SUBNET> using ares64 = ares<64, SUBNET>; template <typename SUBNET> using ares512 = ares<512,SUBNET>;
template <typename SUBNET> using level1 = dlib::repeat<2,res64,res<64,SUBNET>>;
template <typename SUBNET> using level2 = dlib::repeat<2,res128,res_down<128,SUBNET>>;
template <typename SUBNET> using level3 = dlib::repeat<2,res256,res_down<256,SUBNET>>;
template <typename SUBNET> using level4 = dlib::repeat<2,res512,res_down<512,SUBNET>>;
template <typename SUBNET> using alevel1 = dlib::repeat<2,ares64,ares<64,SUBNET>>;
template <typename SUBNET> using alevel2 = dlib::repeat<2,ares128,ares_down<128,SUBNET>>;
template <typename SUBNET> using alevel3 = dlib::repeat<2,ares256,ares_down<256,SUBNET>>;
template <typename SUBNET> using alevel4 = dlib::repeat<2,ares512,ares_down<512,SUBNET>>;
template <typename SUBNET> using level1t = dlib::repeat<2,res64,res_up<64,SUBNET>>;
template <typename SUBNET> using level2t = dlib::repeat<2,res128,res_up<128,SUBNET>>;
template <typename SUBNET> using level3t = dlib::repeat<2,res256,res_up<256,SUBNET>>;
template <typename SUBNET> using level4t = dlib::repeat<2,res512,res_up<512,SUBNET>>;
template <typename SUBNET> using alevel1t = dlib::repeat<2,ares64,ares_up<64,SUBNET>>;
template <typename SUBNET> using alevel2t = dlib::repeat<2,ares128,ares_up<128,SUBNET>>;
template <typename SUBNET> using alevel3t = dlib::repeat<2,ares256,ares_up<256,SUBNET>>;
template <typename SUBNET> using alevel4t = dlib::repeat<2,ares512,ares_up<512,SUBNET>>;
// ----------------------------------------------------------------------------------------
template <typename SUBNET> using level1 = dlib::repeat<2,res512,res_down<512,SUBNET>>; template <
template <typename SUBNET> using level2 = dlib::repeat<2,res256,res_down<256,SUBNET>>; template<typename> class TAGGED,
template <typename SUBNET> using level3 = dlib::repeat<2,res128,res_down<128,SUBNET>>; template<typename> class PREV_RESIZED,
template <typename SUBNET> using level4 = dlib::repeat<2,res64,res<64,SUBNET>>; typename SUBNET
>
template <typename SUBNET> using alevel1 = dlib::repeat<2,ares512,ares_down<512,SUBNET>>; using resize_and_concat = dlib::add_layer<
template <typename SUBNET> using alevel2 = dlib::repeat<2,ares256,ares_down<256,SUBNET>>; dlib::concat_<TAGGED,PREV_RESIZED>,
template <typename SUBNET> using alevel3 = dlib::repeat<2,ares128,ares_down<128,SUBNET>>; PREV_RESIZED<dlib::resize_prev_to_tagged<TAGGED,SUBNET>>>;
template <typename SUBNET> using alevel4 = dlib::repeat<2,ares64,ares<64,SUBNET>>;
template <typename SUBNET> using utag1 = dlib::add_tag_layer<2100+1,SUBNET>;
template <typename SUBNET> using utag2 = dlib::add_tag_layer<2100+2,SUBNET>;
template <typename SUBNET> using utag3 = dlib::add_tag_layer<2100+3,SUBNET>;
template <typename SUBNET> using utag4 = dlib::add_tag_layer<2100+4,SUBNET>;
template <typename SUBNET> using utag1_ = dlib::add_tag_layer<2110+1,SUBNET>;
template <typename SUBNET> using utag2_ = dlib::add_tag_layer<2110+2,SUBNET>;
template <typename SUBNET> using utag3_ = dlib::add_tag_layer<2110+3,SUBNET>;
template <typename SUBNET> using utag4_ = dlib::add_tag_layer<2110+4,SUBNET>;
template <typename SUBNET> using concat_utag1 = resize_and_concat<utag1,utag1_,SUBNET>;
template <typename SUBNET> using concat_utag2 = resize_and_concat<utag2,utag2_,SUBNET>;
template <typename SUBNET> using concat_utag3 = resize_and_concat<utag3,utag3_,SUBNET>;
template <typename SUBNET> using concat_utag4 = resize_and_concat<utag4,utag4_,SUBNET>;
template <typename SUBNET> using level1t = dlib::repeat<2,res512,res_up<512,SUBNET>>; // ----------------------------------------------------------------------------------------
template <typename SUBNET> using level2t = dlib::repeat<2,res256,res_up<256,SUBNET>>;
template <typename SUBNET> using level3t = dlib::repeat<2,res128,res_up<128,SUBNET>>;
template <typename SUBNET> using level4t = dlib::repeat<2,res64,res_up<64,SUBNET>>;
template <typename SUBNET> using alevel1t = dlib::repeat<2,ares512,ares_up<512,SUBNET>>; static const char* semantic_segmentation_net_filename = "semantic_segmentation_voc2012net_v2.dnn";
template <typename SUBNET> using alevel2t = dlib::repeat<2,ares256,ares_up<256,SUBNET>>;
template <typename SUBNET> using alevel3t = dlib::repeat<2,ares128,ares_up<128,SUBNET>>;
template <typename SUBNET> using alevel4t = dlib::repeat<2,ares64,ares_up<64,SUBNET>>;
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
// training network type // training network type
using net_type = dlib::loss_multiclass_log_per_pixel< using bnet_type = dlib::loss_multiclass_log_per_pixel<
dlib::cont<class_count,7,7,2,2, dlib::cont<class_count,1,1,1,1,
level4t<level3t<level2t<level1t< dlib::relu<dlib::bn_con<dlib::cont<64,7,7,2,2,
level1<level2<level3<level4< concat_utag1<level1t<
dlib::max_pool<3,3,2,2,dlib::relu<dlib::bn_con<dlib::con<64,7,7,2,2, concat_utag2<level2t<
concat_utag3<level3t<
concat_utag4<level4t<
level4<utag4<
level3<utag3<
level2<utag2<
level1<dlib::max_pool<3,3,2,2,utag1<
dlib::relu<dlib::bn_con<dlib::con<64,7,7,2,2,
dlib::input<dlib::matrix<dlib::rgb_pixel>> dlib::input<dlib::matrix<dlib::rgb_pixel>>
>>>>>>>>>>>>>>; >>>>>>>>>>>>>>>>>>>>>>>>>;
// testing network type (replaced batch normalization with fixed affine transforms) // testing network type (replaced batch normalization with fixed affine transforms)
using anet_type = dlib::loss_multiclass_log_per_pixel< using anet_type = dlib::loss_multiclass_log_per_pixel<
dlib::cont<class_count,7,7,2,2, dlib::cont<class_count,1,1,1,1,
alevel4t<alevel3t<alevel2t<alevel1t< dlib::relu<dlib::affine<dlib::cont<64,7,7,2,2,
alevel1<alevel2<alevel3<alevel4< concat_utag1<alevel1t<
dlib::max_pool<3,3,2,2,dlib::relu<dlib::affine<dlib::con<64,7,7,2,2, concat_utag2<alevel2t<
concat_utag3<alevel3t<
concat_utag4<alevel4t<
alevel4<utag4<
alevel3<utag3<
alevel2<utag2<
alevel1<dlib::max_pool<3,3,2,2,utag1<
dlib::relu<dlib::affine<dlib::con<64,7,7,2,2,
dlib::input<dlib::matrix<dlib::rgb_pixel>> dlib::input<dlib::matrix<dlib::rgb_pixel>>
>>>>>>>>>>>>>>; >>>>>>>>>>>>>>>>>>>>>>>>>;
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
......
...@@ -41,7 +41,7 @@ struct training_sample ...@@ -41,7 +41,7 @@ struct training_sample
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
rectangle make_random_cropping_rect_resnet( rectangle make_random_cropping_rect(
const matrix<rgb_pixel>& img, const matrix<rgb_pixel>& img,
dlib::rand& rnd dlib::rand& rnd
) )
...@@ -66,7 +66,7 @@ void randomly_crop_image ( ...@@ -66,7 +66,7 @@ void randomly_crop_image (
dlib::rand& rnd dlib::rand& rnd
) )
{ {
const auto rect = make_random_cropping_rect_resnet(input_image, rnd); const auto rect = make_random_cropping_rect(input_image, rnd);
const chip_details chip_details(rect, chip_dims(227, 227)); const chip_details chip_details(rect, chip_dims(227, 227));
...@@ -259,12 +259,12 @@ double calculate_accuracy(anet_type& anet, const std::vector<image_info>& datase ...@@ -259,12 +259,12 @@ double calculate_accuracy(anet_type& anet, const std::vector<image_info>& datase
int main(int argc, char** argv) try int main(int argc, char** argv) try
{ {
if (argc != 2) if (argc < 2 || argc > 3)
{ {
cout << "To run this program you need a copy of the PASCAL VOC2012 dataset." << endl; cout << "To run this program you need a copy of the PASCAL VOC2012 dataset." << endl;
cout << endl; cout << endl;
cout << "You call this program like this: " << endl; cout << "You call this program like this: " << endl;
cout << "./dnn_semantic_segmentation_train_ex /path/to/VOC2012" << endl; cout << "./dnn_semantic_segmentation_train_ex /path/to/VOC2012 [minibatch-size]" << endl;
return 1; return 1;
} }
...@@ -278,13 +278,16 @@ int main(int argc, char** argv) try ...@@ -278,13 +278,16 @@ int main(int argc, char** argv) try
return 1; return 1;
} }
// a mini-batch smaller than the default can be used with GPUs having less memory
const int minibatch_size = argc == 3 ? std::stoi(argv[2]) : 23;
cout << "mini-batch size: " << minibatch_size << endl;
const double initial_learning_rate = 0.1; const double initial_learning_rate = 0.1;
const double weight_decay = 0.0001; const double weight_decay = 0.0001;
const double momentum = 0.9; const double momentum = 0.9;
net_type net; bnet_type bnet;
dnn_trainer<net_type> trainer(net,sgd(weight_decay, momentum)); dnn_trainer<bnet_type> trainer(bnet,sgd(weight_decay, momentum));
trainer.be_verbose(); trainer.be_verbose();
trainer.set_learning_rate(initial_learning_rate); trainer.set_learning_rate(initial_learning_rate);
trainer.set_synchronization_file("pascal_voc2012_trainer_state_file.dat", std::chrono::minutes(10)); trainer.set_synchronization_file("pascal_voc2012_trainer_state_file.dat", std::chrono::minutes(10));
...@@ -292,7 +295,7 @@ int main(int argc, char** argv) try ...@@ -292,7 +295,7 @@ int main(int argc, char** argv) try
trainer.set_iterations_without_progress_threshold(5000); trainer.set_iterations_without_progress_threshold(5000);
// Since the progress threshold is so large might as well set the batch normalization // Since the progress threshold is so large might as well set the batch normalization
// stats window to something big too. // stats window to something big too.
set_all_bn_running_stats_window_sizes(net, 1000); set_all_bn_running_stats_window_sizes(bnet, 1000);
// Output training parameters. // Output training parameters.
cout << endl << trainer << endl; cout << endl << trainer << endl;
...@@ -345,9 +348,9 @@ int main(int argc, char** argv) try ...@@ -345,9 +348,9 @@ int main(int argc, char** argv) try
samples.clear(); samples.clear();
labels.clear(); labels.clear();
// make a 30-image mini-batch // make a mini-batch
training_sample temp; training_sample temp;
while(samples.size() < 30) while(samples.size() < minibatch_size)
{ {
data.dequeue(temp); data.dequeue(temp);
...@@ -369,13 +372,13 @@ int main(int argc, char** argv) try ...@@ -369,13 +372,13 @@ int main(int argc, char** argv) try
// also wait for threaded processing to stop in the trainer. // also wait for threaded processing to stop in the trainer.
trainer.get_net(); trainer.get_net();
net.clean(); bnet.clean();
cout << "saving network" << endl; cout << "saving network" << endl;
serialize("semantic_segmentation_voc2012net.dnn") << net; serialize(semantic_segmentation_net_filename) << bnet;
// Make a copy of the network to use it for inference. // Make a copy of the network to use it for inference.
anet_type anet = net; anet_type anet = bnet;
cout << "Testing the network..." << endl; cout << "Testing the network..." << endl;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment