Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in
Toggle navigation
D
dlib
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
钟尚武
dlib
Commits
483e6ab4
Commit
483e6ab4
authored
Nov 15, 2017
by
Davis King
Browse files
Options
Browse Files
Download
Plain Diff
merged
parents
391c11ed
e48125c2
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
853 additions
and
10 deletions
+853
-10
widgets.cpp
dlib/gui_widgets/widgets.cpp
+8
-0
widgets.h
dlib/gui_widgets/widgets.h
+3
-0
interpolation.h
dlib/image_transforms/interpolation.h
+40
-6
interpolation_abstract.h
dlib/image_transforms/interpolation_abstract.h
+39
-4
CMakeLists.txt
examples/CMakeLists.txt
+10
-0
dnn_semantic_segmentation_ex.cpp
examples/dnn_semantic_segmentation_ex.cpp
+172
-0
dnn_semantic_segmentation_ex.h
examples/dnn_semantic_segmentation_ex.h
+191
-0
dnn_semantic_segmentation_train_ex.cpp
examples/dnn_semantic_segmentation_train_ex.cpp
+390
-0
No files found.
dlib/gui_widgets/widgets.cpp
View file @
483e6ab4
...
...
@@ -1465,6 +1465,14 @@ namespace dlib
// ----------------------------------------------------------------------------------------
unsigned
long
tabbed_display
::
selected_tab
(
)
const
{
auto_mutex
M
(
m
);
return
selected_tab_
;
}
unsigned
long
tabbed_display
::
number_of_tabs
(
)
const
...
...
dlib/gui_widgets/widgets.h
View file @
483e6ab4
...
...
@@ -1190,6 +1190,9 @@ namespace dlib
unsigned
long
num
);
unsigned
long
selected_tab
(
)
const
;
unsigned
long
number_of_tabs
(
)
const
;
...
...
dlib/image_transforms/interpolation.h
View file @
483e6ab4
...
...
@@ -1856,12 +1856,14 @@ namespace dlib
template
<
typename
image_type1
,
typename
image_type2
typename
image_type2
,
typename
interpolation_type
>
void
extract_image_chips
(
const
image_type1
&
img
,
const
std
::
vector
<
chip_details
>&
chip_locations
,
dlib
::
array
<
image_type2
>&
chips
dlib
::
array
<
image_type2
>&
chips
,
const
interpolation_type
&
interp
)
{
// make sure requires clause is not broken
...
...
@@ -1957,9 +1959,9 @@ namespace dlib
// now extract the actual chip
if
(
level
==
-
1
)
transform_image
(
sub_image
(
img
,
bounding_box
),
chips
[
i
],
interp
olate_bilinear
()
,
trns
);
transform_image
(
sub_image
(
img
,
bounding_box
),
chips
[
i
],
interp
,
trns
);
else
transform_image
(
levels
[
level
],
chips
[
i
],
interp
olate_bilinear
()
,
trns
);
transform_image
(
levels
[
level
],
chips
[
i
],
interp
,
trns
);
}
}
}
...
...
@@ -1970,10 +1972,27 @@ namespace dlib
typename
image_type1
,
typename
image_type2
>
void
extract_image_chips
(
const
image_type1
&
img
,
const
std
::
vector
<
chip_details
>&
chip_locations
,
dlib
::
array
<
image_type2
>&
chips
)
{
extract_image_chips
(
img
,
chip_locations
,
chips
,
interpolate_bilinear
());
}
// ----------------------------------------------------------------------------------------
template
<
typename
image_type1
,
typename
image_type2
,
typename
interpolation_type
>
void
extract_image_chip
(
const
image_type1
&
img
,
const
chip_details
&
location
,
image_type2
&
chip
image_type2
&
chip
,
const
interpolation_type
&
interp
)
{
// If the chip doesn't have any rotation or scaling then use the basic version of
...
...
@@ -1988,11 +2007,26 @@ namespace dlib
{
std
::
vector
<
chip_details
>
chip_locations
(
1
,
location
);
dlib
::
array
<
image_type2
>
chips
;
extract_image_chips
(
img
,
chip_locations
,
chips
);
extract_image_chips
(
img
,
chip_locations
,
chips
,
interp
);
swap
(
chips
[
0
],
chip
);
}
}
// ----------------------------------------------------------------------------------------
template
<
typename
image_type1
,
typename
image_type2
>
void
extract_image_chip
(
const
image_type1
&
img
,
const
chip_details
&
location
,
image_type2
&
chip
)
{
extract_image_chip
(
img
,
location
,
chip
,
interpolate_bilinear
());
}
// ----------------------------------------------------------------------------------------
inline
chip_details
get_face_chip_details
(
...
...
dlib/image_transforms/interpolation_abstract.h
View file @
483e6ab4
...
...
@@ -1163,12 +1163,14 @@ namespace dlib
template
<
typename
image_type1
,
typename
image_type2
typename
image_type2
,
typename
interpolation_type
>
void
extract_image_chips
(
const
image_type1
&
img
,
const
std
::
vector
<
chip_details
>&
chip_locations
,
dlib
::
array
<
image_type2
>&
chips
dlib
::
array
<
image_type2
>&
chips
,
const
interpolation_type
&
interp
);
/*!
requires
...
...
@@ -1185,6 +1187,7 @@ namespace dlib
rectangular sub-windows (i.e. chips) within an image and extracts those
sub-windows, storing each into its own image. It also scales and rotates the
image chips according to the instructions inside each chip_details object.
It uses the interpolation method supplied as a parameter.
- #chips == the extracted image chips
- #chips.size() == chip_locations.size()
- for all valid i:
...
...
@@ -1198,16 +1201,33 @@ namespace dlib
- Any pixels in an image chip that go outside img are set to 0 (i.e. black).
!*/
template
<
typename
image_type1
,
typename
image_type2
>
void
extract_image_chips
(
const
image_type1
&
img
,
const
std
::
vector
<
chip_details
>&
chip_locations
,
dlib
::
array
<
image_type2
>&
chips
);
/*!
ensures
- This function is a simple convenience / compatibility wrapper that calls the
above-defined extract_image_chips function using bilinear interpolation.
!*/
// ----------------------------------------------------------------------------------------
template
<
typename
image_type1
,
typename
image_type2
typename
image_type2
,
typename
interpolation_type
>
void
extract_image_chip
(
const
image_type1
&
img
,
const
chip_details
&
chip_location
,
image_type2
&
chip
image_type2
&
chip
,
const
interpolation_type
&
interp
);
/*!
ensures
...
...
@@ -1215,6 +1235,21 @@ namespace dlib
and stores the single output chip into #chip.
!*/
template
<
typename
image_type1
,
typename
image_type2
>
void
extract_image_chip
(
const
image_type1
&
img
,
const
chip_details
&
chip_location
,
image_type2
&
chip
);
/*!
ensures
- This function is a simple convenience / compatibility wrapper that calls the
above-defined extract_image_chip function using bilinear interpolation.
!*/
// ----------------------------------------------------------------------------------------
template
<
...
...
examples/CMakeLists.txt
View file @
483e6ab4
...
...
@@ -124,12 +124,22 @@ if (NOT USING_OLD_VISUAL_STUDIO_COMPILER)
add_gui_example
(
dnn_mmod_find_cars_ex
)
add_gui_example
(
dnn_mmod_find_cars2_ex
)
add_example
(
dnn_mmod_train_find_cars_ex
)
add_gui_example
(
dnn_semantic_segmentation_ex
)
if
(
NOT MSVC
)
# Don't try to compile these programs using Visual Studio since it causes the
# compiler to run out of RAM and to crash. Maybe someday Visual Studio
# won't be broken :(
# (NB: While the 32-bit VC++ compiler launched by the Visual Studio IDE will
# run out of memory, running a 64-bit MSBuild.exe on the Command Prompt
# seems to work fine. So you can try something like this:
# "C:\Program Files (x86)\MSBuild\14.0\Bin\amd64\MSBuild.exe" C:\path\to\examples.sln /p:Configuration=Release /p:Platform=x64 /t:dnn_imagenet_train_ex
# It does take quite a while to build these examples, though!
# Note that you may additionally need to set Debug Information Format to
# C7 compatible (/Z7), in case you get compiler error "cannot update
# program database".)
add_example
(
dnn_imagenet_train_ex
)
add_example
(
dnn_metric_learning_on_images_ex
)
add_example
(
dnn_semantic_segmentation_train_ex
)
endif
()
endif
()
...
...
examples/dnn_semantic_segmentation_ex.cpp
0 → 100644
View file @
483e6ab4
// The contents of this file are in the public domain. See LICENSE_FOR_EXAMPLE_PROGRAMS.txt
/*
This example shows how to do semantic segmentation on an image using net pretrained
on the PASCAL VOC2012 dataset. For an introduction to what segmentation is, see the
accompanying header file dnn_semantic_segmentation_ex.h.
Instructions how to run the example:
1. Download the PASCAL VOC2012 data, and untar it somewhere.
http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar
2. Build the dnn_semantic_segmentation_train_ex example program.
3. Run:
./dnn_semantic_segmentation_train_ex /path/to/VOC2012
4. Wait while the network is being trained.
5. Build the dnn_semantic_segmentation_ex example program.
6. Run:
./dnn_semantic_segmentation_ex /path/to/VOC2012-or-other-images
An alternative to steps 2-4 above is to download a pre-trained network
from here: http://dlib.net/files/voc2012net.dnn
It would be a good idea to become familiar with dlib's DNN tooling before reading this
example. So you should read dnn_introduction_ex.cpp and dnn_introduction2_ex.cpp
before reading this example program.
*/
#include "dnn_semantic_segmentation_ex.h"
#include <iostream>
#include <dlib/data_io.h>
#include <dlib/gui_widgets.h>
using
namespace
std
;
using
namespace
dlib
;
// ----------------------------------------------------------------------------------------
// The PASCAL VOC2012 dataset contains 20 ground-truth classes + background. Each class
// is represented using an RGB color value. We associate each class also to an index in the
// range [0, 20], used internally by the network. To generate nice RGB representations of
// inference results, we need to be able to convert the index values to the corresponding
// RGB values.
// Given an index in the range [0, 20], find the corresponding PASCAL VOC2012 class
// (e.g., 'dog').
const
Voc2012class
&
find_voc2012_class
(
const
uint16_t
&
index_label
)
{
return
find_voc2012_class
(
[
&
index_label
](
const
Voc2012class
&
voc2012class
)
{
return
index_label
==
voc2012class
.
index
;
}
);
}
// Convert an index in the range [0, 20] to a corresponding RGB class label.
inline
rgb_pixel
index_label_to_rgb_label
(
uint16_t
index_label
)
{
return
find_voc2012_class
(
index_label
).
rgb_label
;
}
// Convert an image containing indexes in the range [0, 20] to a corresponding
// image containing RGB class labels.
void
index_label_image_to_rgb_label_image
(
const
matrix
<
uint16_t
>&
index_label_image
,
matrix
<
rgb_pixel
>&
rgb_label_image
)
{
const
long
nr
=
index_label_image
.
nr
();
const
long
nc
=
index_label_image
.
nc
();
rgb_label_image
.
set_size
(
nr
,
nc
);
for
(
long
r
=
0
;
r
<
nr
;
++
r
)
{
for
(
long
c
=
0
;
c
<
nc
;
++
c
)
{
rgb_label_image
(
r
,
c
)
=
index_label_to_rgb_label
(
index_label_image
(
r
,
c
));
}
}
}
// Find the most prominent class label from amongst the per-pixel predictions.
std
::
string
get_most_prominent_non_background_classlabel
(
const
matrix
<
uint16_t
>&
index_label_image
)
{
const
long
nr
=
index_label_image
.
nr
();
const
long
nc
=
index_label_image
.
nc
();
std
::
vector
<
unsigned
int
>
counters
(
class_count
);
for
(
long
r
=
0
;
r
<
nr
;
++
r
)
{
for
(
long
c
=
0
;
c
<
nc
;
++
c
)
{
const
uint16_t
label
=
index_label_image
(
r
,
c
);
++
counters
[
label
];
}
}
const
auto
max_element
=
std
::
max_element
(
counters
.
begin
()
+
1
,
counters
.
end
());
const
uint16_t
most_prominent_index_label
=
max_element
-
counters
.
begin
();
return
find_voc2012_class
(
most_prominent_index_label
).
classlabel
;
}
// ----------------------------------------------------------------------------------------
int
main
(
int
argc
,
char
**
argv
)
try
{
if
(
argc
!=
2
)
{
cout
<<
"You call this program like this: "
<<
endl
;
cout
<<
"./dnn_semantic_segmentation_train_ex /path/to/images"
<<
endl
;
cout
<<
endl
;
cout
<<
"You will also need a trained 'voc2012net.dnn' file."
<<
endl
;
cout
<<
"You can either train it yourself (see example program"
<<
endl
;
cout
<<
"dnn_semantic_segmentation_train_ex), or download a"
<<
endl
;
cout
<<
"copy from here: http://dlib.net/files/voc2012net.dnn"
<<
endl
;
return
1
;
}
// Read the file containing the trained network from the working directory.
anet_type
net
;
deserialize
(
"voc2012net.dnn"
)
>>
net
;
// Show inference results in a window.
image_window
win
;
matrix
<
rgb_pixel
>
input_image
;
matrix
<
uint16_t
>
index_label_image
;
matrix
<
rgb_pixel
>
rgb_label_image
;
// Find supported image files.
const
std
::
vector
<
file
>
files
=
dlib
::
get_files_in_directory_tree
(
argv
[
1
],
dlib
::
match_endings
(
".jpeg .jpg .png"
));
cout
<<
"Found "
<<
files
.
size
()
<<
" images, processing..."
<<
endl
;
for
(
const
file
&
file
:
files
)
{
// Load the input image.
load_image
(
input_image
,
file
.
full_name
());
// Create predictions for each pixel. At this point, the type of each prediction
// is an index (a value between 0 and 20). Note that the net may return an image
// that is not exactly the same size as the input.
const
matrix
<
uint16_t
>
temp
=
net
(
input_image
);
// Crop the returned image to be exactly the same size as the input.
const
chip_details
chip_details
(
centered_rect
(
temp
.
nc
()
/
2
,
temp
.
nr
()
/
2
,
input_image
.
nc
(),
input_image
.
nr
()),
chip_dims
(
input_image
.
nr
(),
input_image
.
nc
())
);
extract_image_chip
(
temp
,
chip_details
,
index_label_image
,
interpolate_nearest_neighbor
());
// Convert the indexes to RGB values.
index_label_image_to_rgb_label_image
(
index_label_image
,
rgb_label_image
);
// Show the input image on the left, and the predicted RGB labels on the right.
win
.
set_image
(
join_rows
(
input_image
,
rgb_label_image
));
// Find the most prominent class label from amongst the per-pixel predictions.
const
std
::
string
classlabel
=
get_most_prominent_non_background_classlabel
(
index_label_image
);
cout
<<
file
.
name
()
<<
" : "
<<
classlabel
<<
" - hit enter to process the next image"
;
cin
.
get
();
}
}
catch
(
std
::
exception
&
e
)
{
cout
<<
e
.
what
()
<<
endl
;
}
examples/dnn_semantic_segmentation_ex.h
0 → 100644
View file @
483e6ab4
// The contents of this file are in the public domain. See LICENSE_FOR_EXAMPLE_PROGRAMS.txt
/*
Semantic segmentation using the PASCAL VOC2012 dataset.
In segmentation, the task is to assign each pixel of an input image
a label - for example, 'dog'. Then, the idea is that neighboring
pixels having the same label can be connected together to form a
larger region, representing a complete (or partially occluded) dog.
So technically, segmentation can be viewed as classification of
individual pixels (using the relevant context in the input images),
however the goal usually is to identify meaningful regions that
represent complete entities of interest (such as dogs).
Instructions how to run the example:
1. Download the PASCAL VOC2012 data, and untar it somewhere.
http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar
2. Build the dnn_semantic_segmentation_train_ex example program.
3. Run:
./dnn_semantic_segmentation_train_ex /path/to/VOC2012
4. Wait while the network is being trained.
5. Build the dnn_semantic_segmentation_ex example program.
6. Run:
./dnn_semantic_segmentation_ex /path/to/VOC2012-or-other-images
An alternative to steps 2-4 above is to download a pre-trained network
from here: http://dlib.net/files/voc2012net.dnn
It would be a good idea to become familiar with dlib's DNN tooling before reading this
example. So you should read dnn_introduction_ex.cpp and dnn_introduction2_ex.cpp
before reading this example program.
*/
#ifndef DLIB_DNn_SEMANTIC_SEGMENTATION_EX_H_
#define DLIB_DNn_SEMANTIC_SEGMENTATION_EX_H_
#include <dlib/dnn.h>
// ----------------------------------------------------------------------------------------
inline
bool
operator
==
(
const
dlib
::
rgb_pixel
&
a
,
const
dlib
::
rgb_pixel
&
b
)
{
return
a
.
red
==
b
.
red
&&
a
.
green
==
b
.
green
&&
a
.
blue
==
b
.
blue
;
}
// ----------------------------------------------------------------------------------------
// The PASCAL VOC2012 dataset contains 20 ground-truth classes + background. Each class
// is represented using an RGB color value. We associate each class also to an index in the
// range [0, 20], used internally by the network.
struct
Voc2012class
{
Voc2012class
(
uint16_t
index
,
const
dlib
::
rgb_pixel
&
rgb_label
,
const
std
::
string
&
classlabel
)
:
index
(
index
),
rgb_label
(
rgb_label
),
classlabel
(
classlabel
)
{}
// The index of the class. In the PASCAL VOC 2012 dataset, indexes from 0 to 20 are valid.
const
uint16_t
index
=
0
;
// The corresponding RGB representation of the class.
const
dlib
::
rgb_pixel
rgb_label
;
// The label of the class in plain text.
const
std
::
string
classlabel
;
};
namespace
{
constexpr
int
class_count
=
21
;
// background + 20 classes
const
std
::
vector
<
Voc2012class
>
classes
=
{
Voc2012class
(
0
,
dlib
::
rgb_pixel
(
0
,
0
,
0
),
""
),
// background
// The cream-colored `void' label is used in border regions and to mask difficult objects
// (see http://host.robots.ox.ac.uk/pascal/VOC/voc2012/htmldoc/devkit_doc.html)
Voc2012class
(
dlib
::
loss_multiclass_log_per_pixel_
::
label_to_ignore
,
dlib
::
rgb_pixel
(
224
,
224
,
192
),
"border"
),
Voc2012class
(
1
,
dlib
::
rgb_pixel
(
128
,
0
,
0
),
"aeroplane"
),
Voc2012class
(
2
,
dlib
::
rgb_pixel
(
0
,
128
,
0
),
"bicycle"
),
Voc2012class
(
3
,
dlib
::
rgb_pixel
(
128
,
128
,
0
),
"bird"
),
Voc2012class
(
4
,
dlib
::
rgb_pixel
(
0
,
0
,
128
),
"boat"
),
Voc2012class
(
5
,
dlib
::
rgb_pixel
(
128
,
0
,
128
),
"bottle"
),
Voc2012class
(
6
,
dlib
::
rgb_pixel
(
0
,
128
,
128
),
"bus"
),
Voc2012class
(
7
,
dlib
::
rgb_pixel
(
128
,
128
,
128
),
"car"
),
Voc2012class
(
8
,
dlib
::
rgb_pixel
(
64
,
0
,
0
),
"cat"
),
Voc2012class
(
9
,
dlib
::
rgb_pixel
(
192
,
0
,
0
),
"chair"
),
Voc2012class
(
10
,
dlib
::
rgb_pixel
(
64
,
128
,
0
),
"cow"
),
Voc2012class
(
11
,
dlib
::
rgb_pixel
(
192
,
128
,
0
),
"diningtable"
),
Voc2012class
(
12
,
dlib
::
rgb_pixel
(
64
,
0
,
128
),
"dog"
),
Voc2012class
(
13
,
dlib
::
rgb_pixel
(
192
,
0
,
128
),
"horse"
),
Voc2012class
(
14
,
dlib
::
rgb_pixel
(
64
,
128
,
128
),
"motorbike"
),
Voc2012class
(
15
,
dlib
::
rgb_pixel
(
192
,
128
,
128
),
"person"
),
Voc2012class
(
16
,
dlib
::
rgb_pixel
(
0
,
64
,
0
),
"pottedplant"
),
Voc2012class
(
17
,
dlib
::
rgb_pixel
(
128
,
64
,
0
),
"sheep"
),
Voc2012class
(
18
,
dlib
::
rgb_pixel
(
0
,
192
,
0
),
"sofa"
),
Voc2012class
(
19
,
dlib
::
rgb_pixel
(
128
,
192
,
0
),
"train"
),
Voc2012class
(
20
,
dlib
::
rgb_pixel
(
0
,
64
,
128
),
"tvmonitor"
),
};
}
template
<
typename
Predicate
>
const
Voc2012class
&
find_voc2012_class
(
Predicate
predicate
)
{
const
auto
i
=
std
::
find_if
(
classes
.
begin
(),
classes
.
end
(),
predicate
);
if
(
i
!=
classes
.
end
())
{
return
*
i
;
}
else
{
throw
std
::
runtime_error
(
"Unable to find a matching VOC2012 class"
);
}
}
// ----------------------------------------------------------------------------------------
// Introduce the building blocks used to define the segmentation network.
// The network first does residual downsampling (similar to the dnn_imagenet_(train_)ex
// example program), and then residual upsampling. The network could be improved e.g.
// by introducing skip connections from the input image, and/or the first layers, to the
// last layer(s). (See Long et al., Fully Convolutional Networks for Semantic Segmentation,
// https://people.eecs.berkeley.edu/~jonlong/long_shelhamer_fcn.pdf)
template
<
int
N
,
template
<
typename
>
class
BN
,
int
stride
,
typename
SUBNET
>
using
block
=
BN
<
dlib
::
con
<
N
,
3
,
3
,
1
,
1
,
dlib
::
relu
<
BN
<
dlib
::
con
<
N
,
3
,
3
,
stride
,
stride
,
SUBNET
>>>>>
;
template
<
int
N
,
template
<
typename
>
class
BN
,
int
stride
,
typename
SUBNET
>
using
blockt
=
BN
<
dlib
::
cont
<
N
,
3
,
3
,
1
,
1
,
dlib
::
relu
<
BN
<
dlib
::
cont
<
N
,
3
,
3
,
stride
,
stride
,
SUBNET
>>>>>
;
template
<
template
<
int
,
template
<
typename
>
class
,
int
,
typename
>
class
block
,
int
N
,
template
<
typename
>
class
BN
,
typename
SUBNET
>
using
residual
=
dlib
::
add_prev1
<
block
<
N
,
BN
,
1
,
dlib
::
tag1
<
SUBNET
>>>
;
template
<
template
<
int
,
template
<
typename
>
class
,
int
,
typename
>
class
block
,
int
N
,
template
<
typename
>
class
BN
,
typename
SUBNET
>
using
residual_down
=
dlib
::
add_prev2
<
dlib
::
avg_pool
<
2
,
2
,
2
,
2
,
dlib
::
skip1
<
dlib
::
tag2
<
block
<
N
,
BN
,
2
,
dlib
::
tag1
<
SUBNET
>>>>>>
;
template
<
template
<
int
,
template
<
typename
>
class
,
int
,
typename
>
class
block
,
int
N
,
template
<
typename
>
class
BN
,
typename
SUBNET
>
using
residual_up
=
dlib
::
add_prev2
<
dlib
::
cont
<
N
,
2
,
2
,
2
,
2
,
dlib
::
skip1
<
dlib
::
tag2
<
blockt
<
N
,
BN
,
2
,
dlib
::
tag1
<
SUBNET
>>>>>>
;
template
<
int
N
,
typename
SUBNET
>
using
res
=
dlib
::
relu
<
residual
<
block
,
N
,
dlib
::
bn_con
,
SUBNET
>>
;
template
<
int
N
,
typename
SUBNET
>
using
ares
=
dlib
::
relu
<
residual
<
block
,
N
,
dlib
::
affine
,
SUBNET
>>
;
template
<
int
N
,
typename
SUBNET
>
using
res_down
=
dlib
::
relu
<
residual_down
<
block
,
N
,
dlib
::
bn_con
,
SUBNET
>>
;
template
<
int
N
,
typename
SUBNET
>
using
ares_down
=
dlib
::
relu
<
residual_down
<
block
,
N
,
dlib
::
affine
,
SUBNET
>>
;
template
<
int
N
,
typename
SUBNET
>
using
res_up
=
dlib
::
relu
<
residual_up
<
block
,
N
,
dlib
::
bn_con
,
SUBNET
>>
;
template
<
int
N
,
typename
SUBNET
>
using
ares_up
=
dlib
::
relu
<
residual_up
<
block
,
N
,
dlib
::
affine
,
SUBNET
>>
;
// ----------------------------------------------------------------------------------------
template
<
typename
SUBNET
>
using
level1
=
res
<
512
,
res
<
512
,
res_down
<
512
,
SUBNET
>>>
;
template
<
typename
SUBNET
>
using
level2
=
res
<
256
,
res
<
256
,
res_down
<
256
,
SUBNET
>>>
;
template
<
typename
SUBNET
>
using
level3
=
res
<
128
,
res
<
128
,
res_down
<
128
,
SUBNET
>>>
;
template
<
typename
SUBNET
>
using
level4
=
res
<
64
,
res
<
64
,
res
<
64
,
SUBNET
>>>
;
template
<
typename
SUBNET
>
using
alevel1
=
ares
<
512
,
ares
<
512
,
ares_down
<
512
,
SUBNET
>>>
;
template
<
typename
SUBNET
>
using
alevel2
=
ares
<
256
,
ares
<
256
,
ares_down
<
256
,
SUBNET
>>>
;
template
<
typename
SUBNET
>
using
alevel3
=
ares
<
128
,
ares
<
128
,
ares_down
<
128
,
SUBNET
>>>
;
template
<
typename
SUBNET
>
using
alevel4
=
ares
<
64
,
ares
<
64
,
ares
<
64
,
SUBNET
>>>
;
template
<
typename
SUBNET
>
using
level1t
=
res
<
512
,
res
<
512
,
res_up
<
512
,
SUBNET
>>>
;
template
<
typename
SUBNET
>
using
level2t
=
res
<
256
,
res
<
256
,
res_up
<
256
,
SUBNET
>>>
;
template
<
typename
SUBNET
>
using
level3t
=
res
<
128
,
res
<
128
,
res_up
<
128
,
SUBNET
>>>
;
template
<
typename
SUBNET
>
using
level4t
=
res
<
64
,
res
<
64
,
res_up
<
64
,
SUBNET
>>>
;
template
<
typename
SUBNET
>
using
alevel1t
=
ares
<
512
,
ares
<
512
,
ares_up
<
512
,
SUBNET
>>>
;
template
<
typename
SUBNET
>
using
alevel2t
=
ares
<
256
,
ares
<
256
,
ares_up
<
256
,
SUBNET
>>>
;
template
<
typename
SUBNET
>
using
alevel3t
=
ares
<
128
,
ares
<
128
,
ares_up
<
128
,
SUBNET
>>>
;
template
<
typename
SUBNET
>
using
alevel4t
=
ares
<
64
,
ares
<
64
,
ares_up
<
64
,
SUBNET
>>>
;
// ----------------------------------------------------------------------------------------
// training network type
using
net_type
=
dlib
::
loss_multiclass_log_per_pixel
<
dlib
::
cont
<
class_count
,
7
,
7
,
2
,
2
,
level4t
<
level3t
<
level2t
<
level1t
<
level1
<
level2
<
level3
<
level4
<
dlib
::
max_pool
<
3
,
3
,
2
,
2
,
dlib
::
relu
<
dlib
::
bn_con
<
dlib
::
con
<
64
,
7
,
7
,
2
,
2
,
dlib
::
input
<
dlib
::
matrix
<
dlib
::
rgb_pixel
>>
>>>>>>>>>>>>>>
;
// testing network type (replaced batch normalization with fixed affine transforms)
using
anet_type
=
dlib
::
loss_multiclass_log_per_pixel
<
dlib
::
cont
<
class_count
,
7
,
7
,
2
,
2
,
alevel4t
<
alevel3t
<
alevel2t
<
alevel1t
<
alevel1
<
alevel2
<
alevel3
<
alevel4
<
dlib
::
max_pool
<
3
,
3
,
2
,
2
,
dlib
::
relu
<
dlib
::
affine
<
dlib
::
con
<
64
,
7
,
7
,
2
,
2
,
dlib
::
input
<
dlib
::
matrix
<
dlib
::
rgb_pixel
>>
>>>>>>>>>>>>>>
;
// ----------------------------------------------------------------------------------------
#endif // DLIB_DNn_SEMANTIC_SEGMENTATION_EX_H_
\ No newline at end of file
examples/dnn_semantic_segmentation_train_ex.cpp
0 → 100644
View file @
483e6ab4
// The contents of this file are in the public domain. See LICENSE_FOR_EXAMPLE_PROGRAMS.txt
/*
This example shows how to train a semantic segmentation net using the PASCAL VOC2012
dataset. For an introduction to what segmentation is, see the accompanying header file
dnn_semantic_segmentation_ex.h.
Instructions how to run the example:
1. Download the PASCAL VOC2012 data, and untar it somewhere.
http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar
2. Build the dnn_semantic_segmentation_train_ex example program.
3. Run:
./dnn_semantic_segmentation_train_ex /path/to/VOC2012
4. Wait while the network is being trained.
5. Build the dnn_semantic_segmentation_ex example program.
6. Run:
./dnn_semantic_segmentation_ex /path/to/VOC2012-or-other-images
It would be a good idea to become familiar with dlib's DNN tooling before reading this
example. So you should read dnn_introduction_ex.cpp and dnn_introduction2_ex.cpp
before reading this example program.
*/
#include "dnn_semantic_segmentation_ex.h"
#include <iostream>
#include <dlib/data_io.h>
#include <dlib/image_transforms.h>
#include <dlib/dir_nav.h>
#include <iterator>
#include <thread>
using
namespace
std
;
using
namespace
dlib
;
// A single training sample. A mini-batch comprises many of these.
struct
training_sample
{
matrix
<
rgb_pixel
>
input_image
;
matrix
<
uint16_t
>
label_image
;
// The ground-truth label of each pixel.
};
// ----------------------------------------------------------------------------------------
rectangle
make_random_cropping_rect_resnet
(
const
matrix
<
rgb_pixel
>&
img
,
dlib
::
rand
&
rnd
)
{
// figure out what rectangle we want to crop from the image
double
mins
=
0.466666666
,
maxs
=
0.875
;
auto
scale
=
mins
+
rnd
.
get_random_double
()
*
(
maxs
-
mins
);
auto
size
=
scale
*
std
::
min
(
img
.
nr
(),
img
.
nc
());
rectangle
rect
(
size
,
size
);
// randomly shift the box around
point
offset
(
rnd
.
get_random_32bit_number
()
%
(
img
.
nc
()
-
rect
.
width
()),
rnd
.
get_random_32bit_number
()
%
(
img
.
nr
()
-
rect
.
height
()));
return
move_rect
(
rect
,
offset
);
}
// ----------------------------------------------------------------------------------------
void
randomly_crop_image
(
const
matrix
<
rgb_pixel
>&
input_image
,
const
matrix
<
uint16_t
>&
label_image
,
training_sample
&
crop
,
dlib
::
rand
&
rnd
)
{
const
auto
rect
=
make_random_cropping_rect_resnet
(
input_image
,
rnd
);
const
chip_details
chip_details
(
rect
,
chip_dims
(
227
,
227
));
// Crop the input image.
extract_image_chip
(
input_image
,
chip_details
,
crop
.
input_image
,
interpolate_bilinear
());
// Crop the labels correspondingly. However, note that here bilinear
// interpolation would make absolutely no sense - you wouldn't say that
// a bicycle is half-way between an aeroplane and a bird, would you?
extract_image_chip
(
label_image
,
chip_details
,
crop
.
label_image
,
interpolate_nearest_neighbor
());
// Also randomly flip the input image and the labels.
if
(
rnd
.
get_random_double
()
>
0.5
)
{
crop
.
input_image
=
fliplr
(
crop
.
input_image
);
crop
.
label_image
=
fliplr
(
crop
.
label_image
);
}
// And then randomly adjust the colors.
apply_random_color_offset
(
crop
.
input_image
,
rnd
);
}
// ----------------------------------------------------------------------------------------
// The names of the input image and the associated RGB label image in the PASCAL VOC 2012
// data set.
struct
image_info
{
string
image_filename
;
string
label_filename
;
};
// Read the list of image files belonging to either the "train", "trainval", or "val" set
// of the PASCAL VOC2012 data.
std
::
vector
<
image_info
>
get_pascal_voc2012_listing
(
const
std
::
string
&
voc2012_folder
,
const
std
::
string
&
file
=
"train"
// "train", "trainval", or "val"
)
{
std
::
ifstream
in
(
voc2012_folder
+
"/ImageSets/Segmentation/"
+
file
+
".txt"
);
std
::
vector
<
image_info
>
results
;
while
(
in
)
{
std
::
string
basename
;
in
>>
basename
;
if
(
!
basename
.
empty
())
{
image_info
image_info
;
image_info
.
image_filename
=
voc2012_folder
+
"/JPEGImages/"
+
basename
+
".jpg"
;
image_info
.
label_filename
=
voc2012_folder
+
"/SegmentationClass/"
+
basename
+
".png"
;
results
.
push_back
(
image_info
);
}
}
return
results
;
}
// Read the list of image files belong to the "train" set of the PASCAL VOC2012 data.
std
::
vector
<
image_info
>
get_pascal_voc2012_train_listing
(
const
std
::
string
&
voc2012_folder
)
{
return
get_pascal_voc2012_listing
(
voc2012_folder
,
"train"
);
}
// Read the list of image files belong to the "val" set of the PASCAL VOC2012 data.
std
::
vector
<
image_info
>
get_pascal_voc2012_val_listing
(
const
std
::
string
&
voc2012_folder
)
{
return
get_pascal_voc2012_listing
(
voc2012_folder
,
"val"
);
}
// ----------------------------------------------------------------------------------------
// The PASCAL VOC2012 dataset contains 20 ground-truth classes + background. Each class
// is represented using an RGB color value. We associate each class also to an index in the
// range [0, 20], used internally by the network. To convert the ground-truth data to
// something that the network can efficiently digest, we need to be able to map the RGB
// values to the corresponding indexes.
// Given an RGB representation, find the corresponding PASCAL VOC2012 class
// (e.g., 'dog').
const
Voc2012class
&
find_voc2012_class
(
const
dlib
::
rgb_pixel
&
rgb_label
)
{
return
find_voc2012_class
(
[
&
rgb_label
](
const
Voc2012class
&
voc2012class
)
{
return
rgb_label
==
voc2012class
.
rgb_label
;
}
);
}
// Convert an RGB class label to an index in the range [0, 20].
inline
uint16_t
rgb_label_to_index_label
(
const
dlib
::
rgb_pixel
&
rgb_label
)
{
return
find_voc2012_class
(
rgb_label
).
index
;
}
// Convert an image containing RGB class labels to a corresponding
// image containing indexes in the range [0, 20].
void
rgb_label_image_to_index_label_image
(
const
dlib
::
matrix
<
dlib
::
rgb_pixel
>&
rgb_label_image
,
dlib
::
matrix
<
uint16_t
>&
index_label_image
)
{
const
long
nr
=
rgb_label_image
.
nr
();
const
long
nc
=
rgb_label_image
.
nc
();
index_label_image
.
set_size
(
nr
,
nc
);
for
(
long
r
=
0
;
r
<
nr
;
++
r
)
{
for
(
long
c
=
0
;
c
<
nc
;
++
c
)
{
index_label_image
(
r
,
c
)
=
rgb_label_to_index_label
(
rgb_label_image
(
r
,
c
));
}
}
}
// ----------------------------------------------------------------------------------------
// Calculate the per-pixel accuracy on a dataset whose file names are supplied as a parameter.
double
calculate_accuracy
(
anet_type
&
anet
,
const
std
::
vector
<
image_info
>&
dataset
)
{
int
num_right
=
0
;
int
num_wrong
=
0
;
matrix
<
rgb_pixel
>
input_image
;
matrix
<
rgb_pixel
>
rgb_label_image
;
matrix
<
uint16_t
>
index_label_image
;
matrix
<
uint16_t
>
net_output
;
for
(
const
auto
&
image_info
:
dataset
)
{
// Load the input image.
load_image
(
input_image
,
image_info
.
image_filename
);
// Load the ground-truth (RGB) labels.
load_image
(
rgb_label_image
,
image_info
.
label_filename
);
// Create predictions for each pixel. At this point, the type of each prediction
// is an index (a value between 0 and 20). Note that the net may return an image
// that is not exactly the same size as the input.
const
matrix
<
uint16_t
>
temp
=
anet
(
input_image
);
// Convert the indexes to RGB values.
rgb_label_image_to_index_label_image
(
rgb_label_image
,
index_label_image
);
// Crop the net output to be exactly the same size as the input.
const
chip_details
chip_details
(
centered_rect
(
temp
.
nc
()
/
2
,
temp
.
nr
()
/
2
,
input_image
.
nc
(),
input_image
.
nr
()),
chip_dims
(
input_image
.
nr
(),
input_image
.
nc
())
);
extract_image_chip
(
temp
,
chip_details
,
net_output
,
interpolate_nearest_neighbor
());
const
long
nr
=
index_label_image
.
nr
();
const
long
nc
=
index_label_image
.
nc
();
// Compare the predicted values to the ground-truth values.
for
(
long
r
=
0
;
r
<
nr
;
++
r
)
{
for
(
long
c
=
0
;
c
<
nc
;
++
c
)
{
const
uint16_t
truth
=
index_label_image
(
r
,
c
);
if
(
truth
!=
dlib
::
loss_multiclass_log_per_pixel_
::
label_to_ignore
)
{
const
uint16_t
prediction
=
net_output
(
r
,
c
);
if
(
prediction
==
truth
)
{
++
num_right
;
}
else
{
++
num_wrong
;
}
}
}
}
}
// Return the accuracy estimate.
return
num_right
/
static_cast
<
double
>
(
num_right
+
num_wrong
);
}
// ----------------------------------------------------------------------------------------
int
main
(
int
argc
,
char
**
argv
)
try
{
if
(
argc
!=
2
)
{
cout
<<
"To run this program you need a copy of the PASCAL VOC2012 dataset."
<<
endl
;
cout
<<
endl
;
cout
<<
"You call this program like this: "
<<
endl
;
cout
<<
"./dnn_semantic_segmentation_train_ex /path/to/VOC2012"
<<
endl
;
return
1
;
}
cout
<<
"
\n
SCANNING PASCAL VOC2012 DATASET
\n
"
<<
endl
;
const
auto
listing
=
get_pascal_voc2012_train_listing
(
argv
[
1
]);
cout
<<
"images in dataset: "
<<
listing
.
size
()
<<
endl
;
if
(
listing
.
size
()
==
0
)
{
cout
<<
"Didn't find the VOC2012 dataset. "
<<
endl
;
return
1
;
}
const
double
initial_learning_rate
=
0.1
;
const
double
weight_decay
=
0.0001
;
const
double
momentum
=
0.9
;
net_type
net
;
dnn_trainer
<
net_type
>
trainer
(
net
,
sgd
(
weight_decay
,
momentum
));
trainer
.
be_verbose
();
trainer
.
set_learning_rate
(
initial_learning_rate
);
trainer
.
set_synchronization_file
(
"pascal_voc2012_trainer_state_file.dat"
,
std
::
chrono
::
minutes
(
10
));
// This threshold is probably excessively large.
trainer
.
set_iterations_without_progress_threshold
(
5000
);
// Since the progress threshold is so large might as well set the batch normalization
// stats window to something big too.
set_all_bn_running_stats_window_sizes
(
net
,
1000
);
// Output training parameters.
cout
<<
endl
<<
trainer
<<
endl
;
std
::
vector
<
matrix
<
rgb_pixel
>>
samples
;
std
::
vector
<
matrix
<
uint16_t
>>
labels
;
// Start a bunch of threads that read images from disk and pull out random crops. It's
// important to be sure to feed the GPU fast enough to keep it busy. Using multiple
// thread for this kind of data preparation helps us do that. Each thread puts the
// crops into the data queue.
dlib
::
pipe
<
training_sample
>
data
(
200
);
auto
f
=
[
&
data
,
&
listing
](
time_t
seed
)
{
dlib
::
rand
rnd
(
time
(
0
)
+
seed
);
matrix
<
rgb_pixel
>
input_image
;
matrix
<
rgb_pixel
>
rgb_label_image
;
matrix
<
uint16_t
>
index_label_image
;
training_sample
temp
;
while
(
data
.
is_enabled
())
{
// Pick a random input image.
const
image_info
&
image_info
=
listing
[
rnd
.
get_random_32bit_number
()
%
listing
.
size
()];
// Load the input image.
load_image
(
input_image
,
image_info
.
image_filename
);
// Load the ground-truth (RGB) labels.
load_image
(
rgb_label_image
,
image_info
.
label_filename
);
// Convert the indexes to RGB values.
rgb_label_image_to_index_label_image
(
rgb_label_image
,
index_label_image
);
// Randomly pick a part of the image.
randomly_crop_image
(
input_image
,
index_label_image
,
temp
,
rnd
);
// Push the result to be used by the trainer.
data
.
enqueue
(
temp
);
}
};
std
::
thread
data_loader1
([
f
](){
f
(
1
);
});
std
::
thread
data_loader2
([
f
](){
f
(
2
);
});
std
::
thread
data_loader3
([
f
](){
f
(
3
);
});
std
::
thread
data_loader4
([
f
](){
f
(
4
);
});
// The main training loop. Keep making mini-batches and giving them to the trainer.
// We will run until the learning rate has dropped by a factor of 1e-4.
while
(
trainer
.
get_learning_rate
()
>=
1e-4
)
{
samples
.
clear
();
labels
.
clear
();
// make a 30-image mini-batch
training_sample
temp
;
while
(
samples
.
size
()
<
30
)
{
data
.
dequeue
(
temp
);
samples
.
push_back
(
std
::
move
(
temp
.
input_image
));
labels
.
push_back
(
std
::
move
(
temp
.
label_image
));
}
trainer
.
train_one_step
(
samples
,
labels
);
}
// Training done, tell threads to stop and make sure to wait for them to finish before
// moving on.
data
.
disable
();
data_loader1
.
join
();
data_loader2
.
join
();
data_loader3
.
join
();
data_loader4
.
join
();
// also wait for threaded processing to stop in the trainer.
trainer
.
get_net
();
net
.
clean
();
cout
<<
"saving network"
<<
endl
;
serialize
(
"voc2012net.dnn"
)
<<
net
;
// Make a copy of the network to use it for inference.
anet_type
anet
=
net
;
cout
<<
"Testing the network..."
<<
endl
;
// Find the accuracy of the newly trained network on both the training and the validation sets.
cout
<<
"train accuracy : "
<<
calculate_accuracy
(
anet
,
get_pascal_voc2012_train_listing
(
argv
[
1
]))
<<
endl
;
cout
<<
"val accuracy : "
<<
calculate_accuracy
(
anet
,
get_pascal_voc2012_val_listing
(
argv
[
1
]))
<<
endl
;
}
catch
(
std
::
exception
&
e
)
{
cout
<<
e
.
what
()
<<
endl
;
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment