Commit 35fb8d4d authored by Davis King's avatar Davis King

Added the is_learning_problem() predicate and used it to make a few

requires clauses more straight forward.

--HG--
extra : convert_revision : svn%3Afdd8eb12-d10e-0410-9acb-85c331704f74/trunk%404021
parent cdd8218b
...@@ -239,7 +239,7 @@ namespace dlib ...@@ -239,7 +239,7 @@ namespace dlib
) const ) const
{ {
// make sure requires clause is not broken // make sure requires clause is not broken
DLIB_ASSERT(is_vector(x) && is_vector(y) && x.size() == y.size() && x.size() > 0, DLIB_ASSERT(is_learning_problem(x,y),
"\t decision_function krr_trainer::train(x,y)" "\t decision_function krr_trainer::train(x,y)"
<< "\n\t invalid inputs were given to this function" << "\n\t invalid inputs were given to this function"
<< "\n\t is_vector(x): " << is_vector(x) << "\n\t is_vector(x): " << is_vector(x)
......
...@@ -233,9 +233,7 @@ namespace dlib ...@@ -233,9 +233,7 @@ namespace dlib
Also, x should contain sample_type objects. Also, x should contain sample_type objects.
- y == a matrix or something convertible to a matrix via vector_to_matrix(). - y == a matrix or something convertible to a matrix via vector_to_matrix().
Also, y should contain scalar_type objects. Also, y should contain scalar_type objects.
- is_vector(x) == true - is_learning_problem(x,y) == true
- is_vector(y) == true
- x.size() == y.size() > 0
- if (get_lambda() == 0 && will_use_regression_loss_for_loo_cv() == false) then - if (get_lambda() == 0 && will_use_regression_loss_for_loo_cv() == false) then
- is_binary_classification_problem(x,y) == true - is_binary_classification_problem(x,y) == true
(i.e. if you want this algorithm to estimate a lambda appropriate for (i.e. if you want this algorithm to estimate a lambda appropriate for
......
...@@ -21,7 +21,7 @@ namespace dlib ...@@ -21,7 +21,7 @@ namespace dlib
class rbf_network_trainer class rbf_network_trainer
{ {
/*! /*!
This is an implemenation of an RBF network trainer that follows This is an implementation of an RBF network trainer that follows
the directions right off Wikipedia basically. So nothing the directions right off Wikipedia basically. So nothing
particularly fancy. Although the way the centers are selected particularly fancy. Although the way the centers are selected
is somewhat unique. is somewhat unique.
...@@ -103,7 +103,7 @@ namespace dlib ...@@ -103,7 +103,7 @@ namespace dlib
typedef typename decision_function<kernel_type>::sample_vector_type sample_vector_type; typedef typename decision_function<kernel_type>::sample_vector_type sample_vector_type;
// make sure requires clause is not broken // make sure requires clause is not broken
DLIB_ASSERT(x.nr() > 1 && x.nr() == y.nr() && x.nc() == 1 && y.nc() == 1, DLIB_ASSERT(is_learning_problem(x,y),
"\tdecision_function rbf_network_trainer::train(x,y)" "\tdecision_function rbf_network_trainer::train(x,y)"
<< "\n\t invalid inputs were given to this function" << "\n\t invalid inputs were given to this function"
<< "\n\t x.nr(): " << x.nr() << "\n\t x.nr(): " << x.nr()
......
...@@ -93,9 +93,7 @@ namespace dlib ...@@ -93,9 +93,7 @@ namespace dlib
Also, x should contain sample_type objects. Also, x should contain sample_type objects.
- y == a matrix or something convertible to a matrix via vector_to_matrix(). - y == a matrix or something convertible to a matrix via vector_to_matrix().
Also, y should contain scalar_type objects. Also, y should contain scalar_type objects.
- x.nr() > 1 - is_learning_problem(x,y) == true
- x.nr() == y.nr() && x.nc() == 1 && y.nc() == 1
(i.e. x and y are both column vectors of the same length)
ensures ensures
- trains a RBF network given the training samples in x and - trains a RBF network given the training samples in x and
labels in y and returns the resulting decision_function labels in y and returns the resulting decision_function
......
...@@ -21,6 +21,35 @@ ...@@ -21,6 +21,35 @@
namespace dlib namespace dlib
{ {
// ----------------------------------------------------------------------------------------
template <
typename T,
typename U
>
inline bool is_learning_problem_impl (
const T& x,
const U& x_labels
)
{
return is_col_vector(x) &&
is_col_vector(x_labels) &&
x.size() == x_labels.size() &&
x.size() > 0;
}
template <
typename T,
typename U
>
inline bool is_learning_problem (
const T& x,
const U& x_labels
)
{
return is_learning_problem_impl(vector_to_matrix(x), vector_to_matrix(x_labels));
}
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
template < template <
...@@ -34,9 +63,12 @@ namespace dlib ...@@ -34,9 +63,12 @@ namespace dlib
{ {
bool seen_neg_class = false; bool seen_neg_class = false;
bool seen_pos_class = false; bool seen_pos_class = false;
if (x.nc() != 1 || x_labels.nc() != 1) return false;
if (x.nr() != x_labels.nr()) return false; if (is_learning_problem_impl(x,x_labels) == false)
if (x.nr() <= 1) return false; return false;
if (x.size() <= 1) return false;
for (long r = 0; r < x_labels.nr(); ++r) for (long r = 0; r < x_labels.nr(); ++r)
{ {
if (x_labels(r) != -1 && x_labels(r) != 1) if (x_labels(r) != -1 && x_labels(r) != 1)
......
...@@ -24,7 +24,7 @@ namespace dlib ...@@ -24,7 +24,7 @@ namespace dlib
typename T, typename T,
typename U typename U
> >
bool is_binary_classification_problem ( bool is_learning_problem (
const T& x, const T& x,
const U& x_labels const U& x_labels
); );
...@@ -37,6 +37,24 @@ namespace dlib ...@@ -37,6 +37,24 @@ namespace dlib
- is_col_vector(x) == true - is_col_vector(x) == true
- is_col_vector(x_labels) == true - is_col_vector(x_labels) == true
- x.size() == x_labels.size() - x.size() == x_labels.size()
- x.size() > 0
!*/
template <
typename T,
typename U
>
bool is_binary_classification_problem (
const T& x,
const U& x_labels
);
/*!
requires
- T == a matrix or something convertible to a matrix via vector_to_matrix()
- U == a matrix or something convertible to a matrix via vector_to_matrix()
ensures
- returns true if all of the following are true and false otherwise:
- is_learning_problem(x, x_labels) == true
- x.size() > 1 - x.size() > 1
- there exists at least one sample from both the +1 and -1 classes. - there exists at least one sample from both the +1 and -1 classes.
(i.e. all samples can't have the same label) (i.e. all samples can't have the same label)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment