Commit 2bf69577 authored by Davis King's avatar Davis King

Added initial version of the learning-to-track interface to the association learning

tools.  So this adds the track_association_function and structural_track_association_trainer
objects and also test_track_association_function() and cross_validate_track_association_trainer()
routines.
parent fb6eca2b
...@@ -45,6 +45,7 @@ ...@@ -45,6 +45,7 @@
#include "svm/structural_svm_problem.h" #include "svm/structural_svm_problem.h"
#include "svm/sequence_labeler.h" #include "svm/sequence_labeler.h"
#include "svm/assignment_function.h" #include "svm/assignment_function.h"
#include "svm/track_association_function.h"
#include "svm/active_learning.h" #include "svm/active_learning.h"
#include "svm/svr_linear_trainer.h" #include "svm/svr_linear_trainer.h"
#include "svm/sequence_segmenter.h" #include "svm/sequence_segmenter.h"
......
// Copyright (C) 2014 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#ifndef DLIB_CROSS_VALIDATE_TRACK_ASSOCIATION_TrAINER_H__
#define DLIB_CROSS_VALIDATE_TRACK_ASSOCIATION_TrAINER_H__
#include "cross_validate_track_association_trainer_abstract.h"
#include "structural_track_association_trainer.h"
namespace dlib
{
// ----------------------------------------------------------------------------------------
namespace impl
{
template <
typename track_association_function,
typename detection_type,
typename detection_id_type
>
void test_track_association_function (
const track_association_function& assoc,
const std::vector<std::vector<std::pair<detection_type,detection_id_type> > >& samples,
unsigned long& total_dets,
unsigned long& correctly_associated_dets
)
{
const typename track_association_function::association_function_type& f = assoc.get_assignment_function();
typedef typename detection_type::track_type track_type;
using namespace impl;
std::vector<track_type> tracks;
std::map<detection_id_type,unsigned long> track_idx; // tracks[track_idx[id]] == track with ID id.
for (unsigned long j = 0; j < samples.size(); ++j)
{
total_dets += samples[j].size();
std::vector<long> assignments = f(get_unlabeled_dets(samples[j]), tracks);
std::vector<bool> updated_track(tracks.size(), false);
// now update all the tracks with the detections that associated to them.
const std::vector<std::pair<detection_type,detection_id_type> >& dets = samples[j];
for (unsigned long k = 0; k < assignments.size(); ++k)
{
if (assignments[k] != -1)
{
tracks[assignments[k]].update_track(dets[k].first);
updated_track[assignments[k]] = true;
// if this detection was supposed to go to this track
if (track_idx.count(dets[k].second) && track_idx[dets[k].second]==assignments[k])
++correctly_associated_dets;
track_idx[dets[k].second] = assignments[k];
}
else
{
track_type new_track;
new_track.update_track(dets[k].first);
tracks.push_back(new_track);
// if this detection was supposed to go to a new track
if (track_idx.count(dets[k].second) == 0)
++correctly_associated_dets;
track_idx[dets[k].second] = tracks.size()-1;
}
}
// Now propagate all the tracks that didn't get any detections.
for (unsigned long k = 0; k < updated_track.size(); ++k)
{
if (!updated_track[k])
tracks[k].propagate_track();
}
}
}
}
// ----------------------------------------------------------------------------------------
template <
typename track_association_function,
typename detection_type,
typename detection_id_type
>
double test_track_association_function (
const track_association_function& assoc,
const std::vector<std::vector<std::vector<std::pair<detection_type,detection_id_type> > > >& samples
)
{
unsigned long total_dets = 0;
unsigned long correctly_associated_dets = 0;
for (unsigned long i = 0; i < samples.size(); ++i)
{
impl::test_track_association_function(assoc, samples[i], total_dets, correctly_associated_dets);
}
return (double)correctly_associated_dets/(double)total_dets;
}
// ----------------------------------------------------------------------------------------
template <
typename trainer_type,
typename detection_type,
typename detection_id_type
>
double cross_validate_track_association_trainer (
const trainer_type& trainer,
const std::vector<std::vector<std::vector<std::pair<detection_type,detection_id_type> > > >& samples,
const long folds
)
{
const long num_in_test = samples.size()/folds;
const long num_in_train = samples.size() - num_in_test;
std::vector<std::vector<std::vector<std::pair<detection_type,detection_id_type> > > > samples_train;
long next_test_idx = 0;
unsigned long total_dets = 0;
unsigned long correctly_associated_dets = 0;
for (long i = 0; i < folds; ++i)
{
samples_train.clear();
// load up the training samples
long next = (next_test_idx + num_in_test)%samples.size();
for (long cnt = 0; cnt < num_in_train; ++cnt)
{
samples_train.push_back(samples[next]);
next = (next + 1)%samples.size();
}
const typename trainer_type::trained_function_type& df = trainer.train(samples_train);
for (long cnt = 0; cnt < num_in_test; ++cnt)
{
impl::test_track_association_function(df, samples[next_test_idx], total_dets, correctly_associated_dets);
next_test_idx = (next_test_idx + 1)%samples.size();
}
}
return (double)correctly_associated_dets/(double)total_dets;
}
// ----------------------------------------------------------------------------------------
}
#endif // DLIB_CROSS_VALIDATE_TRACK_ASSOCIATION_TrAINER_H__
// Copyright (C) 2014 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#undef DLIB_CROSS_VALIDATE_TRACK_ASSOCIATION_TrAINER_ABSTRACT_H__
#ifdef DLIB_CROSS_VALIDATE_TRACK_ASSOCIATION_TrAINER_ABSTRACT_H__
#include "structural_track_association_trainer_abstract.h"
namespace dlib
{
// ----------------------------------------------------------------------------------------
template <
typename track_association_function,
typename detection_type,
typename detection_id_type
>
double test_track_association_function (
const track_association_function& assoc,
const std::vector<std::vector<std::vector<std::pair<detection_type,detection_id_type> > > >& samples
);
/*!
!*/
// ----------------------------------------------------------------------------------------
template <
typename trainer_type,
typename detection_type,
typename detection_id_type
>
double cross_validate_track_association_trainer (
const trainer_type& trainer,
const std::vector<std::vector<std::vector<std::pair<detection_type,detection_id_type> > > >& samples,
const long folds
);
// ----------------------------------------------------------------------------------------
}
#endif // DLIB_CROSS_VALIDATE_TRACK_ASSOCIATION_TrAINER_ABSTRACT_H__
This diff is collapsed.
// Copyright (C) 2014 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#undef DLIB_STRUCTURAL_TRACK_ASSOCIATION_TRAnER_ABSTRACT_H__
#ifdef DLIB_STRUCTURAL_TRACK_ASSOCIATION_TRAnER_ABSTRACT_H__
#include "track_association_function_abstract.h"
#include "structural_assignment_trainer_abstract.h"
namespace dlib
{
// ----------------------------------------------------------------------------------------
template <
typename detection_type_,
typename detection_id_type_ = unsigned long
>
class structural_track_association_trainer
{
public:
typedef detection_type_ detection_type;
typedef typename detection_type::track_type track_type;
typedef detection_id_type_ detection_id_type;
typedef std::pair<detection_type, detection_id_type> labeled_detection;
typedef std::vector<labeled_detection> detections_at_single_time_step;
// This type logically represents an entire track history
typedef std::vector<detections_at_single_time_step> sample_type;
typedef track_association_function<detection_type> trained_function_type;
structural_track_association_trainer (
);
/*!
C = 100;
verbose = false;
eps = 0.1;
num_threads = 2;
max_cache_size = 5;
learn_nonnegative_weights = false;
!*/
void set_num_threads (
unsigned long num
);
unsigned long get_num_threads (
) const;
void set_epsilon (
double eps
);
double get_epsilon (
) const;
void set_max_cache_size (
unsigned long max_size
);
unsigned long get_max_cache_size (
) const;
void be_verbose (
);
void be_quiet (
);
void set_oca (
const oca& item
);
const oca get_oca (
) const;
void set_c (
double C
);
double get_c (
) const;
bool learns_nonnegative_weights (
) const;
void set_learns_nonnegative_weights (
bool value
);
const track_association_function<detection_type> train (
const std::vector<sample_type>& samples
) const;
/*!
requires
- is_track_association_problem(samples) == true
ensures
-
!*/
const track_association_function<detection_type> train (
const sample_type& sample
) const;
/*!
requires
- is_track_association_problem(samples) == true
ensures
-
!*/
};
}
#endif // DLIB_STRUCTURAL_TRACK_ASSOCIATION_TRAnER_ABSTRACT_H__
// Copyright (C) 2014 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#ifndef DLIB_TRACK_ASSOCiATION_FUNCTION_H__
#define DLIB_TRACK_ASSOCiATION_FUNCTION_H__
#include "track_association_function_abstract.h"
#include <vector>
#include <iostream>
#include "../algs.h"
#include "../serialize.h"
#include "assignment_function.h"
namespace dlib
{
// ----------------------------------------------------------------------------------------
template <
typename detection_type
>
class feature_extractor_track_association
{
public:
typedef typename detection_type::track_type track_type;
typedef typename track_type::feature_vector_type feature_vector_type;
typedef detection_type lhs_element;
typedef track_type rhs_element;
feature_extractor_track_association() : num_dims(0), num_nonnegative(0) {}
explicit feature_extractor_track_association (
unsigned long num_dims_,
unsigned long num_nonnegative_
) : num_dims(num_dims_), num_nonnegative(num_nonnegative_) {}
unsigned long num_features(
) const { return num_dims; }
unsigned long num_nonnegative_weights (
) const { return num_nonnegative; }
void get_features (
const detection_type& det,
const track_type& track,
feature_vector_type& feats
) const
{
track.get_similarity_features(det, feats);
}
friend void serialize (const feature_extractor_track_association& item, std::ostream& out)
{
serialize(item.num_dims, out);
serialize(item.num_nonnegative, out);
}
friend void deserialize (feature_extractor_track_association& item, std::istream& in)
{
deserialize(item.num_dims, in);
deserialize(item.num_nonnegative, in);
}
private:
unsigned long num_dims;
unsigned long num_nonnegative;
};
// ----------------------------------------------------------------------------------------
template <
typename detection_type_
>
class track_association_function
{
public:
typedef detection_type_ detection_type;
typedef typename detection_type::track_type track_type;
typedef assignment_function<feature_extractor_track_association<detection_type> > association_function_type;
track_association_function() {}
track_association_function (
const association_function_type& assoc_
) : assoc(assoc_)
{
}
const association_function_type& get_assignment_function (
) const
{
return assoc;
}
void operator() (
std::vector<track_type>& tracks,
const std::vector<detection_type>& dets
) const
{
std::vector<long> assignments = assoc(dets, tracks);
std::vector<bool> updated_track(tracks.size(), false);
// now update all the tracks with the detections that associated to them.
for (unsigned long i = 0; i < assignments.size(); ++i)
{
if (assignments[i] != -1)
{
tracks[assignments[i]].update_track(dets[i]);
updated_track[assignments[i]] = true;
}
else
{
track_type new_track;
new_track.update_track(dets[i]);
tracks.push_back(new_track);
}
}
// Now propagate all the tracks that didn't get any detections.
for (unsigned long i = 0; i < updated_track.size(); ++i)
{
if (!updated_track[i])
tracks[i].propagate_track();
}
}
friend void serialize (const track_association_function& item, std::ostream& out)
{
int version = 1;
serialize(version, out);
serialize(item.assoc, out);
}
friend void deserialize (track_association_function& item, std::istream& in)
{
int version = 0;
deserialize(version, in);
if (version != 1)
throw serialization_error("Unexpected version found while deserializing dlib::track_association_function.");
deserialize(item.assoc, in);
}
private:
assignment_function<feature_extractor_track_association<detection_type> > assoc;
};
// ----------------------------------------------------------------------------------------
}
#endif // DLIB_TRACK_ASSOCiATION_FUNCTION_H__
// Copyright (C) 2014 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#undef DLIB_TRACK_ASSOCiATION_FUNCTION_ABSTRACT_H__
#ifdef DLIB_TRACK_ASSOCiATION_FUNCTION_ABSTRACT_H__
#include <vector>
#include "assignment_function_abstract.h"
namespace dlib
{
// ----------------------------------------------------------------------------------------
class example_detection
{
/*!
WHAT THIS OBJECT REPRESENTS
!*/
public:
// Each detection object should be designed to work with a specific track object.
// This typedef lets you determine which track type is meant for use with this
// detection object.
typedef struct example_track track_type;
};
// ----------------------------------------------------------------------------------------
class example_track
{
/*!
WHAT THIS OBJECT REPRESENTS
!*/
public:
// This type should be a dlib::matrix capable of storing column vectors
// or an unsorted sparse vector type as defined in dlib/svm/sparse_vector_abstract.h.
typedef matrix_or_sparse_vector_type feature_vector_type;
example_track(
);
/*!
ensures
- this object is properly initialized
!*/
void get_similarity_features (
const example_detection& det,
feature_vector_type& feats
) const;
/*!
ensures
- #feats == A feature vector that contains information describing how
likely it is that det is a detection from the object corresponding to
this track. That is, the feature vector should contain information that
lets someone decide if det should be associated to this track.
!*/
void update_track (
const example_detection& det
);
/*!
ensures
- Updates this track with the given detection assuming that det is the most
current observation of the object under track.
!*/
void propagate_track (
);
/*!
ensures
- propagates this track forward in time one time step.
!*/
};
// ----------------------------------------------------------------------------------------
template <
typename detection_type
>
class feature_extractor_track_association
{
/*!
WHAT THIS OBJECT REPRESENTS
This object is an adapter that converts from the detection/track style
interface defined above to the feature extraction interface required by the
association rule learning tools in dlib. Specifically, it converts the
detection/track interface into a form usable by the assignment_function and
its trainer structural_assignment_trainer.
!*/
public:
typedef typename detection_type::track_type track_type;
typedef typename track_type::feature_vector_type feature_vector_type;
typedef detection_type lhs_element;
typedef track_type rhs_element;
unsigned long num_features(
) const;
/*!
ensures
- returns the dimensionality of the feature vectors produced by get_features().
!*/
void get_features (
const detection_type& det,
const track_type& track,
feature_vector_type& feats
) const;
/*!
ensures
- performs: track.get_similarity_features(det, feats);
!*/
};
void serialize (const feature_extractor_track_association& item, std::ostream& out);
void deserialize (feature_extractor_track_association& item, std::istream& in);
/*!
Provides serialization and deserialization support.
!*/
// ----------------------------------------------------------------------------------------
template <
typename detection_type_
>
class track_association_function
{
/*!
WHAT THIS OBJECT REPRESENTS
!*/
public:
typedef detection_type_ detection_type;
typedef typename detection_type::track_type track_type;
typedef assignment_function<feature_extractor_track_association<detection_type> > association_function_type;
track_association_function(
);
/*!
ensures
- #get_assignment_function() will be default initialized.
!*/
track_association_function (
const association_function_type& assoc_
);
/*!
ensures
- #get_assignment_function() == assoc
!*/
const association_function_type& get_assignment_function (
) const;
/*!
ensures
- returns the assignment_function used by this object to assign detections
to tracks.
!*/
void operator() (
std::vector<track_type>& tracks,
const std::vector<detection_type>& dets
) const;
/*!
ensures
- This function uses get_assignment_function() to assign all the detections
in dets to their appropriate track in tracks. Then each track which
associates to a detection is updated by calling update_track() with the
associated detection.
- Detections that don't associate with any of the elements of tracks will
spawn new tracks. For each unassociated detection, this is done by
creating a new track_type object, calling update_track() on it with the
new detection, and then adding the new track into tracks.
- Tracks that don't have a detection associate to them are propagated
forward in time by calling propagate_track() on them. That is, we call
propagate_track() only on tracks that do not get associated with a
detection.
!*/
};
void serialize (const track_association_function& item, std::ostream& out);
void deserialize (track_association_function& item, std::istream& in);
/*!
Provides serialization and deserialization support.
!*/
// ----------------------------------------------------------------------------------------
}
#endif // DLIB_TRACK_ASSOCiATION_FUNCTION_ABSTRACT_H__
...@@ -14,6 +14,8 @@ ...@@ -14,6 +14,8 @@
#include "svm/structural_svm_assignment_problem.h" #include "svm/structural_svm_assignment_problem.h"
#include "svm/structural_assignment_trainer.h" #include "svm/structural_assignment_trainer.h"
#include "svm/cross_validate_track_association_trainer.h"
#include "svm/structural_track_association_trainer.h"
#include "svm/structural_svm_graph_labeling_problem.h" #include "svm/structural_svm_graph_labeling_problem.h"
#include "svm/structural_graph_labeling_trainer.h" #include "svm/structural_graph_labeling_trainer.h"
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment