Commit 3fc43218 authored by Davis King's avatar Davis King

Refactored code a little and added more comments to spec.

parent fcfafe67
......@@ -43,13 +43,13 @@ namespace dlib
double dot(
const matrix_exp<EXP>& lambda,
const feature_extractor& fe,
unsigned long position,
const matrix_exp<EXP2>& label_states,
const std::vector<sample_type>& x
const std::vector<sample_type>& sequence,
const matrix_exp<EXP2>& candidate_labeling,
unsigned long position
)
{
dot_functor<EXP> dot(lambda);
fe.get_features(dot, position, label_states, x);
fe.get_features(dot, sequence, candidate_labeling, position);
return dot.value;
}
......@@ -79,7 +79,7 @@ namespace dlib
const feature_extractor& fe_,
const matrix<double,0,1>& weights_
) :
x(x_),
sequence(x_),
fe(fe_),
weights(weights_)
{
......@@ -88,7 +88,7 @@ namespace dlib
unsigned long number_of_nodes(
) const
{
return x.size();
return sequence.size();
}
template <
......@@ -99,13 +99,13 @@ namespace dlib
const matrix_exp<EXP>& node_states
) const
{
if (fe.reject_labeling(node_id, node_states, x))
if (fe.reject_labeling(sequence, node_states, node_id))
return -std::numeric_limits<double>::infinity();
return fe_helpers::dot(weights, fe, node_id, node_states, x);
return fe_helpers::dot(weights, fe, sequence, node_states, node_id);
}
const sample_sequence_type& x;
const sample_sequence_type& sequence;
const feature_extractor& fe;
const matrix<double,0,1>& weights;
};
......
......@@ -10,6 +10,97 @@
namespace dlib
{
// ----------------------------------------------------------------------------------------
class example_feature_extractor
{
/*!
WHAT THIS OBJECT REPRESENTS
!*/
public:
typedef word_type sample_type;
example_feature_extractor (
);
unsigned long num_features (
) const;
unsigned long order(
) const;
unsigned long num_labels(
) const;
template <typename EXP>
bool reject_labeling (
const std::vector<sample_type>& sequence,
const matrix_exp<EXP>& candidate_labeling,
unsigned long position
) const;
/*!
requires
- EXP::type == unsigned long
(i.e. candidate_labeling contains unsigned longs)
- position < sequence.size()
- candidate_labeling.size() == min(position, order) + 1
- is_vector(candidate_labeling) == true
- max(candidate_labeling) < num_labels()
ensures
- if (the given candidate_labeling for sequence[position] is
always the wrong labeling) then
- returns true
(note that reject_labeling() is just an optional tool to allow
you to overrule the learning algorithm. You don't have to use
it. So if you prefer you can set reject_labeling() to always
return false.)
- else
- returns false
!*/
template <typename feature_setter, typename EXP>
void get_features (
feature_setter& set_feature,
const std::vector<sample_type>& sequence,
const matrix_exp<EXP>& candidate_labeling,
unsigned long position
) const;
/*!
requires
- EXP::type == unsigned long
(i.e. candidate_labeling contains unsigned longs)
- position < sequence.size()
- candidate_labeling.size() == min(position, order) + 1
- is_vector(candidate_labeling) == true
- max(candidate_labeling) < num_labels()
- set_feature is a function object which allows expressions of the form:
- set_features((unsigned long)feature_index, (double)feature_value);
- set_features((unsigned long)feature_index);
ensures
!*/
};
// ----------------------------------------------------------------------------------------
void serialize(
const example_feature_extractor& item,
std::ostream& out
);
/*!
provides serialization support
!*/
void deserialize(
example_feature_extractor& item,
std::istream& in
);
/*!
provides deserialization support
!*/
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
template <
......@@ -18,6 +109,10 @@ namespace dlib
class sequence_labeler
{
/*!
REQUIREMENTS ON feature_extractor
It must be an object that implements an interface compatible with
the example_feature_extractor discussed above.
WHAT THIS OBJECT REPRESENTS
!*/
......@@ -30,18 +125,29 @@ namespace dlib
sequence_labeler() {}
sequence_labeler(
const feature_extractor& fe_,
const matrix<double,0,1>& weights_
const feature_extractor& fe,
const matrix<double,0,1>& weights
);
/*!
requires
- fe.num_features() == weights.size()
ensures
- #get_feature_extractor() == fe
- #get_weights() == weights
!*/
const feature_extractor& get_feature_extractor (
) const;
const matrix<double,0,1>& get_weights (
) const;
/*!
ensures
- returns a vector of length get_feature_extractor().num_features()
!*/
unsigned long num_labels (
) const { return fe.num_labels(); }
) const { return get_feature_extractor().num_labels(); }
labeled_sequence_type operator() (
const sample_sequence_type& x
......
......@@ -26,7 +26,7 @@ namespace dlib
typedef sequence_labeler<feature_extractor> trained_function_type;
structural_sequence_labeling_trainer (
explicit structural_sequence_labeling_trainer (
const feature_extractor& fe_
) : fe(fe_)
{}
......
......@@ -29,7 +29,7 @@ namespace dlib
typedef sequence_labeler<feature_extractor> trained_function_type;
structural_sequence_labeling_trainer (
explicit structural_sequence_labeling_trainer (
const feature_extractor& fe_
) : fe(fe_)
{}
......
......@@ -48,13 +48,13 @@ namespace dlib
void get_feature_vector(
std::vector<std::pair<unsigned long, double> >& feats,
const feature_extractor& fe,
unsigned long position,
const matrix_exp<EXP2>& label_states,
const std::vector<sample_type>& x
const std::vector<sample_type>& sequence,
const matrix_exp<EXP2>& candidate_labeling,
unsigned long position
)
{
get_feats_functor funct(feats);
fe.get_features(funct, position, label_states, x);
fe.get_features(funct, sequence,candidate_labeling, position);
}
}
......@@ -108,12 +108,12 @@ namespace dlib
const int order = fe.order();
matrix<unsigned long,0,1> label_states;
matrix<unsigned long,0,1> candidate_labeling;
for (unsigned long i = 0; i < sample.size(); ++i)
{
label_states = rowm(vector_to_matrix(label), range(i, std::max((int)i-order,0)));
candidate_labeling = rowm(vector_to_matrix(label), range(i, std::max((int)i-order,0)));
fe_helpers::get_feature_vector(psi,fe,i,label_states, sample);
fe_helpers::get_feature_vector(psi,fe,sample,candidate_labeling, i);
}
}
......@@ -132,12 +132,12 @@ namespace dlib
unsigned long num_states() const { return fe.num_labels(); }
map_prob(
const std::vector<sample_type>& sample_,
const std::vector<sample_type>& sequence_,
const std::vector<unsigned long>& label_,
const feature_extractor& fe_,
const matrix<double,0,1>& weights_
) :
sample(sample_),
sequence(sequence_),
label(label_),
fe(fe_),
weights(weights_)
......@@ -147,7 +147,7 @@ namespace dlib
unsigned long number_of_nodes(
) const
{
return sample.size();
return sequence.size();
}
template <
......@@ -165,10 +165,10 @@ namespace dlib
if (node_states(0) != label[node_id])
loss = 1;
return fe_helpers::dot(weights, fe, node_id, node_states, sample) + loss;
return fe_helpers::dot(weights, fe, sequence, node_states, node_id) + loss;
}
const std::vector<sample_type>& sample;
const std::vector<sample_type>& sequence;
const std::vector<unsigned long>& label;
const feature_extractor& fe;
const matrix<double,0,1>& weights;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment