Commit 67d0ef02 authored by Davis King's avatar Davis King

Added momentum_filter and rect_filter as well as find_optimal_momentum_filter()

and find_optimal_rect_filter()
parent cf5e25a9
......@@ -213,6 +213,7 @@ if (NOT TARGET dlib)
data_io/image_dataset_metadata.cpp
data_io/mnist.cpp
global_optimization/global_function_search.cpp
filtering/kalman_filter.cpp
test_for_odr_violations.cpp
)
......
......@@ -20,6 +20,7 @@
#include "../tokenizer/tokenizer_kernel_1.cpp"
#include "../unicode/unicode.cpp"
#include "../test_for_odr_violations.cpp"
#include "../filtering/kalman_filter.cpp"
......
// Copyright (C) 2018 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#ifndef DLIB_KALMAN_FiLTER_CPp_
#define DLIB_KALMAN_FiLTER_CPp_
#include "kalman_filter.h"
#include "../global_optimization.h"
#include "../statistics.h"
namespace dlib
{
// ----------------------------------------------------------------------------------------
momentum_filter find_optimal_momentum_filter (
const std::vector<std::vector<double>>& sequences,
const double smoothness
)
{
DLIB_CASSERT(sequences.size() != 0);
for (auto& vals : sequences)
DLIB_CASSERT(vals.size() > 4);
DLIB_CASSERT(smoothness >= 0);
// define the objective function we optimize to find the best filter
auto obj = [&](double measurement_noise, double typical_acceleration, double max_measurement_deviation)
{
running_stats<double> rs;
for (auto& vals : sequences)
{
momentum_filter filt(measurement_noise, typical_acceleration, max_measurement_deviation);
double prev_filt = 0;
for (size_t i = 0; i < vals.size(); ++i)
{
// we care about smoothness and fitting the data.
if (i > 0)
{
// the filter should fit the data
rs.add(std::abs(vals[i]-filt.get_predicted_next_state()));
}
double next_filt = filt(vals[i]);
if (i > 0)
{
// the filter should also output a smooth trajectory
rs.add(smoothness*std::abs(next_filt-prev_filt));
}
prev_filt = next_filt;
}
}
return rs.mean();
};
running_stats<double> avgdiff;
for (auto& vals : sequences)
{
for (size_t i = 1; i < vals.size(); ++i)
avgdiff.add(vals[i]-vals[i-1]);
}
const double scale = avgdiff.stddev();
function_evaluation opt = find_min_global(obj, {scale*0.01, scale*0.0001, 0.00001}, {scale*10, scale*10, 10}, max_function_calls(400));
momentum_filter filt(opt.x(0), opt.x(1), opt.x(2));
return filt;
}
// ----------------------------------------------------------------------------------------
momentum_filter find_optimal_momentum_filter (
const std::vector<double>& sequence,
const double smoothness
)
{
return find_optimal_momentum_filter({1,sequence}, smoothness);
}
// ----------------------------------------------------------------------------------------
rect_filter find_optimal_rect_filter (
const std::vector<rectangle>& rects,
const double smoothness
)
{
DLIB_CASSERT(rects.size() > 4);
DLIB_CASSERT(smoothness >= 0);
std::vector<std::vector<double>> vals(4);
for (auto& r : rects)
{
vals[0].push_back(r.left());
vals[1].push_back(r.top());
vals[2].push_back(r.right());
vals[3].push_back(r.bottom());
}
return rect_filter(find_optimal_momentum_filter(vals, smoothness));
}
// ----------------------------------------------------------------------------------------
}
#endif // DLIB_KALMAN_FiLTER_CPp_
......@@ -5,6 +5,7 @@
#include "kalman_filter_abstract.h"
#include "../matrix.h"
#include "../geometry.h"
namespace dlib
{
......@@ -161,6 +162,207 @@ namespace dlib
};
// ----------------------------------------------------------------------------------------
class momentum_filter
{
public:
momentum_filter(
double meas_noise,
double acc,
double max_meas_dev
) : measurement_noise(meas_noise),
typical_acceleration(acc),
max_measurement_deviation(max_meas_dev)
{
kal.set_observation_model({1, 0});
kal.set_transition_model( {1, 1,
0, 1});
kal.set_process_noise({0, 0,
0, typical_acceleration*typical_acceleration});
kal.set_measurement_noise({measurement_noise*measurement_noise});
}
momentum_filter() = default;
double get_measurement_noise (
) const { return measurement_noise; }
double get_typical_acceleration (
) const { return typical_acceleration; }
double get_max_measurement_deviation (
) const { return max_measurement_deviation; }
void reset()
{
*this = momentum_filter(measurement_noise, typical_acceleration, max_measurement_deviation);
}
double get_predicted_next_state(
) const
{
return kal.get_predicted_next_state()(0);
}
double operator()(
const double val
)
{
auto x = kal.get_predicted_next_state();
const auto max_deviation = max_measurement_deviation*measurement_noise;
// Check if val has suddenly jumped in value by a whole lot. This could happen if
// the velocity term experiences a much larger than normal acceleration, e.g.
// because the underlying object is doing a maneuver. If this happens then we
// clamp the state so that the predicted next value is no more than
// max_deviation away from val at all times.
if (x(0) > val + max_deviation)
{
x(0) = val + max_deviation;
kal.set_state(x);
}
else if (x(0) < val - max_deviation)
{
x(0) = val - max_deviation;
kal.set_state(x);
}
kal.update({val});
return kal.get_current_state()(0);
}
friend std::ostream& operator << (std::ostream& out, const momentum_filter& item)
{
out << "measurement_noise: " << item.measurement_noise << "\n";
out << "typical_acceleration: " << item.typical_acceleration << "\n";
out << "max_measurement_deviation: " << item.max_measurement_deviation;
return out;
}
friend void serialize(const momentum_filter& item, std::ostream& out)
{
int version = 15;
serialize(version, out);
serialize(item.measurement_noise, out);
serialize(item.typical_acceleration, out);
serialize(item.max_measurement_deviation, out);
serialize(item.kal, out);
}
friend void deserialize(momentum_filter& item, std::istream& in)
{
int version = 0;
deserialize(version, in);
if (version != 15)
throw serialization_error("Unexpected version found while deserializing momentum_filter.");
deserialize(item.measurement_noise, in);
deserialize(item.typical_acceleration, in);
deserialize(item.max_measurement_deviation, in);
deserialize(item.kal, in);
}
private:
double measurement_noise = 2;
double typical_acceleration = 0.1;
double max_measurement_deviation = 3; // nominally number of standard deviations
kalman_filter<2,1> kal;
};
// ----------------------------------------------------------------------------------------
momentum_filter find_optimal_momentum_filter (
const std::vector<std::vector<double>>& sequences,
const double smoothness = 1
);
// ----------------------------------------------------------------------------------------
momentum_filter find_optimal_momentum_filter (
const std::vector<double>& sequence,
const double smoothness = 1
);
// ----------------------------------------------------------------------------------------
class rect_filter
{
public:
rect_filter() = default;
rect_filter(
const momentum_filter& filt
) :
left(filt),
top(filt),
right(filt),
bottom(filt)
{
}
drectangle operator()(const drectangle& r)
{
return drectangle(left(r.left()),
top(r.top()),
right(r.right()),
bottom(r.bottom()));
}
drectangle operator()(const rectangle& r)
{
return drectangle(left(r.left()),
top(r.top()),
right(r.right()),
bottom(r.bottom()));
}
const momentum_filter& get_left () const { return left; }
momentum_filter& get_left () { return left; }
const momentum_filter& get_top () const { return top; }
momentum_filter& get_top () { return top; }
const momentum_filter& get_right () const { return right; }
momentum_filter& get_right () { return right; }
const momentum_filter& get_bottom () const { return bottom; }
momentum_filter& get_bottom () { return bottom; }
friend void serialize(const rect_filter& item, std::ostream& out)
{
int version = 123;
serialize(version, out);
serialize(item.left, out);
serialize(item.top, out);
serialize(item.right, out);
serialize(item.bottom, out);
}
friend void deserialize(rect_filter& item, std::istream& in)
{
int version = 0;
deserialize(version, in);
if (version != 123)
throw dlib::serialization_error("Unknown version number found while deserializing rect_filter object.");
deserialize(item.left, in);
deserialize(item.top, in);
deserialize(item.right, in);
deserialize(item.bottom, in);
}
private:
momentum_filter left, top, right, bottom;
};
// ----------------------------------------------------------------------------------------
rect_filter find_optimal_rect_filter (
const std::vector<rectangle>& rects,
const double smoothness = 1
);
// ----------------------------------------------------------------------------------------
}
......
......@@ -211,6 +211,120 @@ namespace dlib
provides deserialization support
!*/
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
class momentum_filter
{
/*!
WHAT THIS OBJECT REPRESENTS
!*/
public:
momentum_filter(
double meas_noise,
double acc,
double max_meas_dev
);
momentum_filter() = default;
double get_measurement_noise (
) const;
double get_typical_acceleration (
) const;
double get_max_measurement_deviation (
) const;
void reset(
);
double get_predicted_next_state(
) const;
double operator()(
const double val
);
};
std::ostream& operator << (std::ostream& out, const momentum_filter& item);
void serialize(const momentum_filter& item, std::ostream& out);
void deserialize(momentum_filter& item, std::istream& in);
// ----------------------------------------------------------------------------------------
momentum_filter find_optimal_momentum_filter (
const std::vector<std::vector<double>>& sequences,
const double smoothness = 1
);
/*!
requires
- sequences.size() != 0
- for all valid i: sequences[i].size() > 4
- smoothness >= 0
!*/
// ----------------------------------------------------------------------------------------
momentum_filter find_optimal_momentum_filter (
const std::vector<double>& sequence,
const double smoothness = 1
);
/*!
requires
- sequence.size() > 4
- smoothness >= 0
ensures
- performs: find_optimal_momentum_filter({1,sequence}, smoothness);
!*/
// ----------------------------------------------------------------------------------------
class rect_filter
{
public:
rect_filter() = default;
rect_filter(
const momentum_filter& filt
);
drectangle operator()(
const drectangle& r
);
drectangle operator()(
const rectangle& r
);
const momentum_filter& get_left() const;
momentum_filter& get_left();
const momentum_filter& get_top() const;
momentum_filter& get_top();
const momentum_filter& get_right() const;
momentum_filter& get_right();
const momentum_filter& get_bottom() const;
momentum_filter& get_bottom();
};
void serialize(const rect_filter& item, std::ostream& out);
void deserialize(rect_filter& item, std::istream& in);
// ----------------------------------------------------------------------------------------
rect_filter find_optimal_rect_filter (
const std::vector<rectangle>& rects,
const double smoothness = 1
);
/*!
requires
- rects.size() > 4
- smoothness >= 0
!*/
// ----------------------------------------------------------------------------------------
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment