Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in
Toggle navigation
D
dlib
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
钟尚武
dlib
Commits
d7647e83
Commit
d7647e83
authored
Mar 08, 2018
by
Davis King
Browse files
Options
Browse Files
Download
Plain Diff
merged
parents
306cd1a2
49ec319c
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
448 additions
and
4 deletions
+448
-4
CMakeLists.txt
dlib/CMakeLists.txt
+1
-0
source.cpp
dlib/all/source.cpp
+1
-0
kalman_filter.cpp
dlib/filtering/kalman_filter.cpp
+104
-0
kalman_filter.h
dlib/filtering/kalman_filter.h
+213
-0
kalman_filter_abstract.h
dlib/filtering/kalman_filter_abstract.h
+0
-0
rectangles.cpp
tools/python/src/rectangles.cpp
+129
-4
No files found.
dlib/CMakeLists.txt
View file @
d7647e83
...
...
@@ -213,6 +213,7 @@ if (NOT TARGET dlib)
data_io/image_dataset_metadata.cpp
data_io/mnist.cpp
global_optimization/global_function_search.cpp
filtering/kalman_filter.cpp
test_for_odr_violations.cpp
)
...
...
dlib/all/source.cpp
View file @
d7647e83
...
...
@@ -89,6 +89,7 @@
#include "../data_io/image_dataset_metadata.cpp"
#include "../data_io/mnist.cpp"
#include "../global_optimization/global_function_search.cpp"
#include "../filtering/kalman_filter.cpp"
#define DLIB_ALL_SOURCE_END
...
...
dlib/filtering/kalman_filter.cpp
0 → 100644
View file @
d7647e83
// Copyright (C) 2018 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#ifndef DLIB_KALMAN_FiLTER_CPp_
#define DLIB_KALMAN_FiLTER_CPp_
#include "kalman_filter.h"
#include "../global_optimization.h"
#include "../statistics.h"
namespace
dlib
{
// ----------------------------------------------------------------------------------------
momentum_filter
find_optimal_momentum_filter
(
const
std
::
vector
<
std
::
vector
<
double
>>&
sequences
,
const
double
smoothness
)
{
DLIB_CASSERT
(
sequences
.
size
()
!=
0
);
for
(
auto
&
vals
:
sequences
)
DLIB_CASSERT
(
vals
.
size
()
>
4
);
DLIB_CASSERT
(
smoothness
>=
0
);
// define the objective function we optimize to find the best filter
auto
obj
=
[
&
](
double
measurement_noise
,
double
typical_acceleration
,
double
max_measurement_deviation
)
{
running_stats
<
double
>
rs
;
for
(
auto
&
vals
:
sequences
)
{
momentum_filter
filt
(
measurement_noise
,
typical_acceleration
,
max_measurement_deviation
);
double
prev_filt
=
0
;
for
(
size_t
i
=
0
;
i
<
vals
.
size
();
++
i
)
{
// we care about smoothness and fitting the data.
if
(
i
>
0
)
{
// the filter should fit the data
rs
.
add
(
std
::
abs
(
vals
[
i
]
-
filt
.
get_predicted_next_position
()));
}
double
next_filt
=
filt
(
vals
[
i
]);
if
(
i
>
0
)
{
// the filter should also output a smooth trajectory
rs
.
add
(
smoothness
*
std
::
abs
(
next_filt
-
prev_filt
));
}
prev_filt
=
next_filt
;
}
}
return
rs
.
mean
();
};
running_stats
<
double
>
avgdiff
;
for
(
auto
&
vals
:
sequences
)
{
for
(
size_t
i
=
1
;
i
<
vals
.
size
();
++
i
)
avgdiff
.
add
(
vals
[
i
]
-
vals
[
i
-
1
]);
}
const
double
scale
=
avgdiff
.
stddev
();
function_evaluation
opt
=
find_min_global
(
obj
,
{
scale
*
0.01
,
scale
*
0.0001
,
0.00001
},
{
scale
*
10
,
scale
*
10
,
10
},
max_function_calls
(
400
));
momentum_filter
filt
(
opt
.
x
(
0
),
opt
.
x
(
1
),
opt
.
x
(
2
));
return
filt
;
}
// ----------------------------------------------------------------------------------------
momentum_filter
find_optimal_momentum_filter
(
const
std
::
vector
<
double
>&
sequence
,
const
double
smoothness
)
{
return
find_optimal_momentum_filter
({
1
,
sequence
},
smoothness
);
}
// ----------------------------------------------------------------------------------------
rect_filter
find_optimal_rect_filter
(
const
std
::
vector
<
rectangle
>&
rects
,
const
double
smoothness
)
{
DLIB_CASSERT
(
rects
.
size
()
>
4
);
DLIB_CASSERT
(
smoothness
>=
0
);
std
::
vector
<
std
::
vector
<
double
>>
vals
(
4
);
for
(
auto
&
r
:
rects
)
{
vals
[
0
].
push_back
(
r
.
left
());
vals
[
1
].
push_back
(
r
.
top
());
vals
[
2
].
push_back
(
r
.
right
());
vals
[
3
].
push_back
(
r
.
bottom
());
}
return
rect_filter
(
find_optimal_momentum_filter
(
vals
,
smoothness
));
}
// ----------------------------------------------------------------------------------------
}
#endif // DLIB_KALMAN_FiLTER_CPp_
dlib/filtering/kalman_filter.h
View file @
d7647e83
...
...
@@ -5,6 +5,7 @@
#include "kalman_filter_abstract.h"
#include "../matrix.h"
#include "../geometry.h"
namespace
dlib
{
...
...
@@ -161,6 +162,218 @@ namespace dlib
};
// ----------------------------------------------------------------------------------------
class
momentum_filter
{
public
:
momentum_filter
(
double
meas_noise
,
double
acc
,
double
max_meas_dev
)
:
measurement_noise
(
meas_noise
),
typical_acceleration
(
acc
),
max_measurement_deviation
(
max_meas_dev
)
{
DLIB_CASSERT
(
meas_noise
>=
0
);
DLIB_CASSERT
(
acc
>=
0
);
DLIB_CASSERT
(
max_meas_dev
>=
0
);
kal
.
set_observation_model
({
1
,
0
});
kal
.
set_transition_model
(
{
1
,
1
,
0
,
1
});
kal
.
set_process_noise
({
0
,
0
,
0
,
typical_acceleration
*
typical_acceleration
});
kal
.
set_measurement_noise
({
measurement_noise
*
measurement_noise
});
}
momentum_filter
()
=
default
;
double
get_measurement_noise
(
)
const
{
return
measurement_noise
;
}
double
get_typical_acceleration
(
)
const
{
return
typical_acceleration
;
}
double
get_max_measurement_deviation
(
)
const
{
return
max_measurement_deviation
;
}
void
reset
()
{
*
this
=
momentum_filter
(
measurement_noise
,
typical_acceleration
,
max_measurement_deviation
);
}
double
get_predicted_next_position
(
)
const
{
return
kal
.
get_predicted_next_state
()(
0
);
}
double
operator
()(
const
double
measured_position
)
{
auto
x
=
kal
.
get_predicted_next_state
();
const
auto
max_deviation
=
max_measurement_deviation
*
measurement_noise
;
// Check if measured_position has suddenly jumped in value by a whole lot. This
// could happen if the velocity term experiences a much larger than normal
// acceleration, e.g. because the underlying object is doing a maneuver. If
// this happens then we clamp the state so that the predicted next value is no
// more than max_deviation away from measured_position at all times.
if
(
x
(
0
)
>
measured_position
+
max_deviation
)
{
x
(
0
)
=
measured_position
+
max_deviation
;
kal
.
set_state
(
x
);
}
else
if
(
x
(
0
)
<
measured_position
-
max_deviation
)
{
x
(
0
)
=
measured_position
-
max_deviation
;
kal
.
set_state
(
x
);
}
kal
.
update
({
measured_position
});
return
kal
.
get_current_state
()(
0
);
}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
out
,
const
momentum_filter
&
item
)
{
out
<<
"measurement_noise: "
<<
item
.
measurement_noise
<<
"
\n
"
;
out
<<
"typical_acceleration: "
<<
item
.
typical_acceleration
<<
"
\n
"
;
out
<<
"max_measurement_deviation: "
<<
item
.
max_measurement_deviation
;
return
out
;
}
friend
void
serialize
(
const
momentum_filter
&
item
,
std
::
ostream
&
out
)
{
int
version
=
15
;
serialize
(
version
,
out
);
serialize
(
item
.
measurement_noise
,
out
);
serialize
(
item
.
typical_acceleration
,
out
);
serialize
(
item
.
max_measurement_deviation
,
out
);
serialize
(
item
.
kal
,
out
);
}
friend
void
deserialize
(
momentum_filter
&
item
,
std
::
istream
&
in
)
{
int
version
=
0
;
deserialize
(
version
,
in
);
if
(
version
!=
15
)
throw
serialization_error
(
"Unexpected version found while deserializing momentum_filter."
);
deserialize
(
item
.
measurement_noise
,
in
);
deserialize
(
item
.
typical_acceleration
,
in
);
deserialize
(
item
.
max_measurement_deviation
,
in
);
deserialize
(
item
.
kal
,
in
);
}
private
:
double
measurement_noise
=
2
;
double
typical_acceleration
=
0
.
1
;
double
max_measurement_deviation
=
3
;
// nominally number of standard deviations
kalman_filter
<
2
,
1
>
kal
;
};
// ----------------------------------------------------------------------------------------
momentum_filter
find_optimal_momentum_filter
(
const
std
::
vector
<
std
::
vector
<
double
>>&
sequences
,
const
double
smoothness
=
1
);
// ----------------------------------------------------------------------------------------
momentum_filter
find_optimal_momentum_filter
(
const
std
::
vector
<
double
>&
sequence
,
const
double
smoothness
=
1
);
// ----------------------------------------------------------------------------------------
class
rect_filter
{
public
:
rect_filter
()
=
default
;
rect_filter
(
double
meas_noise
,
double
acc
,
double
max_meas_dev
)
:
rect_filter
(
momentum_filter
(
meas_noise
,
acc
,
max_meas_dev
))
{}
rect_filter
(
const
momentum_filter
&
filt
)
:
left
(
filt
),
top
(
filt
),
right
(
filt
),
bottom
(
filt
)
{
}
drectangle
operator
()(
const
drectangle
&
r
)
{
return
drectangle
(
left
(
r
.
left
()),
top
(
r
.
top
()),
right
(
r
.
right
()),
bottom
(
r
.
bottom
()));
}
drectangle
operator
()(
const
rectangle
&
r
)
{
return
drectangle
(
left
(
r
.
left
()),
top
(
r
.
top
()),
right
(
r
.
right
()),
bottom
(
r
.
bottom
()));
}
const
momentum_filter
&
get_left
()
const
{
return
left
;
}
momentum_filter
&
get_left
()
{
return
left
;
}
const
momentum_filter
&
get_top
()
const
{
return
top
;
}
momentum_filter
&
get_top
()
{
return
top
;
}
const
momentum_filter
&
get_right
()
const
{
return
right
;
}
momentum_filter
&
get_right
()
{
return
right
;
}
const
momentum_filter
&
get_bottom
()
const
{
return
bottom
;
}
momentum_filter
&
get_bottom
()
{
return
bottom
;
}
friend
void
serialize
(
const
rect_filter
&
item
,
std
::
ostream
&
out
)
{
int
version
=
123
;
serialize
(
version
,
out
);
serialize
(
item
.
left
,
out
);
serialize
(
item
.
top
,
out
);
serialize
(
item
.
right
,
out
);
serialize
(
item
.
bottom
,
out
);
}
friend
void
deserialize
(
rect_filter
&
item
,
std
::
istream
&
in
)
{
int
version
=
0
;
deserialize
(
version
,
in
);
if
(
version
!=
123
)
throw
dlib
::
serialization_error
(
"Unknown version number found while deserializing rect_filter object."
);
deserialize
(
item
.
left
,
in
);
deserialize
(
item
.
top
,
in
);
deserialize
(
item
.
right
,
in
);
deserialize
(
item
.
bottom
,
in
);
}
private
:
momentum_filter
left
,
top
,
right
,
bottom
;
};
// ----------------------------------------------------------------------------------------
rect_filter
find_optimal_rect_filter
(
const
std
::
vector
<
rectangle
>&
rects
,
const
double
smoothness
=
1
);
// ----------------------------------------------------------------------------------------
}
...
...
dlib/filtering/kalman_filter_abstract.h
View file @
d7647e83
This diff is collapsed.
Click to expand it.
tools/python/src/rectangles.cpp
View file @
d7647e83
...
...
@@ -6,6 +6,7 @@
#include <pybind11/stl_bind.h>
#include "indexing.h"
#include "opaque_types.h"
#include <dlib/filtering.h>
using
namespace
dlib
;
using
namespace
std
;
...
...
@@ -60,14 +61,56 @@ string print_rectangle_str(const rect_type& r)
return
sout
.
str
();
}
template
<
typename
rect_type
>
string
print_rectangle_repr
(
const
rect_type
&
r
)
string
print_rectangle_repr
(
const
rectangle
&
r
)
{
std
::
ostringstream
sout
;
sout
<<
"rectangle("
<<
r
.
left
()
<<
","
<<
r
.
top
()
<<
","
<<
r
.
right
()
<<
","
<<
r
.
bottom
()
<<
")"
;
return
sout
.
str
();
}
string
print_drectangle_repr
(
const
drectangle
&
r
)
{
std
::
ostringstream
sout
;
sout
<<
"drectangle("
<<
r
.
left
()
<<
","
<<
r
.
top
()
<<
","
<<
r
.
right
()
<<
","
<<
r
.
bottom
()
<<
")"
;
return
sout
.
str
();
}
string
print_rect_filter
(
const
rect_filter
&
r
)
{
std
::
ostringstream
sout
;
sout
<<
"rect_filter("
;
sout
<<
"measurement_noise="
<<
r
.
get_left
().
get_measurement_noise
();
sout
<<
", typical_acceleration="
<<
r
.
get_left
().
get_typical_acceleration
();
sout
<<
", max_measurement_deviation="
<<
r
.
get_left
().
get_max_measurement_deviation
();
sout
<<
")"
;
return
sout
.
str
();
}
rectangle
add_point_to_rect
(
const
rectangle
&
r
,
const
point
&
p
)
{
return
r
+
p
;
}
rectangle
add_rect_to_rect
(
const
rectangle
&
r
,
const
rectangle
&
p
)
{
return
r
+
p
;
}
rectangle
&
iadd_point_to_rect
(
rectangle
&
r
,
const
point
&
p
)
{
r
+=
p
;
return
r
;
}
rectangle
&
iadd_rect_to_rect
(
rectangle
&
r
,
const
rectangle
&
p
)
{
r
+=
p
;
return
r
;
}
// ----------------------------------------------------------------------------------------
void
bind_rectangles
(
py
::
module
&
m
)
...
...
@@ -76,6 +119,7 @@ void bind_rectangles(py::module& m)
typedef
rectangle
type
;
py
::
class_
<
type
>
(
m
,
"rectangle"
,
"This object represents a rectangular area of an image."
)
.
def
(
py
::
init
<
long
,
long
,
long
,
long
>
(),
py
::
arg
(
"left"
),
py
::
arg
(
"top"
),
py
::
arg
(
"right"
),
py
::
arg
(
"bottom"
))
.
def
(
py
::
init
())
.
def
(
"area"
,
&::
area
)
.
def
(
"left"
,
&::
left
)
.
def
(
"top"
,
&::
top
)
...
...
@@ -91,7 +135,11 @@ void bind_rectangles(py::module& m)
.
def
(
"contains"
,
&::
contains_rec
<
type
>
,
py
::
arg
(
"rectangle"
))
.
def
(
"intersect"
,
&::
intersect
<
type
>
,
py
::
arg
(
"rectangle"
))
.
def
(
"__str__"
,
&::
print_rectangle_str
<
type
>
)
.
def
(
"__repr__"
,
&::
print_rectangle_repr
<
type
>
)
.
def
(
"__repr__"
,
&::
print_rectangle_repr
)
.
def
(
"__add__"
,
&::
add_point_to_rect
)
.
def
(
"__add__"
,
&::
add_rect_to_rect
)
.
def
(
"__iadd__"
,
&::
iadd_point_to_rect
)
.
def
(
"__iadd__"
,
&::
iadd_rect_to_rect
)
.
def
(
py
::
self
==
py
::
self
)
.
def
(
py
::
self
!=
py
::
self
)
.
def
(
py
::
pickle
(
&
getstate
<
type
>
,
&
setstate
<
type
>
));
...
...
@@ -115,12 +163,89 @@ void bind_rectangles(py::module& m)
.
def
(
"contains"
,
&::
contains_rec
<
type
>
,
py
::
arg
(
"rectangle"
))
.
def
(
"intersect"
,
&::
intersect
<
type
>
,
py
::
arg
(
"rectangle"
))
.
def
(
"__str__"
,
&::
print_rectangle_str
<
type
>
)
.
def
(
"__repr__"
,
&::
print_
rectangle_repr
<
type
>
)
.
def
(
"__repr__"
,
&::
print_
drectangle_repr
)
.
def
(
py
::
self
==
py
::
self
)
.
def
(
py
::
self
!=
py
::
self
)
.
def
(
py
::
pickle
(
&
getstate
<
type
>
,
&
setstate
<
type
>
));
}
{
typedef
rect_filter
type
;
py
::
class_
<
type
>
(
m
,
"rect_filter"
,
R"asdf(
This object is a simple tool for filtering a rectangle that
measures the location of a moving object that has some non-trivial
momentum. Importantly, the measurements are noisy and the object can
experience sudden unpredictable accelerations. To accomplish this
filtering we use a simple Kalman filter with a state transition model of:
position_{i+1} = position_{i} + velocity_{i}
velocity_{i+1} = velocity_{i} + some_unpredictable_acceleration
and a measurement model of:
measured_position_{i} = position_{i} + measurement_noise
Where some_unpredictable_acceleration and measurement_noise are 0 mean Gaussian
noise sources with standard deviations of typical_acceleration and
measurement_noise respectively.
To allow for really sudden and large but infrequent accelerations, at each
step we check if the current measured position deviates from the predicted
filtered position by more than max_measurement_deviation*measurement_noise
and if so we adjust the filter's state to keep it within these bounds.
This allows the moving object to undergo large unmodeled accelerations, far
in excess of what would be suggested by typical_acceleration, without
then experiencing a long lag time where the Kalman filter has to "catches
up" to the new position. )asdf"
)
.
def
(
py
::
init
<
double
,
double
,
double
>
(),
py
::
arg
(
"measurement_noise"
),
py
::
arg
(
"typical_acceleration"
),
py
::
arg
(
"max_measurement_deviation"
))
.
def
(
"measurement_noise"
,
[](
const
rect_filter
&
a
){
return
a
.
get_left
().
get_measurement_noise
();})
.
def
(
"typical_acceleration"
,
[](
const
rect_filter
&
a
){
return
a
.
get_left
().
get_typical_acceleration
();})
.
def
(
"max_measurement_deviation"
,
[](
const
rect_filter
&
a
){
return
a
.
get_left
().
get_max_measurement_deviation
();})
.
def
(
"__call__"
,
[](
rect_filter
&
f
,
const
dlib
::
rectangle
&
r
){
return
rectangle
(
f
(
r
));
},
py
::
arg
(
"rect"
))
.
def
(
"__repr__"
,
print_rect_filter
)
.
def
(
py
::
pickle
(
&
getstate
<
type
>
,
&
setstate
<
type
>
));
}
m
.
def
(
"find_optimal_rect_filter"
,
[](
const
std
::
vector
<
rectangle
>&
rects
,
const
double
smoothness
)
{
return
find_optimal_rect_filter
(
rects
,
smoothness
);
},
py
::
arg
(
"rects"
),
py
::
arg
(
"smoothness"
)
=
1
,
"requires
\n
\
- rects.size() > 4
\n
\
- smoothness >= 0
\n
\
ensures
\n
\
- This function finds the
\"
optimal
\"
settings of a rect_filter based on recorded
\n
\
measurement data stored in rects. Here we assume that rects is a complete
\n
\
track history of some object's measured positions. Essentially, what we do
\n
\
is find the rect_filter that minimizes the following objective function:
\n
\
sum of abs(predicted_location[i] - measured_location[i]) + smoothness*abs(filtered_location[i]-filtered_location[i-1])
\n
\
Where i is a time index.
\n
\
The sum runs over all the data in rects. So what we do is find the
\n
\
filter settings that produce smooth filtered trajectories but also produce
\n
\
filtered outputs that are as close to the measured positions as possible.
\n
\
The larger the value of smoothness the less jittery the filter outputs will
\n
\
be, but they might become biased or laggy if smoothness is set really high. "
/*!
requires
- rects.size() > 4
- smoothness >= 0
ensures
- This function finds the "optimal" settings of a rect_filter based on recorded
measurement data stored in rects. Here we assume that rects is a complete
track history of some object's measured positions. Essentially, what we do
is find the rect_filter that minimizes the following objective function:
sum of abs(predicted_location[i] - measured_location[i]) + smoothness*abs(filtered_location[i]-filtered_location[i-1])
Where i is a time index.
The sum runs over all the data in rects. So what we do is find the
filter settings that produce smooth filtered trajectories but also produce
filtered outputs that are as close to the measured positions as possible.
The larger the value of smoothness the less jittery the filter outputs will
be, but they might become biased or laggy if smoothness is set really high.
!*/
);
{
typedef
std
::
vector
<
rectangle
>
type
;
py
::
bind_vector
<
type
>
(
m
,
"rectangles"
,
"An array of rectangle objects."
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment