Commit 11ed2630 authored by Davis King's avatar Davis King

Added more docs and tests.

parent 62cd23ce
......@@ -46,6 +46,7 @@ namespace dlib
std::vector<feature>& feats
) const
{
DLIB_ASSERT(max_num_feats() != 0);
num = std::min(num, num_feats);
feats.clear();
......@@ -66,6 +67,7 @@ namespace dlib
const feature& f
) const
{
DLIB_ASSERT(max_num_feats() != 0);
return item(f);
}
......@@ -290,6 +292,7 @@ namespace dlib
size_t num
)
{
DLIB_CASSERT(num > 0);
num_trees = num;
}
......@@ -297,6 +300,7 @@ namespace dlib
double frac
)
{
DLIB_CASSERT(0 < frac && frac <= 1);
feature_subsampling_frac = frac;
}
......@@ -344,7 +348,7 @@ namespace dlib
trained_function_type train (
const std::vector<sample_type>& x,
const std::vector<double>& y,
std::vector<double>& oob_values // predicted y, basically like LOO-CV
std::vector<double>& oob_values
) const
{
return do_train(x,y,oob_values,true);
......@@ -355,7 +359,7 @@ namespace dlib
trained_function_type do_train (
const std::vector<sample_type>& x,
const std::vector<double>& y,
std::vector<double>& oob_values, // predicted y, basically like LOO-CV
std::vector<double>& oob_values,
bool compute_oob_values
) const
{
......
......@@ -21,7 +21,7 @@ namespace dlib
This particular feature extract does almost nothing since it works on
vectors in R^n and simply selects elements from each vector. However, the
tools below are templated and allow you to design your own feature extracts
tools below are templated and allow you to design your own feature extractors
that operate on whatever object types you create. So for example, maybe
you want to perform regression on images rather than vectors. Moreover,
your feature extraction could be more complex. Maybe you are selecting
......@@ -37,7 +37,9 @@ namespace dlib
THREAD SAFETY
It is safe to call const members of this object from multiple threads.
It is safe to call const members of this object from multiple threads. ANY
USER DEFINED FEATURE EXTRACTORS MUST ALSO MEET THIS GUARONTEE AS WELL SINCE
IT IS ASSUMED BY THE RANDOM FOREST TRAINING ROUTINES.
!*/
public:
......@@ -67,7 +69,10 @@ namespace dlib
the training vectors.
- #max_num_feats() == x[0].size()
(In general, setup() sets max_num_feats() to some non-zero value so that
the other methods of this object can then be called.)
the other methods of this object can then be called. The point of setup()
is to allow a feature extractor to gather whatever statistics it needs from
training data. That is, more complex feature extraction strategies my
themselves be trained from data.)
!*/
void get_random_features (
......@@ -79,11 +84,16 @@ namespace dlib
requires
- max_num_feats() != 0
ensures
- This function pulls out num randomly selected features from sample_type. If
sample_type was a simple object like a dense vector then the features could just
be integer indices into the vector. But for other objects it might be
something more complex.
- #feats.size() == min(num, max_num_feats())
- This function randomly identifies num features and stores them into feats.
These feature objects can then be used with extract_feature_value() to
obtain a value from any particular sample_type object. This value is the
"feature value" used by a decision tree algorithm to deice how to split
and traverse trees.
- The above two conditions define the behavior of get_random_features() in
general. For this specific implementation of the feature extraction interface
this function selects num integer values from the range [0, max_num_feats()),
without replacement. These values are stored into feats.
!*/
double extract_feature_value (
......@@ -104,9 +114,10 @@ namespace dlib
) const;
/*!
ensures
- returns the number of distinct features this object might extract.
- returns the number of distinct features this object might extract. That is,
a feature extractor essentially defines a mapping from sample_type objects to
vectors in R^max_num_feats().
!*/
};
void serialize(const dense_feature_extractor& item, std::ostream& out);
......@@ -150,6 +161,10 @@ namespace dlib
class random_forest_regression_function
{
/*!
REQUIREMENTS ON feature_extractor
feature_extractor must be dense_feature_extractor or a type with a
compatible interface.
WHAT THIS OBJECT REPRESENTS
This object represents a regression forest. This is a collection of
decision trees that take an object as input and each vote on a real value
......@@ -181,6 +196,7 @@ namespace dlib
- for all valid i:
- leaves[i].size() > 0
- trees[i].size()+leaves[i].size() > the maximal left or right index values in trees[i].
(i.e. each left or right value must index to some existing internal tree node or leaf node).
ensures
- #get_internal_tree_nodes() == trees_
- #get_tree_leaves() == leaves_
......@@ -244,8 +260,20 @@ namespace dlib
class random_forest_regression_trainer
{
/*!
WHAT THIS OBJECT REPRESENTS
REQUIREMENTS ON feature_extractor
feature_extractor must be dense_feature_extractor or a type with a
compatible interface.
WHAT THIS OBJECT REPRESENTS
This object implements Breiman's classic random forest regression
algorithm. The algorithm learns to map objects, nominally vectors in R^n,
into the reals. It essentially optimizes the mean squared error by fitting
a bunch of decision trees, each of which vote on the output value of the
regressor. The final prediction is obtained by averaging all the
predictions.
For more information on the algorithm see:
Breiman, Leo. "Random forests." Machine learning 45.1 (2001): 5-32.
!*/
public:
......@@ -268,17 +296,35 @@ namespace dlib
const feature_extractor_type& get_feature_extractor (
) const;
/*!
ensures
- returns the feature extractor used when train() is invoked.
!*/
void set_feature_extractor (
const feature_extractor_type& feat_extractor
);
/*!
ensures
- #get_feature_extractor() == feat_extractor
!*/
void set_seed (
const std::string& seed
);
/*!
ensures
- #get_random_seed() == seed
!*/
const std::string& get_random_seed (
) const;
/*!
ensures
- A central part of this algorithm is random selection of both training
samples and features. This function returns the seed used to initialized
the random number generator used for these random selections.
!*/
size_t get_num_trees (
) const;
......@@ -313,7 +359,7 @@ namespace dlib
ensures
- When we build trees, at each node we don't look at all the available
features. We consider only get_feature_subsampling_frac() fraction of
them at random.
them, selected at random.
!*/
void set_min_samples_per_leaf (
......@@ -331,7 +377,8 @@ namespace dlib
/*!
ensures
- When building trees, each leaf node in a tree will contain at least
get_min_samples_per_leaf() samples.
get_min_samples_per_leaf() samples. This means that the output votes of
each tree are averages of at least get_min_samples_per_leaf() y values.
!*/
void be_verbose (
......@@ -349,17 +396,60 @@ namespace dlib
- this object will not print anything to standard out
!*/
trained_function_type train (
random_forest_regression_function<feature_extractor> train (
const std::vector<sample_type>& x,
const std::vector<double>& y
const std::vector<double>& y,
std::vector<double>& oob_values
) const;
/*!
requires
- x.size() == y.size()
- x.size() > 0
- Running following code:
auto fe = get_feature_extractor()
fe.setup(x,y);
Must be valid and result in fe.max_num_feats() != 0
ensures
- This function fits a regression forest to the given training data. The
goal being to regress x to y in the mean squared sense. It therefore
fits regression trees and returns the resulting random_forest_regression_function
RF, which will have the following properties:
- RF.get_num_trees() == get_num_trees()
- for all valid i:
- RF(x[i]) should output a value close to y[i]
- RF.get_feature_extractor() will be a copy of this->get_feature_extractor()
that has been configured by a call the feature extractor's setup() routine.
To run the algorithm we need to use a feature extractor. We obtain a
valid feature extractor by making a copy of get_feature_extractor(), then
invoking setup(x,y) on it. This feature extractor is what is used to fit
the trees and is also the feature extractor stored in the returned random
forest.
- #oob_values.size() == y.size()
- for all valid i:
- #oob_values[i] == the "out of bag" prediction for y[i]. It is
calculated by computing the average output from trees not trained on
y[i]. This is similar to a leave-one-out cross-validation prediction
of y[i] and can be used to estimate the generalization error of the
regression forest.
- Training uses all the available CPU cores.
!*/
trained_function_type train (
random_forest_regression_function<feature_extractor> train (
const std::vector<sample_type>& x,
const std::vector<double>& y,
std::vector<double>& oob_values // predicted y, basically like LOO-CV
const std::vector<double>& y
) const;
/*!
requires
- x.size() == y.size()
- x.size() > 0
- Running following code:
auto fe = get_feature_extractor()
fe.setup(x,y);
Must be valid and result in fe.max_num_feats() != 0
ensures
- This function is identical to train(x,y,oob_values) except that the
oob_values are not calculated.
!*/
};
// ----------------------------------------------------------------------------------------
......
......@@ -41,6 +41,8 @@ namespace
{
istringstream sin(get_decoded_string());
print_spinner();
typedef matrix<double,0,1> sample_type;
std::vector<double> labels;
std::vector<sample_type> samples;
......@@ -61,7 +63,7 @@ namespace
auto result = test_regression_function(df, samples, labels);
// train: 2.239 0.987173 0.970669 1.1399
dlog << LINFO << "train: " << trans(result);
dlog << LINFO << "train: " << result;
DLIB_TEST_MSG(result(0) < 2.3, result(0));
running_stats<double> rs;
......@@ -69,6 +71,18 @@ namespace
rs.add(std::pow(oobs[i]-labels[i],2.0));
dlog << LINFO << "OOB MSE: "<< rs.mean();
DLIB_TEST_MSG(rs.mean() < 10.2, rs.mean());
print_spinner();
stringstream ss;
serialize(df, ss);
decltype(df) df2;
deserialize(df2, ss);
DLIB_TEST(df2.get_num_trees() == 1000);
result = test_regression_function(df2, samples, labels);
// train: 2.239 0.987173 0.970669 1.1399
dlog << LINFO << "serialized train results: " << result;
DLIB_TEST_MSG(result(0) < 2.3, result(0));
}
} a;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment