Commit 0642f20a authored by Davis King's avatar Davis King

Cleaned up the pegasos code and filled out the spec file.

--HG--
extra : convert_revision : svn%3Afdd8eb12-d10e-0410-9acb-85c331704f74/trunk%402906
parent f97644f3
......@@ -132,18 +132,12 @@ namespace dlib
const sample_type& x,
const scalar_type& y
)
/*!
requires
- y == 1 || y == -1
ensures
- trains this svm using the given sample x and label y
- returns the current learning rate
!*/
{
// make sure requires clause is not broken
DLIB_ASSERT(y == -1 || y == 1,
"\tdecision_function svm_pegasos::train(x,y)"
"\tscalar_type svm_pegasos::train(x,y)"
<< "\n\t invalid inputs were given to this function"
<< "\n\t y: " << y
);
++train_count;
const scalar_type learning_rate = 1/(lambda*train_count);
......@@ -152,7 +146,7 @@ namespace dlib
if (y*w.inner_product(x) < 1)
{
// w = (1-learning_rate*lambda) + y*learning_rate*x
// compute: w = (1-learning_rate*lambda) + y*learning_rate*x
w.train(x, 1 - learning_rate*lambda, y*learning_rate);
scalar_type wnorm = std::sqrt(w.squared_norm());
......@@ -194,6 +188,26 @@ namespace dlib
exchange(w, item.w);
}
friend void serialize(const svm_pegasos& item, std::ostream& out)
{
serialize(item.kernel, out);
serialize(item.lambda, out);
serialize(item.tau, out);
serialize(item.tolerance, out);
serialize(item.train_count, out);
serialize(item.w, out);
}
friend void deserialize(svm_pegasos& item, std::istream& in)
{
deserialize(item.kernel, in);
deserialize(item.lambda, in);
deserialize(item.tau, in);
deserialize(item.tolerance, in);
deserialize(item.train_count, in);
deserialize(item.w, in);
}
private:
kernel_type kernel;
......@@ -205,6 +219,14 @@ namespace dlib
}; // end of class svm_pegasos
template <
typename K
>
void swap (
svm_pegasos<K>& a,
svm_pegasos<K>& b
) { a.swap(b); }
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
......@@ -239,7 +261,7 @@ namespace dlib
{
// make sure requires clause is not broken
DLIB_ASSERT(0 < min_learning_rate_,
"\tsvm_pegasos_trainer::svm_pegasos_trainer(kernel,lambda)"
"\tbatch_trainer::batch_trainer()"
<< "\n\t invalid inputs were given to this function"
<< "\n\t min_learning_rate_: " << min_learning_rate_
);
......@@ -251,6 +273,12 @@ namespace dlib
return trainer.get_kernel();
}
const scalar_type get_min_learning_rate (
) const
{
return min_learning_rate;
}
template <
typename in_sample_vector_type,
typename in_scalar_vector_type
......@@ -277,17 +305,6 @@ namespace dlib
typedef typename decision_function<kernel_type>::sample_vector_type sample_vector_type;
typedef typename decision_function<kernel_type>::scalar_vector_type scalar_vector_type;
// make sure requires clause is not broken
DLIB_ASSERT(is_binary_classification_problem(x,y) == true,
"\tdecision_function batch_trainer::train(x,y)"
<< "\n\t invalid inputs were given to this function"
<< "\n\t x.nr(): " << x.nr()
<< "\n\t y.nr(): " << y.nr()
<< "\n\t x.nc(): " << x.nc()
<< "\n\t y.nc(): " << y.nc()
<< "\n\t is_binary_classification_problem(x,y): " << ((is_binary_classification_problem(x,y))? "true":"false")
);
dlib::rand::kernel_1a rnd;
trainer_type my_trainer(trainer);
......
......@@ -31,6 +31,17 @@ namespace dlib
Pegasos: Primal estimated sub-gradient solver for SVM (2007)
by Yoram Singer, Nathan Srebro
In ICML
This SVM training algorithm has two interesting properties. First, the
pegasos algorithm itself converges to the solution in an amount of time
unrelated to the size of the training set (in addition to being quite fast
to begin with). This makes it an appropriate algorithm for learning from
very large datasets. Second, this object uses the dlib::kcentroid object
to maintain a sparse approximation of the learned decision function.
This means that the number of support vectors in the resulting decision
function is also unrelated to the size of the dataset (in normal SVM
training algorithms, the number of support vectors grows approximately
linearly with the size of the training set).
!*/
public:
......@@ -76,6 +87,39 @@ namespace dlib
(e.g. clears out any memory of previous calls to train())
!*/
const scalar_type get_lambda (
) const;
/*!
ensures
- returns the SVM regularization term. It is the parameter that
determines the trade off between trying to fit the training data
exactly or allowing more errors but hopefully improving the
generalization ability of the resulting classifier. Smaller
values encourage exact fitting while larger values may encourage
better generalization. It is also worth noting that the number
of iterations it takes for this algorithm to converge is
proportional to 1/lambda. So smaller values of this term cause
the running time of this algorithm to increase. For more
information you should consult the paper referenced above.
!*/
const scalar_type get_tolerance (
) const;
/*!
ensures
- returns the tolerance used by the internal kcentroid object to
represent the learned decision function. Smaller values of this
tolerance will result in a more accurate representation of the
decision function but will use more support vectors.
!*/
const kernel_type get_kernel (
) const;
/*!
ensures
- returns the kernel used by this object
!*/
void set_kernel (
kernel_type k
);
......@@ -118,21 +162,6 @@ namespace dlib
since this object was constructed or last cleared.
!*/
const scalar_type get_lambda (
) const;
/*!
!*/
const scalar_type get_tolerance (
) const;
/*!
!*/
const kernel_type get_kernel (
) const;
/*!
!*/
scalar_type train (
const sample_type& x,
const scalar_type& y
......@@ -142,28 +171,78 @@ namespace dlib
- y == 1 || y == -1
ensures
- trains this svm using the given sample x and label y
- #get_train_count() == get_train_count() + 1
- returns the current learning rate
(i.e. 1/(get_lambda()*get_train_count()))
!*/
scalar_type operator() (
const sample_type& x
) const;
/*!
ensures
- classifies the given x sample using the decision function
this object has learned so far.
- if (x is a sample predicted have +1 label) then
- returns a number >= 0
- else
- returns a number < 0
!*/
const decision_function<kernel_type> get_decision_function (
) const;
/*!
ensures
- returns a decision function F that represents the function learned
by this object so far. I.e. it is the case that:
- for all x: F(x) == (*this)(x)
!*/
void swap (
svm_pegasos& item
);
/*!
ensures
- swaps *this and item
!*/
};
// ----------------------------------------------------------------------------------------
template <
typename kern_type
>
void swap(
svm_pegasos<kern_type>& a,
svm_pegasos<kern_type>& b
) { a.swap(b); }
/*!
provides a global swap function
!*/
template <
typename kern_type
>
void serialize (
const svm_pegasos<kern_type>& item,
std::ostream& out
);
/*!
provides serialization support for svm_pegasos objects
!*/
template <
typename kern_type
>
void deserialize (
svm_pegasos<kern_type>& item,
std::istream& in
);
/*!
provides serialization support for svm_pegasos objects
!*/
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
......@@ -173,6 +252,18 @@ namespace dlib
>
class batch_trainer
{
/*!
REQUIREMENTS ON trainer_type
- trainer_type == some kind of online trainer object (e.g. svm_pegasos)
WHAT THIS OBJECT REPRESENTS
This is a trainer object that is meant to wrap online trainer objects
that create decision_functions. It turns an online learning algorithm
such as svm_pegasos into a batch learning object. This allows you to
use objects like svm_pegasos with functions (e.g. cross_validate_trainer)
that expect batch mode training objects.
!*/
public:
typedef typename trainer_type::kernel_type kernel_type;
typedef typename trainer_type::scalar_type scalar_type;
......@@ -184,19 +275,43 @@ namespace dlib
batch_trainer (
);
/*!
ensures
- This object is in an uninitialized state. You must
construct a real one with the other constructor and assign it
to this instance before you use this object.
!*/
batch_trainer (
const trainer_type& trainer_,
const trainer_type& online_trainer,
const scalar_type min_learning_rate_,
bool verbose_
);
/*!
requires
- min_learning_rate_ > 0
ensures
- returns a batch trainer object that uses the given online_trainer object
to train a decision function.
- #get_kernel() == trainer.get_kernel()
- #get_min_learning_rate() == min_learning_rate_
- if (verbose_ == true) then
- this object will output status messages to standard out while
training is under way.
!*/
const scalar_type get_min_learning_rate (
) const;
/*!
ensures
- returns the min learning rate that the online trainer must reach
before this object considers training to be complete.
!*/
const kernel_type get_kernel (
) const;
/*!
ensures
- returns the kernel used by this trainer object
!*/
template <
......@@ -208,6 +323,14 @@ namespace dlib
const in_scalar_vector_type& y
) const;
/*!
ensures
- trains and returns a decision_function using the trainer that was
supplied to this object's constructor.
- training continues until the online training object indicates that
its learning rate has dropped below get_min_learning_rate().
throws
- std::bad_alloc
- any exceptions thrown by the trainer_type object
!*/
};
......@@ -222,6 +345,13 @@ namespace dlib
const typename trainer_type::scalar_type min_learning_rate = 0.1
) { return batch_trainer<trainer_type>(trainer, min_learning_rate, false); }
/*!
requires
- min_learning_rate > 0
- trainer_type == some kind of online trainer object that creates decision_function
objects (e.g. svm_pegasos)
ensures
- returns a batch_trainer object that has been instantiated with the
given arguments.
!*/
// ----------------------------------------------------------------------------------------
......@@ -234,6 +364,13 @@ namespace dlib
const typename trainer_type::scalar_type min_learning_rate = 0.1
) { return batch_trainer<trainer_type>(trainer, min_learning_rate, true); }
/*!
requires
- min_learning_rate > 0
- trainer_type == some kind of online trainer object that creates decision_function
objects (e.g. svm_pegasos)
ensures
- returns a batch_trainer object that has been instantiated with the
given arguments (and is verbose).
!*/
// ----------------------------------------------------------------------------------------
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment