Commit 71d4306e authored by Davis King's avatar Davis King

Added a bias term to the assignment_function's model so the user doesn't need

to remember, or even understand, that they should add it themselves.  However,
this change breaks backwards compatibility with the previous serialization
format for assignment_function objects.
parent 55631944
...@@ -33,13 +33,16 @@ namespace dlib ...@@ -33,13 +33,16 @@ namespace dlib
{ {
weights.set_size(fe.num_features()); weights.set_size(fe.num_features());
weights = 0; weights = 0;
bias = 0;
force_assignment = false; force_assignment = false;
} }
explicit assignment_function( explicit assignment_function(
const matrix<double,0,1>& weights_ const matrix<double,0,1>& weights_,
double bias_
) : ) :
weights(weights_), weights(weights_),
bias(bias_),
force_assignment(false) force_assignment(false)
{ {
// make sure requires clause is not broken // make sure requires clause is not broken
...@@ -55,10 +58,12 @@ namespace dlib ...@@ -55,10 +58,12 @@ namespace dlib
assignment_function( assignment_function(
const matrix<double,0,1>& weights_, const matrix<double,0,1>& weights_,
double bias_,
const feature_extractor& fe_ const feature_extractor& fe_
) : ) :
fe(fe_), fe(fe_),
weights(weights_), weights(weights_),
bias(bias_),
force_assignment(false) force_assignment(false)
{ {
// make sure requires clause is not broken // make sure requires clause is not broken
...@@ -73,11 +78,13 @@ namespace dlib ...@@ -73,11 +78,13 @@ namespace dlib
assignment_function( assignment_function(
const matrix<double,0,1>& weights_, const matrix<double,0,1>& weights_,
double bias_,
const feature_extractor& fe_, const feature_extractor& fe_,
bool force_assignment_ bool force_assignment_
) : ) :
fe(fe_), fe(fe_),
weights(weights_), weights(weights_),
bias(bias_),
force_assignment(force_assignment_) force_assignment(force_assignment_)
{ {
// make sure requires clause is not broken // make sure requires clause is not broken
...@@ -96,6 +103,9 @@ namespace dlib ...@@ -96,6 +103,9 @@ namespace dlib
const matrix<double,0,1>& get_weights ( const matrix<double,0,1>& get_weights (
) const { return weights; } ) const { return weights; }
double get_bias (
) const { return bias; }
bool forces_assignment ( bool forces_assignment (
) const { return force_assignment; } ) const { return force_assignment; }
...@@ -130,7 +140,7 @@ namespace dlib ...@@ -130,7 +140,7 @@ namespace dlib
if (r < (long)lhs.size() && c < (long)rhs.size()) if (r < (long)lhs.size() && c < (long)rhs.size())
{ {
fe.get_features(lhs[r], rhs[c], feats); fe.get_features(lhs[r], rhs[c], feats);
cost(r,c) = dot(weights, feats); cost(r,c) = dot(weights, feats) + bias;
} }
else else
{ {
...@@ -188,6 +198,7 @@ namespace dlib ...@@ -188,6 +198,7 @@ namespace dlib
feature_extractor fe; feature_extractor fe;
matrix<double,0,1> weights; matrix<double,0,1> weights;
double bias;
bool force_assignment; bool force_assignment;
}; };
...@@ -201,8 +212,11 @@ namespace dlib ...@@ -201,8 +212,11 @@ namespace dlib
std::ostream& out std::ostream& out
) )
{ {
int version = 2;
serialize(version, out);
serialize(item.get_feature_extractor(), out); serialize(item.get_feature_extractor(), out);
serialize(item.get_weights(), out); serialize(item.get_weights(), out);
serialize(item.get_bias(), out);
serialize(item.forces_assignment(), out); serialize(item.forces_assignment(), out);
} }
...@@ -218,13 +232,19 @@ namespace dlib ...@@ -218,13 +232,19 @@ namespace dlib
{ {
feature_extractor fe; feature_extractor fe;
matrix<double,0,1> weights; matrix<double,0,1> weights;
double bias;
bool force_assignment; bool force_assignment;
int version = 0;
deserialize(version, in);
if (version != 2)
throw serialization_error("Unexpected version found while deserializing dlib::assignment_function.");
deserialize(fe, in); deserialize(fe, in);
deserialize(weights, in); deserialize(weights, in);
deserialize(bias, in);
deserialize(force_assignment, in); deserialize(force_assignment, in);
item = assignment_function<feature_extractor>(weights, fe, force_assignment); item = assignment_function<feature_extractor>(weights, bias, fe, force_assignment);
} }
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
......
...@@ -29,9 +29,9 @@ namespace dlib ...@@ -29,9 +29,9 @@ namespace dlib
case it is excluded from the sum. case it is excluded from the sum.
Finally, match_score() is defined as: Finally, match_score() is defined as:
match_score(l,r) == dot(w, PSI(l,r)) match_score(l,r) == dot(w, PSI(l,r)) + bias
where l is an element of LHS, r is an element of RHS, and where l is an element of LHS, r is an element of RHS, w is a parameter
w is a parameter vector. vector and bias is a scalar valued parameter.
Therefore, a feature extractor defines how the PSI() feature vector Therefore, a feature extractor defines how the PSI() feature vector
is calculated. In particular, PSI() is defined by the get_features() is calculated. In particular, PSI() is defined by the get_features()
...@@ -140,9 +140,10 @@ namespace dlib ...@@ -140,9 +140,10 @@ namespace dlib
case it is excluded from the sum. case it is excluded from the sum.
Finally, this object supports match_score() functions of the form: Finally, this object supports match_score() functions of the form:
match_score(l,r) == dot(w, PSI(l,r)) match_score(l,r) == dot(w, PSI(l,r)) + bias
where l is an element of LHS, r is an element of RHS, w is a parameter where l is an element of LHS, r is an element of RHS, w is a parameter
vector, and PSI() is defined by the feature_extractor template argument. vector, bias is a scalar valued parameter, and PSI() is defined by the
feature_extractor template argument.
THREAD SAFETY THREAD SAFETY
It is always safe to use distinct instances of this object in different It is always safe to use distinct instances of this object in different
...@@ -170,11 +171,13 @@ namespace dlib ...@@ -170,11 +171,13 @@ namespace dlib
(i.e. it will have its default value) (i.e. it will have its default value)
- #get_weights().size() == #get_feature_extractor().num_features() - #get_weights().size() == #get_feature_extractor().num_features()
- #get_weights() == 0 - #get_weights() == 0
- #get_bias() == 0
- #forces_assignment() == false - #forces_assignment() == false
!*/ !*/
explicit assignment_function( explicit assignment_function(
const matrix<double,0,1>& weights const matrix<double,0,1>& weights,
double bias
); );
/*! /*!
requires requires
...@@ -183,11 +186,13 @@ namespace dlib ...@@ -183,11 +186,13 @@ namespace dlib
- #get_feature_extractor() == feature_extractor() - #get_feature_extractor() == feature_extractor()
(i.e. it will have its default value) (i.e. it will have its default value)
- #get_weights() == weights - #get_weights() == weights
- #get_bias() == bias
- #forces_assignment() == false - #forces_assignment() == false
!*/ !*/
assignment_function( assignment_function(
const matrix<double,0,1>& weights, const matrix<double,0,1>& weights,
double bias,
const feature_extractor& fe const feature_extractor& fe
); );
/*! /*!
...@@ -196,11 +201,13 @@ namespace dlib ...@@ -196,11 +201,13 @@ namespace dlib
ensures ensures
- #get_feature_extractor() == fe - #get_feature_extractor() == fe
- #get_weights() == weights - #get_weights() == weights
- #get_bias() == bias
- #forces_assignment() == false - #forces_assignment() == false
!*/ !*/
assignment_function( assignment_function(
const matrix<double,0,1>& weights, const matrix<double,0,1>& weights,
double bias,
const feature_extractor& fe, const feature_extractor& fe,
bool force_assignment bool force_assignment
); );
...@@ -210,6 +217,7 @@ namespace dlib ...@@ -210,6 +217,7 @@ namespace dlib
ensures ensures
- #get_feature_extractor() == fe - #get_feature_extractor() == fe
- #get_weights() == weights - #get_weights() == weights
- #get_bias() == bias
- #forces_assignment() == force_assignment - #forces_assignment() == force_assignment
!*/ !*/
...@@ -228,6 +236,13 @@ namespace dlib ...@@ -228,6 +236,13 @@ namespace dlib
The length of the vector is get_feature_extractor().num_features(). The length of the vector is get_feature_extractor().num_features().
!*/ !*/
double get_bias (
) const;
/*!
ensures
- returns the bias parameter associated with this assignment function.
!*/
bool forces_assignment ( bool forces_assignment (
) const; ) const;
/*! /*!
......
...@@ -184,9 +184,13 @@ namespace dlib ...@@ -184,9 +184,13 @@ namespace dlib
matrix<double,0,1> weights; matrix<double,0,1> weights;
solver(prob, weights, num_nonnegative_weights(fe)); // Take the min here because we want to prevent the user from accidentally
// forcing the bias term to be non-negative.
const unsigned long num_nonneg = std::min(fe.num_features(),num_nonnegative_weights(fe));
solver(prob, weights, num_nonneg);
return assignment_function<feature_extractor>(weights,fe,force_assignment); const double bias = weights(weights.size()-1);
return assignment_function<feature_extractor>(colm(weights,0,weights.size()-1), bias,fe,force_assignment);
} }
......
...@@ -14,16 +14,41 @@ ...@@ -14,16 +14,41 @@
namespace dlib namespace dlib
{ {
template <long n, typename T>
struct column_matrix_static_resize
{
typedef T type;
};
template <long n, typename T, long NR, long NC, typename MM, typename L>
struct column_matrix_static_resize<n, matrix<T,NR,NC,MM,L> >
{
typedef matrix<T,NR+n,NC,MM,L> type;
};
template <long n, typename T, long NC, typename MM, typename L>
struct column_matrix_static_resize<n, matrix<T,0,NC,MM,L> >
{
typedef matrix<T,0,NC,MM,L> type;
};
template <typename T>
struct add_one_to_static_feat_size
{
typedef typename column_matrix_static_resize<1,typename T::feature_vector_type>::type type;
};
// ----------------------------------------------------------------------------------------
template < template <
typename feature_extractor typename feature_extractor
> >
class structural_svm_assignment_problem : noncopyable, class structural_svm_assignment_problem : noncopyable,
public structural_svm_problem_threaded<matrix<double,0,1>, typename feature_extractor::feature_vector_type > public structural_svm_problem_threaded<matrix<double,0,1>, typename add_one_to_static_feat_size<feature_extractor>::type >
{ {
public: public:
typedef matrix<double,0,1> matrix_type; typedef matrix<double,0,1> matrix_type;
typedef typename feature_extractor::feature_vector_type feature_vector_type; typedef typename add_one_to_static_feat_size<feature_extractor>::type feature_vector_type;
typedef typename feature_extractor::lhs_element lhs_element; typedef typename feature_extractor::lhs_element lhs_element;
typedef typename feature_extractor::rhs_element rhs_element; typedef typename feature_extractor::rhs_element rhs_element;
...@@ -77,7 +102,7 @@ namespace dlib ...@@ -77,7 +102,7 @@ namespace dlib
virtual long get_num_dimensions ( virtual long get_num_dimensions (
) const ) const
{ {
return fe.num_features(); return fe.num_features()+1; // +1 for the bias term
} }
virtual long get_num_samples ( virtual long get_num_samples (
...@@ -94,14 +119,15 @@ namespace dlib ...@@ -94,14 +119,15 @@ namespace dlib
) const ) const
{ {
typename feature_extractor::feature_vector_type feats; typename feature_extractor::feature_vector_type feats;
psi.set_size(fe.num_features()); psi.set_size(get_num_dimensions());
psi = 0; psi = 0;
for (unsigned long i = 0; i < sample.first.size(); ++i) for (unsigned long i = 0; i < sample.first.size(); ++i)
{ {
if (label[i] != -1) if (label[i] != -1)
{ {
fe.get_features(sample.first[i], sample.second[label[i]], feats); fe.get_features(sample.first[i], sample.second[label[i]], feats);
psi += feats; set_rowm(psi,range(0,feats.size()-1)) += feats;
psi(get_num_dimensions()-1) += 1;
} }
} }
} }
...@@ -123,15 +149,18 @@ namespace dlib ...@@ -123,15 +149,18 @@ namespace dlib
) const ) const
{ {
psi.clear(); psi.clear();
typename feature_extractor::feature_vector_type feats; feature_vector_type feats;
int num_assignments = 0;
for (unsigned long i = 0; i < sample.first.size(); ++i) for (unsigned long i = 0; i < sample.first.size(); ++i)
{ {
if (label[i] != -1) if (label[i] != -1)
{ {
fe.get_features(sample.first[i], sample.second[label[i]], feats); fe.get_features(sample.first[i], sample.second[label[i]], feats);
append_to_sparse_vect(psi, feats); append_to_sparse_vect(psi, feats);
++num_assignments;
} }
} }
psi.push_back(std::make_pair(get_num_dimensions()-1,num_assignments));
} }
virtual void get_truth_joint_feature_vector ( virtual void get_truth_joint_feature_vector (
...@@ -176,7 +205,8 @@ namespace dlib ...@@ -176,7 +205,8 @@ namespace dlib
if (c < (long)samples[idx].second.size()) if (c < (long)samples[idx].second.size())
{ {
fe.get_features(samples[idx].first[r], samples[idx].second[c], feats); fe.get_features(samples[idx].first[r], samples[idx].second[c], feats);
cost(r,c) = dot(current_solution, feats); const double bias = current_solution(current_solution.size()-1);
cost(r,c) = dot(colm(current_solution,0,current_solution.size()-1), feats) + bias;
// add in the loss since this corresponds to an incorrect prediction. // add in the loss since this corresponds to an incorrect prediction.
if (c != labels[idx][r]) if (c != labels[idx][r])
......
...@@ -27,9 +27,9 @@ namespace dlib ...@@ -27,9 +27,9 @@ namespace dlib
the example_feature_extractor defined in dlib/svm/assignment_function_abstract.h. the example_feature_extractor defined in dlib/svm/assignment_function_abstract.h.
WHAT THIS OBJECT REPRESENTS WHAT THIS OBJECT REPRESENTS
This object is a tool for learning the weight vector needed to use This object is a tool for learning the parameters needed to use an
an assignment_function object. It learns the parameter vector by assignment_function object. It learns the parameters by formulating the
formulating the problem as a structural SVM problem. problem as a structural SVM problem.
!*/ !*/
public: public:
...@@ -56,8 +56,8 @@ namespace dlib ...@@ -56,8 +56,8 @@ namespace dlib
- This object attempts to learn a mapping from the given samples to the - This object attempts to learn a mapping from the given samples to the
given labels. In particular, it attempts to learn to predict labels[i] given labels. In particular, it attempts to learn to predict labels[i]
based on samples[i]. Or in other words, this object can be used to learn based on samples[i]. Or in other words, this object can be used to learn
a parameter vector, w, such that an assignment_function declared as: a parameter vector and bias, w and b, such that an assignment_function declared as:
assignment_function<feature_extractor> assigner(w,fe,force_assignment) assignment_function<feature_extractor> assigner(w,b,fe,force_assignment)
results in an assigner object which attempts to compute the following mapping: results in an assigner object which attempts to compute the following mapping:
labels[i] == labeler(samples[i]) labels[i] == labeler(samples[i])
- This object will use num_threads threads during the optimization - This object will use num_threads threads during the optimization
......
...@@ -29,14 +29,14 @@ namespace ...@@ -29,14 +29,14 @@ namespace
struct feature_extractor_dense struct feature_extractor_dense
{ {
typedef matrix<double,4,1> feature_vector_type; typedef matrix<double,3,1> feature_vector_type;
typedef ::lhs_element lhs_element; typedef ::lhs_element lhs_element;
typedef ::rhs_element rhs_element; typedef ::rhs_element rhs_element;
unsigned long num_features() const unsigned long num_features() const
{ {
return 4; return 3;
} }
void get_features ( void get_features (
...@@ -45,7 +45,7 @@ namespace ...@@ -45,7 +45,7 @@ namespace
feature_vector_type& feats feature_vector_type& feats
) const ) const
{ {
feats = join_cols(squared(left - right), ones_matrix<double>(1,1)); feats = squared(left - right);
} }
}; };
...@@ -64,7 +64,7 @@ namespace ...@@ -64,7 +64,7 @@ namespace
unsigned long num_features() const unsigned long num_features() const
{ {
return 4; return 3;
} }
void get_features ( void get_features (
...@@ -77,7 +77,6 @@ namespace ...@@ -77,7 +77,6 @@ namespace
feats.push_back(make_pair(0,squared(left-right)(0))); feats.push_back(make_pair(0,squared(left-right)(0)));
feats.push_back(make_pair(1,squared(left-right)(1))); feats.push_back(make_pair(1,squared(left-right)(1)));
feats.push_back(make_pair(2,squared(left-right)(2))); feats.push_back(make_pair(2,squared(left-right)(2)));
feats.push_back(make_pair(3,1.0));
} }
}; };
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment