Commit 43946fcc authored by Davis King's avatar Davis King

Added a function that lets you test and train at the same time

--HG--
extra : convert_revision : svn%3Afdd8eb12-d10e-0410-9acb-85c331704f74/trunk%402251
parent 230b8b95
...@@ -105,10 +105,74 @@ namespace dlib ...@@ -105,10 +105,74 @@ namespace dlib
return std::sqrt(kernel(x,x) + bias - 2*temp); return std::sqrt(kernel(x,x) + bias - 2*temp);
} }
scalar_type test_and_train (
const sample_type& x
)
{
return train_and_maybe_test(x,true);
}
void train ( void train (
const sample_type& x const sample_type& x
) )
{ {
train_and_maybe_test(x,false);
}
void swap (
one_class& item
)
{
exchange(kernel, item.kernel);
dictionary.swap(item.dictionary);
alpha.swap(item.alpha);
K_inv.swap(item.K_inv);
K.swap(item.K);
exchange(tolerance, item.tolerance);
exchange(samples_seen, item.samples_seen);
exchange(bias, item.bias);
exchange(max_dis, item.max_dis);
a.swap(item.a);
k.swap(item.k);
}
unsigned long dictionary_size (
) const { return dictionary.size(); }
friend void serialize(const one_class& item, std::ostream& out)
{
serialize(item.kernel, out);
serialize(item.dictionary, out);
serialize(item.alpha, out);
serialize(item.K_inv, out);
serialize(item.K, out);
serialize(item.tolerance, out);
serialize(item.samples_seen, out);
serialize(item.bias, out);
serialize(item.max_dis, out);
}
friend void deserialize(one_class& item, std::istream& in)
{
deserialize(item.kernel, in);
deserialize(item.dictionary, in);
deserialize(item.alpha, in);
deserialize(item.K_inv, in);
deserialize(item.K, in);
deserialize(item.tolerance, in);
deserialize(item.samples_seen, in);
deserialize(item.bias, in);
deserialize(item.max_dis, in);
}
private:
scalar_type train_and_maybe_test (
const sample_type& x,
bool do_test
)
{
scalar_type test_result = 0;
const scalar_type kx = kernel(x,x); const scalar_type kx = kernel(x,x);
if (alpha.size() == 0) if (alpha.size() == 0)
{ {
...@@ -129,6 +193,11 @@ namespace dlib ...@@ -129,6 +193,11 @@ namespace dlib
for (long r = 0; r < k.nr(); ++r) for (long r = 0; r < k.nr(); ++r)
k(r) = kernel(x,dictionary[r]); k(r) = kernel(x,dictionary[r]);
if (do_test)
{
test_result = std::sqrt(kx + bias - 2*trans(vector_to_matrix(alpha))*k);
}
// compute the error we would have if we approximated the new x sample // compute the error we would have if we approximated the new x sample
// with the dictionary. That is, do the ALD test from the KRLS paper. // with the dictionary. That is, do the ALD test from the KRLS paper.
a = K_inv*k; a = K_inv*k;
...@@ -214,56 +283,10 @@ namespace dlib ...@@ -214,56 +283,10 @@ namespace dlib
if (samples_seen > max_dis) if (samples_seen > max_dis)
samples_seen = max_dis; samples_seen = max_dis;
}
void swap ( return test_result;
one_class& item
)
{
exchange(kernel, item.kernel);
dictionary.swap(item.dictionary);
alpha.swap(item.alpha);
K_inv.swap(item.K_inv);
K.swap(item.K);
exchange(tolerance, item.tolerance);
exchange(samples_seen, item.samples_seen);
exchange(bias, item.bias);
exchange(max_dis, item.max_dis);
a.swap(item.a);
k.swap(item.k);
} }
unsigned long dictionary_size (
) const { return dictionary.size(); }
friend void serialize(const one_class& item, std::ostream& out)
{
serialize(item.kernel, out);
serialize(item.dictionary, out);
serialize(item.alpha, out);
serialize(item.K_inv, out);
serialize(item.K, out);
serialize(item.tolerance, out);
serialize(item.samples_seen, out);
serialize(item.bias, out);
serialize(item.max_dis, out);
}
friend void deserialize(one_class& item, std::istream& in)
{
deserialize(item.kernel, in);
deserialize(item.dictionary, in);
deserialize(item.alpha, in);
deserialize(item.K_inv, in);
deserialize(item.K, in);
deserialize(item.tolerance, in);
deserialize(item.samples_seen, in);
deserialize(item.bias, in);
deserialize(item.max_dis, in);
}
private:
typedef std_allocator<sample_type, mem_manager_type> alloc_sample_type; typedef std_allocator<sample_type, mem_manager_type> alloc_sample_type;
typedef std_allocator<scalar_type, mem_manager_type> alloc_scalar_type; typedef std_allocator<scalar_type, mem_manager_type> alloc_scalar_type;
......
...@@ -125,6 +125,18 @@ namespace dlib ...@@ -125,6 +125,18 @@ namespace dlib
to this object so far. to this object so far.
!*/ !*/
scalar_type test_and_train (
const sample_type& x
);
/*!
ensures
- calls train(x)
- returns (*this)(x)
- The reason this function exists is because train() and operator()
both compute some of the same things. So this function is more efficient
than calling both individually.
!*/
void train ( void train (
const sample_type& x const sample_type& x
); );
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment