Commit f8534f99 authored by Davis King's avatar Davis King

Modified the svm_pegasos class so that the user can set independent lambda

parameters for each class.  This also breaks backwards compatibility with
the previous interface slightly and changes the serialization format
of this class.

--HG--
extra : convert_revision : svn%3Afdd8eb12-d10e-0410-9acb-85c331704f74/trunk%403149
parent 3b6f3f20
......@@ -39,7 +39,8 @@ namespace dlib
svm_pegasos (
) :
max_sv(40),
lambda(0.0001),
lambda_c1(0.0001),
lambda_c2(0.0001),
tau(0.01),
tolerance(0.01),
train_count(0),
......@@ -55,17 +56,18 @@ namespace dlib
) :
max_sv(max_num_sv),
kernel(kernel_),
lambda(lambda_),
lambda_c1(lambda_),
lambda_c2(lambda_),
tau(0.01),
tolerance(tolerance_),
train_count(0),
w(offset_kernel<kernel_type>(kernel,tau),tolerance, max_sv, false)
{
// make sure requires clause is not broken
DLIB_ASSERT(lambda > 0 && tolerance > 0 && max_num_sv > 0,
DLIB_ASSERT(lambda_ > 0 && tolerance > 0 && max_num_sv > 0,
"\tsvm_pegasos::svm_pegasos(kernel,lambda,tolerance)"
<< "\n\t invalid inputs were given to this function"
<< "\n\t lambda: " << lambda
<< "\n\t lambda_: " << lambda_
<< "\n\t max_num_sv: " << max_num_sv
);
}
......@@ -128,16 +130,55 @@ namespace dlib
DLIB_ASSERT(0 < lambda_,
"\tvoid svm_pegasos::set_lambda(lambda_)"
<< "\n\t invalid inputs were given to this function"
<< "\n\t lambda: " << lambda_
<< "\n\t lambda_: " << lambda_
);
lambda = lambda_;
lambda_c1 = lambda_;
lambda_c2 = lambda_;
max_wnorm = 1/std::sqrt(std::min(lambda_c1, lambda_c2));
clear();
}
void set_lambda_class1 (
scalar_type lambda_
)
{
// make sure requires clause is not broken
DLIB_ASSERT(0 < lambda_,
"\tvoid svm_pegasos::set_lambda_class1(lambda_)"
<< "\n\t invalid inputs were given to this function"
<< "\n\t lambda_: " << lambda_
);
lambda_c1 = lambda_;
max_wnorm = 1/std::sqrt(std::min(lambda_c1, lambda_c2));
clear();
}
const scalar_type get_lambda (
void set_lambda_class2 (
scalar_type lambda_
)
{
// make sure requires clause is not broken
DLIB_ASSERT(0 < lambda_,
"\tvoid svm_pegasos::set_lambda_class2(lambda_)"
<< "\n\t invalid inputs were given to this function"
<< "\n\t lambda_: " << lambda_
);
lambda_c2 = lambda_;
max_wnorm = 1/std::sqrt(std::min(lambda_c1, lambda_c2));
clear();
}
const scalar_type get_lambda_class1 (
) const
{
return lambda;
return lambda_c1;
}
const scalar_type get_lambda_class2 (
) const
{
return lambda_c2;
}
const scalar_type get_tolerance (
......@@ -169,6 +210,9 @@ namespace dlib
<< "\n\t invalid inputs were given to this function"
<< "\n\t y: " << y
);
const double lambda = (y==+1)? lambda_c1 : lambda_c2;
++train_count;
const scalar_type learning_rate = 1/(lambda*train_count);
......@@ -180,7 +224,7 @@ namespace dlib
w.train(x, 1 - learning_rate*lambda, y*learning_rate);
scalar_type wnorm = std::sqrt(w.squared_norm());
scalar_type temp = (1/std::sqrt(lambda))/(wnorm);
scalar_type temp = max_wnorm/wnorm;
if (temp < 1)
w.scale_by(temp);
}
......@@ -189,7 +233,8 @@ namespace dlib
w.scale_by(1 - learning_rate*lambda);
}
return learning_rate;
// return the current learning rate
return 1/(std::min(lambda_c1,lambda_c2)*train_count);
}
scalar_type operator() (
......@@ -212,7 +257,9 @@ namespace dlib
{
exchange(max_sv, item.max_sv);
exchange(kernel, item.kernel);
exchange(lambda, item.lambda);
exchange(lambda_c1, item.lambda_c1);
exchange(lambda_c2, item.lambda_c2);
exchange(max_wnorm, item.max_wnorm);
exchange(tau, item.tau);
exchange(tolerance, item.tolerance);
exchange(train_count, item.train_count);
......@@ -223,7 +270,9 @@ namespace dlib
{
serialize(item.max_sv, out);
serialize(item.kernel, out);
serialize(item.lambda, out);
serialize(item.lambda_c1, out);
serialize(item.lambda_c2, out);
serialize(item.max_wnorm, out);
serialize(item.tau, out);
serialize(item.tolerance, out);
serialize(item.train_count, out);
......@@ -234,7 +283,9 @@ namespace dlib
{
deserialize(item.max_sv, in);
deserialize(item.kernel, in);
deserialize(item.lambda, in);
deserialize(item.lambda_c1, in);
deserialize(item.lambda_c2, in);
deserialize(item.max_wnorm, in);
deserialize(item.tau, in);
deserialize(item.tolerance, in);
deserialize(item.train_count, in);
......@@ -245,7 +296,9 @@ namespace dlib
unsigned long max_sv;
kernel_type kernel;
scalar_type lambda;
scalar_type lambda_c1;
scalar_type lambda_c2;
scalar_type max_wnorm;
scalar_type tau;
scalar_type tolerance;
scalar_type train_count;
......@@ -273,7 +326,8 @@ namespace dlib
)
{
dest.set_tolerance(source.get_tolerance());
dest.set_lambda(source.get_lambda());
dest.set_lambda_class1(source.get_lambda_class1());
dest.set_lambda_class2(source.get_lambda_class2());
dest.set_max_num_sv(source.get_max_num_sv());
}
......
......@@ -61,7 +61,8 @@ namespace dlib
/*!
ensures
- this object is properly initialized
- #get_lambda() == 0.0001
- #get_lambda_class1() == 0.0001
- #get_lambda_class2() == 0.0001
- #get_tolerance() == 0.01
- #get_train_count() == 0
- #get_max_num_sv() == 40
......@@ -80,7 +81,8 @@ namespace dlib
- max_num_sv > 0
ensures
- this object is properly initialized
- #get_lambda() == lambda_
- #get_lambda_class1() == lambda_
- #get_lambda_class2() == lambda_
- #get_tolerance() == tolerance_
- #get_kernel() == kernel_
- #get_train_count() == 0
......@@ -94,28 +96,38 @@ namespace dlib
- #get_train_count() == 0
- clears out any memory of previous calls to train()
- doesn't change any of the algorithm parameters. I.e.
- #get_lambda() == get_lambda()
- #get_tolerance() == get_tolerance()
- #get_kernel() == get_kernel()
- #get_max_num_sv() == get_max_num_sv()
- #get_lambda_class1() == get_lambda_class1()
- #get_lambda_class2() == get_lambda_class2()
- #get_tolerance() == get_tolerance()
- #get_kernel() == get_kernel()
- #get_max_num_sv() == get_max_num_sv()
!*/
const scalar_type get_lambda (
const scalar_type get_lambda_class1 (
) const;
/*!
ensures
- returns the SVM regularization term. It is the parameter that
determines the trade off between trying to fit the training data
exactly or allowing more errors but hopefully improving the
generalization ability of the resulting classifier. Smaller
values encourage exact fitting while larger values may encourage
better generalization. It is also worth noting that the number
of iterations it takes for this algorithm to converge is
- returns the SVM regularization term for the +1 class. It is the
parameter that determines the trade off between trying to fit the
+1 training data exactly or allowing more errors but hopefully
improving the generalization ability of the resulting classifier.
Smaller values encourage exact fitting while larger values may
encourage better generalization. It is also worth noting that the
number of iterations it takes for this algorithm to converge is
proportional to 1/lambda. So smaller values of this term cause
the running time of this algorithm to increase. For more
information you should consult the paper referenced above.
!*/
const scalar_type get_lambda_class2 (
) const;
/*!
ensures
- returns the SVM regularization term for the -1 class. It has
the same properties as the get_lambda_class1() parameter except that
it applies to the -1 class.
!*/
const scalar_type get_tolerance (
) const;
/*!
......@@ -183,11 +195,42 @@ namespace dlib
requires
- lambda_ > 0
ensures
- #get_lambda() == tol
- #get_lambda_class1() == lambda_
- #get_lambda_class2() == lambda_
- #get_train_count() == 0
(i.e. clears any memory of previous training)
!*/
void set_lambda_class1 (
scalar_type lambda_
);
/*!
requires
- lambda_ > 0
ensures
- #get_lambda_class1() == lambda_
#get_train_count() == 0
(i.e. clears any memory of previous training)
!*/
void set_lambda_class2 (
scalar_type lambda_
);
/*!
requires
- lambda_ > 0
ensures
- #get_lambda_class2() == lambda_
#get_train_count() == 0
(i.e. clears any memory of previous training)
!*/
const scalar_type get_lambda_class1 (
) const;
const scalar_type get_lambda_class2 (
) const;
unsigned long get_train_count (
) const;
/*!
......@@ -207,7 +250,7 @@ namespace dlib
- trains this svm using the given sample x and label y
- #get_train_count() == get_train_count() + 1
- returns the current learning rate
(i.e. 1/(get_lambda()*get_train_count()))
(i.e. 1/(get_train_count()*min(get_lambda_class1(),get_lambda_class2())) )
!*/
scalar_type operator() (
......@@ -291,7 +334,8 @@ namespace dlib
ensures
- copies all the parameters from the source trainer to the dest trainer.
- #dest.get_tolerance() == source.get_tolerance()
- #dest.get_lambda() == source.get_lambda()
- #dest.get_lambda_class1() == source.get_lambda_class1()
- #dest.get_lambda_class2() == source.get_lambda_class2()
- #dest.get_max_num_sv() == source.get_max_num_sv()
!*/
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment