Commit f8534f99 authored by Davis King's avatar Davis King

Modified the svm_pegasos class so that the user can set independent lambda

parameters for each class.  This also breaks backwards compatibility with
the previous interface slightly and changes the serialization format
of this class.

--HG--
extra : convert_revision : svn%3Afdd8eb12-d10e-0410-9acb-85c331704f74/trunk%403149
parent 3b6f3f20
...@@ -39,7 +39,8 @@ namespace dlib ...@@ -39,7 +39,8 @@ namespace dlib
svm_pegasos ( svm_pegasos (
) : ) :
max_sv(40), max_sv(40),
lambda(0.0001), lambda_c1(0.0001),
lambda_c2(0.0001),
tau(0.01), tau(0.01),
tolerance(0.01), tolerance(0.01),
train_count(0), train_count(0),
...@@ -55,17 +56,18 @@ namespace dlib ...@@ -55,17 +56,18 @@ namespace dlib
) : ) :
max_sv(max_num_sv), max_sv(max_num_sv),
kernel(kernel_), kernel(kernel_),
lambda(lambda_), lambda_c1(lambda_),
lambda_c2(lambda_),
tau(0.01), tau(0.01),
tolerance(tolerance_), tolerance(tolerance_),
train_count(0), train_count(0),
w(offset_kernel<kernel_type>(kernel,tau),tolerance, max_sv, false) w(offset_kernel<kernel_type>(kernel,tau),tolerance, max_sv, false)
{ {
// make sure requires clause is not broken // make sure requires clause is not broken
DLIB_ASSERT(lambda > 0 && tolerance > 0 && max_num_sv > 0, DLIB_ASSERT(lambda_ > 0 && tolerance > 0 && max_num_sv > 0,
"\tsvm_pegasos::svm_pegasos(kernel,lambda,tolerance)" "\tsvm_pegasos::svm_pegasos(kernel,lambda,tolerance)"
<< "\n\t invalid inputs were given to this function" << "\n\t invalid inputs were given to this function"
<< "\n\t lambda: " << lambda << "\n\t lambda_: " << lambda_
<< "\n\t max_num_sv: " << max_num_sv << "\n\t max_num_sv: " << max_num_sv
); );
} }
...@@ -128,16 +130,55 @@ namespace dlib ...@@ -128,16 +130,55 @@ namespace dlib
DLIB_ASSERT(0 < lambda_, DLIB_ASSERT(0 < lambda_,
"\tvoid svm_pegasos::set_lambda(lambda_)" "\tvoid svm_pegasos::set_lambda(lambda_)"
<< "\n\t invalid inputs were given to this function" << "\n\t invalid inputs were given to this function"
<< "\n\t lambda: " << lambda_ << "\n\t lambda_: " << lambda_
); );
lambda = lambda_; lambda_c1 = lambda_;
lambda_c2 = lambda_;
max_wnorm = 1/std::sqrt(std::min(lambda_c1, lambda_c2));
clear();
}
void set_lambda_class1 (
scalar_type lambda_
)
{
// make sure requires clause is not broken
DLIB_ASSERT(0 < lambda_,
"\tvoid svm_pegasos::set_lambda_class1(lambda_)"
<< "\n\t invalid inputs were given to this function"
<< "\n\t lambda_: " << lambda_
);
lambda_c1 = lambda_;
max_wnorm = 1/std::sqrt(std::min(lambda_c1, lambda_c2));
clear(); clear();
} }
const scalar_type get_lambda ( void set_lambda_class2 (
scalar_type lambda_
)
{
// make sure requires clause is not broken
DLIB_ASSERT(0 < lambda_,
"\tvoid svm_pegasos::set_lambda_class2(lambda_)"
<< "\n\t invalid inputs were given to this function"
<< "\n\t lambda_: " << lambda_
);
lambda_c2 = lambda_;
max_wnorm = 1/std::sqrt(std::min(lambda_c1, lambda_c2));
clear();
}
const scalar_type get_lambda_class1 (
) const ) const
{ {
return lambda; return lambda_c1;
}
const scalar_type get_lambda_class2 (
) const
{
return lambda_c2;
} }
const scalar_type get_tolerance ( const scalar_type get_tolerance (
...@@ -169,6 +210,9 @@ namespace dlib ...@@ -169,6 +210,9 @@ namespace dlib
<< "\n\t invalid inputs were given to this function" << "\n\t invalid inputs were given to this function"
<< "\n\t y: " << y << "\n\t y: " << y
); );
const double lambda = (y==+1)? lambda_c1 : lambda_c2;
++train_count; ++train_count;
const scalar_type learning_rate = 1/(lambda*train_count); const scalar_type learning_rate = 1/(lambda*train_count);
...@@ -180,7 +224,7 @@ namespace dlib ...@@ -180,7 +224,7 @@ namespace dlib
w.train(x, 1 - learning_rate*lambda, y*learning_rate); w.train(x, 1 - learning_rate*lambda, y*learning_rate);
scalar_type wnorm = std::sqrt(w.squared_norm()); scalar_type wnorm = std::sqrt(w.squared_norm());
scalar_type temp = (1/std::sqrt(lambda))/(wnorm); scalar_type temp = max_wnorm/wnorm;
if (temp < 1) if (temp < 1)
w.scale_by(temp); w.scale_by(temp);
} }
...@@ -189,7 +233,8 @@ namespace dlib ...@@ -189,7 +233,8 @@ namespace dlib
w.scale_by(1 - learning_rate*lambda); w.scale_by(1 - learning_rate*lambda);
} }
return learning_rate; // return the current learning rate
return 1/(std::min(lambda_c1,lambda_c2)*train_count);
} }
scalar_type operator() ( scalar_type operator() (
...@@ -212,7 +257,9 @@ namespace dlib ...@@ -212,7 +257,9 @@ namespace dlib
{ {
exchange(max_sv, item.max_sv); exchange(max_sv, item.max_sv);
exchange(kernel, item.kernel); exchange(kernel, item.kernel);
exchange(lambda, item.lambda); exchange(lambda_c1, item.lambda_c1);
exchange(lambda_c2, item.lambda_c2);
exchange(max_wnorm, item.max_wnorm);
exchange(tau, item.tau); exchange(tau, item.tau);
exchange(tolerance, item.tolerance); exchange(tolerance, item.tolerance);
exchange(train_count, item.train_count); exchange(train_count, item.train_count);
...@@ -223,7 +270,9 @@ namespace dlib ...@@ -223,7 +270,9 @@ namespace dlib
{ {
serialize(item.max_sv, out); serialize(item.max_sv, out);
serialize(item.kernel, out); serialize(item.kernel, out);
serialize(item.lambda, out); serialize(item.lambda_c1, out);
serialize(item.lambda_c2, out);
serialize(item.max_wnorm, out);
serialize(item.tau, out); serialize(item.tau, out);
serialize(item.tolerance, out); serialize(item.tolerance, out);
serialize(item.train_count, out); serialize(item.train_count, out);
...@@ -234,7 +283,9 @@ namespace dlib ...@@ -234,7 +283,9 @@ namespace dlib
{ {
deserialize(item.max_sv, in); deserialize(item.max_sv, in);
deserialize(item.kernel, in); deserialize(item.kernel, in);
deserialize(item.lambda, in); deserialize(item.lambda_c1, in);
deserialize(item.lambda_c2, in);
deserialize(item.max_wnorm, in);
deserialize(item.tau, in); deserialize(item.tau, in);
deserialize(item.tolerance, in); deserialize(item.tolerance, in);
deserialize(item.train_count, in); deserialize(item.train_count, in);
...@@ -245,7 +296,9 @@ namespace dlib ...@@ -245,7 +296,9 @@ namespace dlib
unsigned long max_sv; unsigned long max_sv;
kernel_type kernel; kernel_type kernel;
scalar_type lambda; scalar_type lambda_c1;
scalar_type lambda_c2;
scalar_type max_wnorm;
scalar_type tau; scalar_type tau;
scalar_type tolerance; scalar_type tolerance;
scalar_type train_count; scalar_type train_count;
...@@ -273,7 +326,8 @@ namespace dlib ...@@ -273,7 +326,8 @@ namespace dlib
) )
{ {
dest.set_tolerance(source.get_tolerance()); dest.set_tolerance(source.get_tolerance());
dest.set_lambda(source.get_lambda()); dest.set_lambda_class1(source.get_lambda_class1());
dest.set_lambda_class2(source.get_lambda_class2());
dest.set_max_num_sv(source.get_max_num_sv()); dest.set_max_num_sv(source.get_max_num_sv());
} }
......
...@@ -61,7 +61,8 @@ namespace dlib ...@@ -61,7 +61,8 @@ namespace dlib
/*! /*!
ensures ensures
- this object is properly initialized - this object is properly initialized
- #get_lambda() == 0.0001 - #get_lambda_class1() == 0.0001
- #get_lambda_class2() == 0.0001
- #get_tolerance() == 0.01 - #get_tolerance() == 0.01
- #get_train_count() == 0 - #get_train_count() == 0
- #get_max_num_sv() == 40 - #get_max_num_sv() == 40
...@@ -80,7 +81,8 @@ namespace dlib ...@@ -80,7 +81,8 @@ namespace dlib
- max_num_sv > 0 - max_num_sv > 0
ensures ensures
- this object is properly initialized - this object is properly initialized
- #get_lambda() == lambda_ - #get_lambda_class1() == lambda_
- #get_lambda_class2() == lambda_
- #get_tolerance() == tolerance_ - #get_tolerance() == tolerance_
- #get_kernel() == kernel_ - #get_kernel() == kernel_
- #get_train_count() == 0 - #get_train_count() == 0
...@@ -94,28 +96,38 @@ namespace dlib ...@@ -94,28 +96,38 @@ namespace dlib
- #get_train_count() == 0 - #get_train_count() == 0
- clears out any memory of previous calls to train() - clears out any memory of previous calls to train()
- doesn't change any of the algorithm parameters. I.e. - doesn't change any of the algorithm parameters. I.e.
- #get_lambda() == get_lambda() - #get_lambda_class1() == get_lambda_class1()
- #get_tolerance() == get_tolerance() - #get_lambda_class2() == get_lambda_class2()
- #get_kernel() == get_kernel() - #get_tolerance() == get_tolerance()
- #get_max_num_sv() == get_max_num_sv() - #get_kernel() == get_kernel()
- #get_max_num_sv() == get_max_num_sv()
!*/ !*/
const scalar_type get_lambda ( const scalar_type get_lambda_class1 (
) const; ) const;
/*! /*!
ensures ensures
- returns the SVM regularization term. It is the parameter that - returns the SVM regularization term for the +1 class. It is the
determines the trade off between trying to fit the training data parameter that determines the trade off between trying to fit the
exactly or allowing more errors but hopefully improving the +1 training data exactly or allowing more errors but hopefully
generalization ability of the resulting classifier. Smaller improving the generalization ability of the resulting classifier.
values encourage exact fitting while larger values may encourage Smaller values encourage exact fitting while larger values may
better generalization. It is also worth noting that the number encourage better generalization. It is also worth noting that the
of iterations it takes for this algorithm to converge is number of iterations it takes for this algorithm to converge is
proportional to 1/lambda. So smaller values of this term cause proportional to 1/lambda. So smaller values of this term cause
the running time of this algorithm to increase. For more the running time of this algorithm to increase. For more
information you should consult the paper referenced above. information you should consult the paper referenced above.
!*/ !*/
const scalar_type get_lambda_class2 (
) const;
/*!
ensures
- returns the SVM regularization term for the -1 class. It has
the same properties as the get_lambda_class1() parameter except that
it applies to the -1 class.
!*/
const scalar_type get_tolerance ( const scalar_type get_tolerance (
) const; ) const;
/*! /*!
...@@ -183,11 +195,42 @@ namespace dlib ...@@ -183,11 +195,42 @@ namespace dlib
requires requires
- lambda_ > 0 - lambda_ > 0
ensures ensures
- #get_lambda() == tol - #get_lambda_class1() == lambda_
- #get_lambda_class2() == lambda_
- #get_train_count() == 0 - #get_train_count() == 0
(i.e. clears any memory of previous training) (i.e. clears any memory of previous training)
!*/ !*/
void set_lambda_class1 (
scalar_type lambda_
);
/*!
requires
- lambda_ > 0
ensures
- #get_lambda_class1() == lambda_
#get_train_count() == 0
(i.e. clears any memory of previous training)
!*/
void set_lambda_class2 (
scalar_type lambda_
);
/*!
requires
- lambda_ > 0
ensures
- #get_lambda_class2() == lambda_
#get_train_count() == 0
(i.e. clears any memory of previous training)
!*/
const scalar_type get_lambda_class1 (
) const;
const scalar_type get_lambda_class2 (
) const;
unsigned long get_train_count ( unsigned long get_train_count (
) const; ) const;
/*! /*!
...@@ -207,7 +250,7 @@ namespace dlib ...@@ -207,7 +250,7 @@ namespace dlib
- trains this svm using the given sample x and label y - trains this svm using the given sample x and label y
- #get_train_count() == get_train_count() + 1 - #get_train_count() == get_train_count() + 1
- returns the current learning rate - returns the current learning rate
(i.e. 1/(get_lambda()*get_train_count())) (i.e. 1/(get_train_count()*min(get_lambda_class1(),get_lambda_class2())) )
!*/ !*/
scalar_type operator() ( scalar_type operator() (
...@@ -291,7 +334,8 @@ namespace dlib ...@@ -291,7 +334,8 @@ namespace dlib
ensures ensures
- copies all the parameters from the source trainer to the dest trainer. - copies all the parameters from the source trainer to the dest trainer.
- #dest.get_tolerance() == source.get_tolerance() - #dest.get_tolerance() == source.get_tolerance()
- #dest.get_lambda() == source.get_lambda() - #dest.get_lambda_class1() == source.get_lambda_class1()
- #dest.get_lambda_class2() == source.get_lambda_class2()
- #dest.get_max_num_sv() == source.get_max_num_sv() - #dest.get_max_num_sv() == source.get_max_num_sv()
!*/ !*/
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment