Commit 3fc43218 authored by Davis King's avatar Davis King

Refactored code a little and added more comments to spec.

parent fcfafe67
...@@ -43,13 +43,13 @@ namespace dlib ...@@ -43,13 +43,13 @@ namespace dlib
double dot( double dot(
const matrix_exp<EXP>& lambda, const matrix_exp<EXP>& lambda,
const feature_extractor& fe, const feature_extractor& fe,
unsigned long position, const std::vector<sample_type>& sequence,
const matrix_exp<EXP2>& label_states, const matrix_exp<EXP2>& candidate_labeling,
const std::vector<sample_type>& x unsigned long position
) )
{ {
dot_functor<EXP> dot(lambda); dot_functor<EXP> dot(lambda);
fe.get_features(dot, position, label_states, x); fe.get_features(dot, sequence, candidate_labeling, position);
return dot.value; return dot.value;
} }
...@@ -79,7 +79,7 @@ namespace dlib ...@@ -79,7 +79,7 @@ namespace dlib
const feature_extractor& fe_, const feature_extractor& fe_,
const matrix<double,0,1>& weights_ const matrix<double,0,1>& weights_
) : ) :
x(x_), sequence(x_),
fe(fe_), fe(fe_),
weights(weights_) weights(weights_)
{ {
...@@ -88,7 +88,7 @@ namespace dlib ...@@ -88,7 +88,7 @@ namespace dlib
unsigned long number_of_nodes( unsigned long number_of_nodes(
) const ) const
{ {
return x.size(); return sequence.size();
} }
template < template <
...@@ -99,13 +99,13 @@ namespace dlib ...@@ -99,13 +99,13 @@ namespace dlib
const matrix_exp<EXP>& node_states const matrix_exp<EXP>& node_states
) const ) const
{ {
if (fe.reject_labeling(node_id, node_states, x)) if (fe.reject_labeling(sequence, node_states, node_id))
return -std::numeric_limits<double>::infinity(); return -std::numeric_limits<double>::infinity();
return fe_helpers::dot(weights, fe, node_id, node_states, x); return fe_helpers::dot(weights, fe, sequence, node_states, node_id);
} }
const sample_sequence_type& x; const sample_sequence_type& sequence;
const feature_extractor& fe; const feature_extractor& fe;
const matrix<double,0,1>& weights; const matrix<double,0,1>& weights;
}; };
......
...@@ -10,6 +10,97 @@ ...@@ -10,6 +10,97 @@
namespace dlib namespace dlib
{ {
// ----------------------------------------------------------------------------------------
class example_feature_extractor
{
/*!
WHAT THIS OBJECT REPRESENTS
!*/
public:
typedef word_type sample_type;
example_feature_extractor (
);
unsigned long num_features (
) const;
unsigned long order(
) const;
unsigned long num_labels(
) const;
template <typename EXP>
bool reject_labeling (
const std::vector<sample_type>& sequence,
const matrix_exp<EXP>& candidate_labeling,
unsigned long position
) const;
/*!
requires
- EXP::type == unsigned long
(i.e. candidate_labeling contains unsigned longs)
- position < sequence.size()
- candidate_labeling.size() == min(position, order) + 1
- is_vector(candidate_labeling) == true
- max(candidate_labeling) < num_labels()
ensures
- if (the given candidate_labeling for sequence[position] is
always the wrong labeling) then
- returns true
(note that reject_labeling() is just an optional tool to allow
you to overrule the learning algorithm. You don't have to use
it. So if you prefer you can set reject_labeling() to always
return false.)
- else
- returns false
!*/
template <typename feature_setter, typename EXP>
void get_features (
feature_setter& set_feature,
const std::vector<sample_type>& sequence,
const matrix_exp<EXP>& candidate_labeling,
unsigned long position
) const;
/*!
requires
- EXP::type == unsigned long
(i.e. candidate_labeling contains unsigned longs)
- position < sequence.size()
- candidate_labeling.size() == min(position, order) + 1
- is_vector(candidate_labeling) == true
- max(candidate_labeling) < num_labels()
- set_feature is a function object which allows expressions of the form:
- set_features((unsigned long)feature_index, (double)feature_value);
- set_features((unsigned long)feature_index);
ensures
!*/
};
// ----------------------------------------------------------------------------------------
void serialize(
const example_feature_extractor& item,
std::ostream& out
);
/*!
provides serialization support
!*/
void deserialize(
example_feature_extractor& item,
std::istream& in
);
/*!
provides deserialization support
!*/
// ----------------------------------------------------------------------------------------
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
template < template <
...@@ -18,6 +109,10 @@ namespace dlib ...@@ -18,6 +109,10 @@ namespace dlib
class sequence_labeler class sequence_labeler
{ {
/*! /*!
REQUIREMENTS ON feature_extractor
It must be an object that implements an interface compatible with
the example_feature_extractor discussed above.
WHAT THIS OBJECT REPRESENTS WHAT THIS OBJECT REPRESENTS
!*/ !*/
...@@ -30,18 +125,29 @@ namespace dlib ...@@ -30,18 +125,29 @@ namespace dlib
sequence_labeler() {} sequence_labeler() {}
sequence_labeler( sequence_labeler(
const feature_extractor& fe_, const feature_extractor& fe,
const matrix<double,0,1>& weights_ const matrix<double,0,1>& weights
); );
/*!
requires
- fe.num_features() == weights.size()
ensures
- #get_feature_extractor() == fe
- #get_weights() == weights
!*/
const feature_extractor& get_feature_extractor ( const feature_extractor& get_feature_extractor (
) const; ) const;
const matrix<double,0,1>& get_weights ( const matrix<double,0,1>& get_weights (
) const; ) const;
/*!
ensures
- returns a vector of length get_feature_extractor().num_features()
!*/
unsigned long num_labels ( unsigned long num_labels (
) const { return fe.num_labels(); } ) const { return get_feature_extractor().num_labels(); }
labeled_sequence_type operator() ( labeled_sequence_type operator() (
const sample_sequence_type& x const sample_sequence_type& x
......
...@@ -26,7 +26,7 @@ namespace dlib ...@@ -26,7 +26,7 @@ namespace dlib
typedef sequence_labeler<feature_extractor> trained_function_type; typedef sequence_labeler<feature_extractor> trained_function_type;
structural_sequence_labeling_trainer ( explicit structural_sequence_labeling_trainer (
const feature_extractor& fe_ const feature_extractor& fe_
) : fe(fe_) ) : fe(fe_)
{} {}
......
...@@ -29,7 +29,7 @@ namespace dlib ...@@ -29,7 +29,7 @@ namespace dlib
typedef sequence_labeler<feature_extractor> trained_function_type; typedef sequence_labeler<feature_extractor> trained_function_type;
structural_sequence_labeling_trainer ( explicit structural_sequence_labeling_trainer (
const feature_extractor& fe_ const feature_extractor& fe_
) : fe(fe_) ) : fe(fe_)
{} {}
......
...@@ -48,13 +48,13 @@ namespace dlib ...@@ -48,13 +48,13 @@ namespace dlib
void get_feature_vector( void get_feature_vector(
std::vector<std::pair<unsigned long, double> >& feats, std::vector<std::pair<unsigned long, double> >& feats,
const feature_extractor& fe, const feature_extractor& fe,
unsigned long position, const std::vector<sample_type>& sequence,
const matrix_exp<EXP2>& label_states, const matrix_exp<EXP2>& candidate_labeling,
const std::vector<sample_type>& x unsigned long position
) )
{ {
get_feats_functor funct(feats); get_feats_functor funct(feats);
fe.get_features(funct, position, label_states, x); fe.get_features(funct, sequence,candidate_labeling, position);
} }
} }
...@@ -108,12 +108,12 @@ namespace dlib ...@@ -108,12 +108,12 @@ namespace dlib
const int order = fe.order(); const int order = fe.order();
matrix<unsigned long,0,1> label_states; matrix<unsigned long,0,1> candidate_labeling;
for (unsigned long i = 0; i < sample.size(); ++i) for (unsigned long i = 0; i < sample.size(); ++i)
{ {
label_states = rowm(vector_to_matrix(label), range(i, std::max((int)i-order,0))); candidate_labeling = rowm(vector_to_matrix(label), range(i, std::max((int)i-order,0)));
fe_helpers::get_feature_vector(psi,fe,i,label_states, sample); fe_helpers::get_feature_vector(psi,fe,sample,candidate_labeling, i);
} }
} }
...@@ -132,12 +132,12 @@ namespace dlib ...@@ -132,12 +132,12 @@ namespace dlib
unsigned long num_states() const { return fe.num_labels(); } unsigned long num_states() const { return fe.num_labels(); }
map_prob( map_prob(
const std::vector<sample_type>& sample_, const std::vector<sample_type>& sequence_,
const std::vector<unsigned long>& label_, const std::vector<unsigned long>& label_,
const feature_extractor& fe_, const feature_extractor& fe_,
const matrix<double,0,1>& weights_ const matrix<double,0,1>& weights_
) : ) :
sample(sample_), sequence(sequence_),
label(label_), label(label_),
fe(fe_), fe(fe_),
weights(weights_) weights(weights_)
...@@ -147,7 +147,7 @@ namespace dlib ...@@ -147,7 +147,7 @@ namespace dlib
unsigned long number_of_nodes( unsigned long number_of_nodes(
) const ) const
{ {
return sample.size(); return sequence.size();
} }
template < template <
...@@ -165,10 +165,10 @@ namespace dlib ...@@ -165,10 +165,10 @@ namespace dlib
if (node_states(0) != label[node_id]) if (node_states(0) != label[node_id])
loss = 1; loss = 1;
return fe_helpers::dot(weights, fe, node_id, node_states, sample) + loss; return fe_helpers::dot(weights, fe, sequence, node_states, node_id) + loss;
} }
const std::vector<sample_type>& sample; const std::vector<sample_type>& sequence;
const std::vector<unsigned long>& label; const std::vector<unsigned long>& label;
const feature_extractor& fe; const feature_extractor& fe;
const matrix<double,0,1>& weights; const matrix<double,0,1>& weights;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment