Commit 230b8b95 authored by Davis King's avatar Davis King

Added the max discount parameter to the one class algorithm as well

as cleaned up the code a bit.

--HG--
extra : convert_revision : svn%3Afdd8eb12-d10e-0410-9acb-85c331704f74/trunk%402250
parent 13ed137c
...@@ -39,13 +39,21 @@ namespace dlib ...@@ -39,13 +39,21 @@ namespace dlib
scalar_type tolerance_ = 0.001 scalar_type tolerance_ = 0.001
) : ) :
kernel(kernel_), kernel(kernel_),
tolerance(tolerance_) tolerance(tolerance_),
max_dis(1e6)
{ {
clear(); clear_dictionary();
} }
void set_tolerance (scalar_type tolerance_) void set_tolerance (scalar_type tolerance_)
{ {
// make sure requires clause is not broken
DLIB_ASSERT(tolerance_ >= 0,
"\tvoid one_class::set_tolerance"
<< "\n\tinvalid tolerance value"
<< "\n\ttolerance: " << tolerance_
<< "\n\tthis: " << this
);
tolerance = tolerance_; tolerance = tolerance_;
} }
...@@ -54,7 +62,29 @@ namespace dlib ...@@ -54,7 +62,29 @@ namespace dlib
return tolerance; return tolerance;
} }
void clear () void set_max_discount (
scalar_type value
)
{
// make sure requires clause is not broken
DLIB_ASSERT(value >= 0,
"\tvoid one_class::set_max_discount"
<< "\n\tinvalid discount value"
<< "\n\tvalue: " << value
<< "\n\tthis: " << this
);
max_dis = value;
if (samples_seen > value)
samples_seen = value;
}
scalar_type get_max_discount(
) const
{
return max_dis;
}
void clear_dictionary ()
{ {
dictionary.clear(); dictionary.clear();
alpha.clear(); alpha.clear();
...@@ -177,10 +207,13 @@ namespace dlib ...@@ -177,10 +207,13 @@ namespace dlib
} }
} }
++samples_seen;
// recompute the bias term // recompute the bias term
bias = sum(pointwise_multiply(K, vector_to_matrix(alpha)*trans(vector_to_matrix(alpha)))); bias = sum(pointwise_multiply(K, vector_to_matrix(alpha)*trans(vector_to_matrix(alpha))));
++samples_seen; if (samples_seen > max_dis)
samples_seen = max_dis;
} }
void swap ( void swap (
...@@ -195,6 +228,7 @@ namespace dlib ...@@ -195,6 +228,7 @@ namespace dlib
exchange(tolerance, item.tolerance); exchange(tolerance, item.tolerance);
exchange(samples_seen, item.samples_seen); exchange(samples_seen, item.samples_seen);
exchange(bias, item.bias); exchange(bias, item.bias);
exchange(max_dis, item.max_dis);
a.swap(item.a); a.swap(item.a);
k.swap(item.k); k.swap(item.k);
} }
...@@ -212,6 +246,7 @@ namespace dlib ...@@ -212,6 +246,7 @@ namespace dlib
serialize(item.tolerance, out); serialize(item.tolerance, out);
serialize(item.samples_seen, out); serialize(item.samples_seen, out);
serialize(item.bias, out); serialize(item.bias, out);
serialize(item.max_dis, out);
} }
friend void deserialize(one_class& item, std::istream& in) friend void deserialize(one_class& item, std::istream& in)
...@@ -224,6 +259,7 @@ namespace dlib ...@@ -224,6 +259,7 @@ namespace dlib
deserialize(item.tolerance, in); deserialize(item.tolerance, in);
deserialize(item.samples_seen, in); deserialize(item.samples_seen, in);
deserialize(item.bias, in); deserialize(item.bias, in);
deserialize(item.max_dis, in);
} }
private: private:
...@@ -245,6 +281,7 @@ namespace dlib ...@@ -245,6 +281,7 @@ namespace dlib
scalar_type tolerance; scalar_type tolerance;
scalar_type samples_seen; scalar_type samples_seen;
scalar_type bias; scalar_type bias;
scalar_type max_dis;
// temp variables here just so we don't have to reconstruct them over and over. Thus, // temp variables here just so we don't have to reconstruct them over and over. Thus,
......
...@@ -23,6 +23,7 @@ namespace dlib ...@@ -23,6 +23,7 @@ namespace dlib
INITIAL VALUE INITIAL VALUE
- dictionary_size() == 0 - dictionary_size() == 0
- max_discount() == 1e6
WHAT THIS OBJECT REPRESENTS WHAT THIS OBJECT REPRESENTS
This is an implementation of an online algorithm for recursively estimating the This is an implementation of an online algorithm for recursively estimating the
...@@ -57,6 +58,8 @@ namespace dlib ...@@ -57,6 +58,8 @@ namespace dlib
scalar_type tolerance_ scalar_type tolerance_
); );
/*! /*!
requires
- tolerance_ >= 0
ensures ensures
- #get_tolerance() == tolerance_ - #get_tolerance() == tolerance_
!*/ !*/
...@@ -74,14 +77,42 @@ namespace dlib ...@@ -74,14 +77,42 @@ namespace dlib
less accurate estimate but also in less support vectors. less accurate estimate but also in less support vectors.
!*/ !*/
void clear ( void set_max_discount (
scalar_type value
); );
/*! /*!
requires
- value > 0
ensures ensures
- clears out all learned data and puts this object back to its - #get_max_discount() == value
initial state. (e.g. #dictionary_size() == 0) !*/
- #get_tolerance() == get_tolerance()
(i.e. doesn't change the value of the tolerance) scalar_type get_max_discount(
) const;
/*!
ensures
- If you have shown this object N samples so far then it has found
the centroid of those N samples. That is, it has found the average
of all of them in some high dimensional feature space.
- if (N <= get_max_discount()) then
- The next sample you show this object will be added to the centroid
with a weight of 1/(N+1).
- else
- The next sample you show this object will be added to the centroid
with a weight of 1/(get_max_discount()+1).
- If you think your samples are from a stationary source then you
should set the max discount to some really big number. However,
if you think the source isn't stationary then use a smaller number.
This will cause the centroid in this object to be closer to the
centroid of the more recent points.
!*/
void clear_dictionary (
);
/*!
ensures
- clears out all learned data (e.g. #dictionary_size() == 0)
!*/ !*/
scalar_type operator() ( scalar_type operator() (
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment