Commit 2932d6d3 authored by Davis King's avatar Davis King

Fixed a minor bug and did some cleanup

--HG--
extra : convert_revision : svn%3Afdd8eb12-d10e-0410-9acb-85c331704f74/trunk%404014
parent 00325e75
...@@ -6,7 +6,6 @@ ...@@ -6,7 +6,6 @@
//#include "local/make_label_kernel_matrix.h" //#include "local/make_label_kernel_matrix.h"
#include "svm_c_trainer_abstract.h" #include "svm_c_trainer_abstract.h"
#include "calculate_rho_and_b.h"
#include <cmath> #include <cmath>
#include <limits> #include <limits>
#include <sstream> #include <sstream>
...@@ -236,8 +235,8 @@ namespace dlib ...@@ -236,8 +235,8 @@ namespace dlib
alpha, alpha,
eps); eps);
scalar_type rho, b; scalar_type b;
calculate_rho_and_b(y,alpha,solver.get_gradient(),rho,b); calculate_b(y,alpha,solver.get_gradient(),Cpos,Cneg,b);
alpha = pointwise_multiply(alpha,y); alpha = pointwise_multiply(alpha,y);
// count the number of support vectors // count the number of support vectors
...@@ -263,11 +262,80 @@ namespace dlib ...@@ -263,11 +262,80 @@ namespace dlib
} }
// now return the decision function // now return the decision function
return decision_function<K> (sv_alpha, b*rho, kernel_function, support_vectors); return decision_function<K> (sv_alpha, b, kernel_function, support_vectors);
} }
// ------------------------------------------------------------------------------------ // ------------------------------------------------------------------------------------
template <
typename scalar_vector_type,
typename scalar_vector_type2
>
void calculate_b(
const scalar_vector_type2& y,
const scalar_vector_type& alpha,
const scalar_vector_type& df,
const scalar_type& Cpos,
const scalar_type& Cneg,
scalar_type& b
) const
{
using namespace std;
long num_free = 0;
scalar_type sum_free = 0;
scalar_type upper_bound = -numeric_limits<scalar_type>::infinity();
scalar_type lower_bound = numeric_limits<scalar_type>::infinity();
for(long i = 0; i < alpha.nr(); ++i)
{
if(y(i) == 1)
{
if(alpha(i) == Cpos)
{
if (df(i) > upper_bound)
upper_bound = df(i);
}
else if(alpha(i) == 0)
{
if (df(i) < lower_bound)
lower_bound = df(i);
}
else
{
++num_free;
sum_free += df(i);
}
}
else
{
if(alpha(i) == Cneg)
{
if (-df(i) > upper_bound)
upper_bound = -df(i);
}
else if(alpha(i) == 0)
{
if (-df(i) < lower_bound)
lower_bound = -df(i);
}
else
{
++num_free;
sum_free -= df(i);
}
}
}
if(num_free > 0)
b = sum_free/num_free;
else
b = (upper_bound+lower_bound)/2;
}
// ------------------------------------------------------------------------------------
kernel_type kernel_function; kernel_type kernel_function;
scalar_type Cpos; scalar_type Cpos;
scalar_type Cneg; scalar_type Cneg;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment