Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in
Toggle navigation
D
dlib
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
钟尚武
dlib
Commits
19856946
Commit
19856946
authored
Jul 04, 2017
by
Davis King
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Added tt::resize_bilinear() and tt::resize_bilinear_gradient().
parent
9cd06ce3
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
371 additions
and
5 deletions
+371
-5
cpu_dlib.cpp
dlib/dnn/cpu_dlib.cpp
+101
-0
cpu_dlib.h
dlib/dnn/cpu_dlib.h
+12
-0
cuda_dlib.cu
dlib/dnn/cuda_dlib.cu
+110
-0
cuda_dlib.h
dlib/dnn/cuda_dlib.h
+21
-5
tensor_tools.cpp
dlib/dnn/tensor_tools.cpp
+26
-0
tensor_tools.h
dlib/dnn/tensor_tools.h
+35
-0
dnn.cpp
dlib/test/dnn.cpp
+66
-0
No files found.
dlib/dnn/cpu_dlib.cpp
View file @
19856946
...
...
@@ -1401,6 +1401,107 @@ namespace dlib
}
}
// ----------------------------------------------------------------------------------------
void
resize_bilinear
(
tensor
&
dest
,
const
tensor
&
src
)
{
DLIB_CASSERT
(
is_same_object
(
dest
,
src
)
==
false
);
DLIB_CASSERT
(
dest
.
num_samples
()
==
src
.
num_samples
());
DLIB_CASSERT
(
dest
.
k
()
==
src
.
k
());
if
(
dest
.
size
()
==
0
||
src
.
size
()
==
0
)
return
;
const
float
*
s
=
src
.
host
();
float
*
d
=
dest
.
host
();
const
float
x_scale
=
(
src
.
nc
()
-
1
)
/
(
float
)
std
::
max
<
long
>
((
dest
.
nc
()
-
1
),
1
);
const
float
y_scale
=
(
src
.
nr
()
-
1
)
/
(
float
)
std
::
max
<
long
>
((
dest
.
nr
()
-
1
),
1
);
for
(
long
samp
=
0
;
samp
<
dest
.
num_samples
();
++
samp
)
{
for
(
long
k
=
0
;
k
<
dest
.
k
();
++
k
)
{
for
(
long
r
=
0
;
r
<
dest
.
nr
();
++
r
)
{
const
float
y
=
r
*
y_scale
;
const
long
top
=
static_cast
<
long
>
(
std
::
floor
(
y
));
const
long
bottom
=
std
::
min
(
top
+
1
,
src
.
nr
()
-
1
);
const
float
tb_frac
=
y
-
top
;
for
(
long
c
=
0
;
c
<
dest
.
nc
();
++
c
)
{
const
float
x
=
c
*
x_scale
;
const
long
left
=
static_cast
<
long
>
(
std
::
floor
(
x
));
const
long
right
=
std
::
min
(
left
+
1
,
src
.
nc
()
-
1
);
const
float
lr_frac
=
x
-
left
;
float
tl
=
s
[
top
*
src
.
nc
()
+
left
];
float
tr
=
s
[
top
*
src
.
nc
()
+
right
];
float
bl
=
s
[
bottom
*
src
.
nc
()
+
left
];
float
br
=
s
[
bottom
*
src
.
nc
()
+
right
];
float
temp
=
(
1
-
tb_frac
)
*
((
1
-
lr_frac
)
*
tl
+
lr_frac
*
tr
)
+
tb_frac
*
((
1
-
lr_frac
)
*
bl
+
lr_frac
*
br
);
d
[
r
*
dest
.
nc
()
+
c
]
=
temp
;
}
}
d
+=
dest
.
nr
()
*
dest
.
nc
();
s
+=
src
.
nr
()
*
src
.
nc
();
}
}
}
void
resize_bilinear_gradient
(
tensor
&
grad
,
const
tensor
&
gradient_input
)
{
DLIB_CASSERT
(
is_same_object
(
grad
,
gradient_input
)
==
false
);
DLIB_CASSERT
(
gradient_input
.
num_samples
()
==
grad
.
num_samples
());
DLIB_CASSERT
(
gradient_input
.
k
()
==
grad
.
k
());
if
(
gradient_input
.
size
()
==
0
||
grad
.
size
()
==
0
)
return
;
const
float
*
gi
=
gradient_input
.
host
();
float
*
g
=
grad
.
host
();
const
float
x_scale
=
(
grad
.
nc
()
-
1
)
/
(
float
)
std
::
max
<
long
>
((
gradient_input
.
nc
()
-
1
),
1
);
const
float
y_scale
=
(
grad
.
nr
()
-
1
)
/
(
float
)
std
::
max
<
long
>
((
gradient_input
.
nr
()
-
1
),
1
);
for
(
long
samp
=
0
;
samp
<
gradient_input
.
num_samples
();
++
samp
)
{
for
(
long
k
=
0
;
k
<
gradient_input
.
k
();
++
k
)
{
for
(
long
r
=
0
;
r
<
gradient_input
.
nr
();
++
r
)
{
const
float
y
=
r
*
y_scale
;
const
long
top
=
static_cast
<
long
>
(
std
::
floor
(
y
));
const
long
bottom
=
std
::
min
(
top
+
1
,
grad
.
nr
()
-
1
);
const
float
tb_frac
=
y
-
top
;
for
(
long
c
=
0
;
c
<
gradient_input
.
nc
();
++
c
)
{
const
float
x
=
c
*
x_scale
;
const
long
left
=
static_cast
<
long
>
(
std
::
floor
(
x
));
const
long
right
=
std
::
min
(
left
+
1
,
grad
.
nc
()
-
1
);
const
float
lr_frac
=
x
-
left
;
const
float
tmp
=
gi
[
r
*
gradient_input
.
nc
()
+
c
];
g
[
top
*
grad
.
nc
()
+
left
]
+=
tmp
*
(
1
-
tb_frac
)
*
(
1
-
lr_frac
);
g
[
top
*
grad
.
nc
()
+
right
]
+=
tmp
*
(
1
-
tb_frac
)
*
(
lr_frac
);
g
[
bottom
*
grad
.
nc
()
+
left
]
+=
tmp
*
(
tb_frac
)
*
(
1
-
lr_frac
);
g
[
bottom
*
grad
.
nc
()
+
right
]
+=
tmp
*
(
tb_frac
)
*
(
lr_frac
);
}
}
g
+=
grad
.
nr
()
*
grad
.
nc
();
gi
+=
gradient_input
.
nr
()
*
gradient_input
.
nc
();
}
}
}
// ------------------------------------------------------------------------------------
// ------------------------------------------------------------------------------------
// ------------------------------------------------------------------------------------
...
...
dlib/dnn/cpu_dlib.h
View file @
19856946
...
...
@@ -296,6 +296,18 @@ namespace dlib
const
tensor
&
gradient_input
);
// ----------------------------------------------------------------------------------------
void
resize_bilinear
(
tensor
&
dest
,
const
tensor
&
src
);
void
resize_bilinear_gradient
(
tensor
&
grad
,
const
tensor
&
gradient_input
);
// -----------------------------------------------------------------------------------
class
pooling
...
...
dlib/dnn/cuda_dlib.cu
View file @
19856946
...
...
@@ -1137,6 +1137,116 @@ namespace dlib
param.device(), params_grad.device());
}
// ----------------------------------------------------------------------------------------
__global__ void _cuda_resize_bilinear(size_t dsize, size_t dchan_size, size_t dnc, float* d,
size_t schan_size, int snr, int snc, const float* s,
const float x_scale, const float y_scale)
{
for(auto i : grid_stride_range(0, dsize))
{
const int idx = i%dchan_size;
const int channel = i/dchan_size;
const int sidx = channel*schan_size;
const int r = idx/dnc;
const int c = idx%dnc;
const float y = r*y_scale;
const int top = static_cast<int>(::floor(y));
const int bottom = ::min(top+1, snr-1);
const float tb_frac = y - top;
const float x = c*x_scale;
const int left = static_cast<int>(::floor(x));
const int right = ::min(left+1, snc-1);
const float lr_frac = x - left;
float tl = s[sidx+top*snc+left];
float tr = s[sidx+top*snc+right];
float bl = s[sidx+bottom*snc+left];
float br = s[sidx+bottom*snc+right];
float temp = (1-tb_frac)*((1-lr_frac)*tl + lr_frac*tr) +
tb_frac*((1-lr_frac)*bl + lr_frac*br);
d[i] = temp;
}
}
void resize_bilinear (
tensor& dest,
const tensor& src
)
{
DLIB_CASSERT(is_same_object(dest, src)==false);
DLIB_CASSERT(dest.num_samples() == src.num_samples());
DLIB_CASSERT(dest.k() == src.k());
if (dest.size() == 0 || src.size() == 0)
return;
const float x_scale = (src.nc()-1)/(float)std::max<long>((dest.nc()-1),1);
const float y_scale = (src.nr()-1)/(float)std::max<long>((dest.nr()-1),1);
launch_kernel(_cuda_resize_bilinear,
dest.size(), dest.nr()*dest.nc(), dest.nc(), dest.device(),
src.nr()*src.nc(), src.nr(), src.nc(), src.device(),
x_scale, y_scale);
}
__global__ void _cuda_resize_bilinear_gradient(size_t dsize, size_t dchan_size, size_t dnc, const float* d,
size_t schan_size, int snr, int snc, float* s,
const float x_scale, const float y_scale)
{
for(auto i : grid_stride_range(0, dsize))
{
const float tmp = d[i];
const int idx = i%dchan_size;
const int channel = i/dchan_size;
const int sidx = channel*schan_size;
const int r = idx/dnc;
const int c = idx%dnc;
const float y = r*y_scale;
const int top = static_cast<int>(::floor(y));
const int bottom = ::min(top+1, snr-1);
const float tb_frac = y - top;
const float x = c*x_scale;
const int left = static_cast<int>(::floor(x));
const int right = ::min(left+1, snc-1);
const float lr_frac = x - left;
atomicAdd(s+sidx+top*snc+left, tmp*(1-tb_frac)*(1-lr_frac));
atomicAdd(s+sidx+top*snc+right, tmp*(1-tb_frac)*(lr_frac));
atomicAdd(s+sidx+bottom*snc+left, tmp*(tb_frac)*(1-lr_frac));
atomicAdd(s+sidx+bottom*snc+right, tmp*(tb_frac)*(lr_frac));
}
}
void resize_bilinear_gradient (
tensor& grad,
const tensor& gradient_input
)
{
DLIB_CASSERT(is_same_object(grad, gradient_input)==false);
DLIB_CASSERT(gradient_input.num_samples() == grad.num_samples());
DLIB_CASSERT(gradient_input.k() == grad.k());
if (grad.size() == 0 || gradient_input.size() == 0)
return;
const float x_scale = (grad.nc()-1)/(float)std::max<long>((gradient_input.nc()-1),1);
const float y_scale = (grad.nr()-1)/(float)std::max<long>((gradient_input.nr()-1),1);
launch_kernel(_cuda_resize_bilinear_gradient,
gradient_input.size(), gradient_input.nr()*gradient_input.nc(), gradient_input.nc(), gradient_input.device(),
grad.nr()*grad.nc(), grad.nr(), grad.nc(), grad.device(),
x_scale, y_scale);
}
// ----------------------------------------------------------------------------------------
void copy_tensor(
...
...
dlib/dnn/cuda_dlib.h
View file @
19856946
...
...
@@ -346,13 +346,29 @@ namespace dlib
tensor
&
params_grad
);
// ----------------------------------------------------------------------------------------
void
resize_bilinear
(
tensor
&
dest
,
const
tensor
&
src
);
void
resize_bilinear_gradient
(
tensor
&
grad
,
const
tensor
&
gradient_input
);
// ----------------------------------------------------------------------------------------
void
copy_tensor
(
tensor
&
dest
,
size_t
dest_k_offset
,
const
tensor
&
src
,
size_t
src_k_offset
,
size_t
count_k
tensor
&
dest
,
size_t
dest_k_offset
,
const
tensor
&
src
,
size_t
src_k_offset
,
size_t
count_k
);
// ------------------------------------------------------------------------------------
// ------------------------------------------------------------------------------------
// ------------------------------------------------------------------------------------
...
...
dlib/dnn/tensor_tools.cpp
View file @
19856946
...
...
@@ -838,6 +838,32 @@ namespace dlib { namespace tt
#endif
}
// ----------------------------------------------------------------------------------------
void
resize_bilinear
(
tensor
&
dest
,
const
tensor
&
src
)
{
#ifdef DLIB_USE_CUDA
cuda
::
resize_bilinear
(
dest
,
src
);
#else
cpu
::
resize_bilinear
(
dest
,
src
);
#endif
}
void
resize_bilinear_gradient
(
tensor
&
grad
,
const
tensor
&
gradient_input
)
{
#ifdef DLIB_USE_CUDA
cuda
::
resize_bilinear_gradient
(
grad
,
gradient_input
);
#else
cpu
::
resize_bilinear_gradient
(
grad
,
gradient_input
);
#endif
}
// ------------------------------------------------------------------------------------
void
copy_tensor
(
...
...
dlib/dnn/tensor_tools.h
View file @
19856946
...
...
@@ -1350,6 +1350,41 @@ namespace dlib { namespace tt
is_same_object(grad, gradient_input)==true
!*/
// ----------------------------------------------------------------------------------------
void
resize_bilinear
(
tensor
&
dest
,
const
tensor
&
src
);
/*!
requires
- is_same_object(dest, src)==false
- dest.num_samples() == src.num_samples()
- dest.k() == src.k()
ensures
- for all valid i,k: image_plane(dest,i,k) is a copy of image_plane(src,i,k)
that has been bilinearly interpolated to fit into the shape of
image_plane(dest,i,k).
!*/
void
resize_bilinear_gradient
(
tensor
&
grad
,
const
tensor
&
gradient_input
);
/*!
requires
- is_same_object(grad, gradient_input)==false
- gradient_input.num_samples() == grad.num_samples()
- gradient_input.k() == grad.k()
ensures
- Suppose that DEST is the output of resize_bilinear(DEST,SRC) for some SRC
tensor, let f(SRC) == dot(gradient_input,DEST). Then this function computes
the gradient of f() with respect to SRC and adds it to grad. It should be
noted that we don't need to know the contents of DEST to compute this
gradient. All that matters is that gradient_input have the same dimensions
as DEST.
!*/
// ----------------------------------------------------------------------------------------
class
multi_device_tensor_averager
...
...
dlib/test/dnn.cpp
View file @
19856946
...
...
@@ -2303,6 +2303,69 @@ namespace
"Number of correctly classified elements = "
<<
num_correct
<<
", required = "
<<
num_correct_required
);
}
// ----------------------------------------------------------------------------------------
void
test_tensor_resize_bilienar
(
long
samps
,
long
k
,
long
nr
,
long
nc
,
long
onr
,
long
onc
)
{
resizable_tensor
img
(
samps
,
k
,
nr
,
nc
);
resizable_tensor
out
(
samps
,
k
,
onr
,
onc
);
resizable_tensor
out2
(
samps
,
k
,
onr
,
onc
);
dlib
::
rand
rnd
;
for
(
int
iter
=
0
;
iter
<
10
;
++
iter
)
{
print_spinner
();
const
size_t
idx
=
rnd
.
get_random_64bit_number
()
%
img
.
size
();
img
=
1
;
img
.
host
()[
idx
]
=
2
;
cpu
::
resize_bilinear
(
out
,
img
);
#ifdef DLIB_USE_CUDA
cuda
::
resize_bilinear
(
out2
,
img
);
DLIB_CASSERT
(
max
(
abs
(
mat
(
out
)
-
mat
(
out2
)))
<
1e-5
);
#endif
resizable_tensor
gradient_input
;
gradient_input
.
copy_size
(
out
);
tt
::
tensor_rand
rnd
;
rnd
.
fill_uniform
(
gradient_input
);
const
float
h
=
1e-2
;
img
.
host
()[
idx
]
=
2
;
cpu
::
resize_bilinear
(
out
,
img
);
float
f1
=
dot
(
out
,
gradient_input
);
img
.
host
()[
idx
]
=
2
+
h
;
cpu
::
resize_bilinear
(
out
,
img
);
float
f2
=
dot
(
out
,
gradient_input
);
const
float
numerical_grad
=
(
f2
-
f1
)
/
h
;
dlog
<<
LINFO
<<
"numerical grad: "
<<
numerical_grad
;
resizable_tensor
grad
,
grad2
;
grad
.
copy_size
(
img
);
grad
=
0.1
;
grad2
.
copy_size
(
img
);
grad2
=
0.1
;
cpu
::
resize_bilinear_gradient
(
grad2
,
gradient_input
);
dlog
<<
LINFO
<<
"analytic grad: "
<<
grad2
.
host
()[
idx
]
-
0.1
;
DLIB_CASSERT
(
std
::
abs
(
numerical_grad
-
grad2
.
host
()[
idx
]
+
0.1
)
<
1e-2
,
std
::
abs
(
numerical_grad
-
grad2
.
host
()[
idx
]
+
0.1
)
<<
" numerical_grad: "
<<
numerical_grad
);
#ifdef DLIB_USE_CUDA
cuda
::
resize_bilinear_gradient
(
grad
,
gradient_input
);
dlog
<<
LINFO
<<
"analytic grad: "
<<
grad
.
host
()[
idx
]
-
0.1
;
DLIB_CASSERT
(
std
::
abs
(
numerical_grad
-
grad
.
host
()[
idx
]
+
0.1
)
<
1e-2
,
std
::
abs
(
numerical_grad
-
grad
.
host
()[
idx
]
+
0.1
)
<<
" numerical_grad: "
<<
numerical_grad
);
DLIB_CASSERT
(
max
(
abs
(
mat
(
grad
)
-
mat
(
grad2
)))
<
1e-5
);
#endif
}
}
// ----------------------------------------------------------------------------------------
class
dnn_tester
:
public
tester
...
...
@@ -2337,6 +2400,9 @@ namespace
compare_adam
();
test_copy_tensor_gpu
();
#endif
test_tensor_resize_bilienar
(
2
,
3
,
6
,
6
,
11
,
11
);
test_tensor_resize_bilienar
(
2
,
3
,
6
,
6
,
3
,
4
);
test_tensor_resize_bilienar
(
2
,
3
,
5
,
6
,
12
,
21
);
test_max_pool
(
1
,
1
,
2
,
3
,
0
,
0
);
test_max_pool
(
3
,
3
,
1
,
1
,
0
,
0
);
test_max_pool
(
3
,
3
,
2
,
2
,
0
,
0
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment