Commit 6eee12f2 authored by Davis King's avatar Davis King

Improved the assert messages related to badly formed graph labeling

problems.
parent 0246088a
...@@ -23,13 +23,15 @@ namespace dlib ...@@ -23,13 +23,15 @@ namespace dlib
const std::vector<std::vector<bool> >& labels const std::vector<std::vector<bool> >& labels
) )
{ {
DLIB_ASSERT(is_graph_labeling_problem(samples, labels) , #ifdef ENABLE_ASSERTS
std::string reason_for_failure;
DLIB_ASSERT(is_graph_labeling_problem(samples, labels, reason_for_failure) ,
"\t matrix test_graph_labeling_function()" "\t matrix test_graph_labeling_function()"
<< "\n\t invalid inputs were given to this function" << "\n\t invalid inputs were given to this function"
<< "\n\t samples.size(): " << samples.size() << "\n\t samples.size(): " << samples.size()
<< "\n\t is_graph_labeling_problem(samples,labels): " << is_graph_labeling_problem(samples,labels) << "\n\t reason_for_failure: " << reason_for_failure
<< "\n\t is_learning_problem(samples,labels): " << is_learning_problem(samples,labels)
); );
#endif
std::vector<bool> temp; std::vector<bool> temp;
unsigned long num_pos_correct = 0; unsigned long num_pos_correct = 0;
...@@ -83,15 +85,20 @@ namespace dlib ...@@ -83,15 +85,20 @@ namespace dlib
const long folds const long folds
) )
{ {
DLIB_ASSERT(is_graph_labeling_problem(samples, labels) && #ifdef ENABLE_ASSERTS
1 < folds && folds <= static_cast<long>(samples.size()), std::string reason_for_failure;
DLIB_ASSERT(is_graph_labeling_problem(samples, labels, reason_for_failure),
"\t matrix cross_validate_graph_labeling_trainer()" "\t matrix cross_validate_graph_labeling_trainer()"
<< "\n\t invalid inputs were given to this function" << "\n\t invalid inputs were given to this function"
<< "\n\t samples.size(): " << samples.size() << "\n\t samples.size(): " << samples.size()
<< "\n\t reason_for_failure: " << reason_for_failure
);
DLIB_ASSERT( 1 < folds && folds <= static_cast<long>(samples.size()),
"\t matrix cross_validate_graph_labeling_trainer()"
<< "\n\t invalid inputs were given to this function"
<< "\n\t folds: " << folds << "\n\t folds: " << folds
<< "\n\t is_graph_labeling_problem(samples,labels): " << is_graph_labeling_problem(samples,labels)
<< "\n\t is_learning_problem(samples,labels): " << is_learning_problem(samples,labels)
); );
#endif
typedef std::vector<bool> label_type; typedef std::vector<bool> label_type;
......
...@@ -171,17 +171,25 @@ namespace dlib ...@@ -171,17 +171,25 @@ namespace dlib
const std::vector<std::vector<double> >& losses const std::vector<std::vector<double> >& losses
) const ) const
{ {
DLIB_ASSERT(is_graph_labeling_problem(samples, labels) == true && #ifdef ENABLE_ASSERTS
(losses.size() == 0 || sizes_match(labels, losses) == true) && std::string reason_for_failure;
all_values_are_nonnegative(losses) == true, DLIB_ASSERT(is_graph_labeling_problem(samples, labels, reason_for_failure) == true ,
"\t void structural_graph_labeling_trainer::train()" "\t void structural_graph_labeling_trainer::train()"
<< "\n\t Invalid inputs were given to this function." << "\n\t Invalid inputs were given to this function."
<< "\n\t reason_for_failure: " << reason_for_failure
<< "\n\t samples.size(): " << samples.size() << "\n\t samples.size(): " << samples.size()
<< "\n\t labels.size(): " << labels.size() << "\n\t labels.size(): " << labels.size()
<< "\n\t this: " << this );
DLIB_ASSERT((losses.size() == 0 || sizes_match(labels, losses) == true) &&
all_values_are_nonnegative(losses) == true,
"\t void structural_graph_labeling_trainer::train()"
<< "\n\t Invalid inputs were given to this function."
<< "\n\t labels.size(): " << labels.size()
<< "\n\t losses.size(): " << losses.size() << "\n\t losses.size(): " << losses.size()
<< "\n\t sizes_match(labels,losses): " << sizes_match(labels,losses) << "\n\t sizes_match(labels,losses): " << sizes_match(labels,losses)
<< "\n\t all_values_are_nonnegative(losses): " << all_values_are_nonnegative(losses) << "\n\t all_values_are_nonnegative(losses): " << all_values_are_nonnegative(losses)
<< "\n\t this: " << this ); << "\n\t this: " << this );
#endif
structural_svm_graph_labeling_problem<graph_type> prob(samples, labels, losses, num_threads); structural_svm_graph_labeling_problem<graph_type> prob(samples, labels, losses, num_threads);
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
#include "structural_svm_problem_threaded.h" #include "structural_svm_problem_threaded.h"
#include "../graph.h" #include "../graph.h"
#include "sparse_vector.h" #include "sparse_vector.h"
#include <sstream>
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
...@@ -26,7 +27,8 @@ namespace dlib ...@@ -26,7 +27,8 @@ namespace dlib
> >
bool is_graph_labeling_problem ( bool is_graph_labeling_problem (
const dlib::array<graph_type>& samples, const dlib::array<graph_type>& samples,
const std::vector<std::vector<bool> >& labels const std::vector<std::vector<bool> >& labels,
std::string& reason_for_failure
) )
{ {
typedef typename graph_type::type node_vector_type; typedef typename graph_type::type node_vector_type;
...@@ -36,9 +38,14 @@ namespace dlib ...@@ -36,9 +38,14 @@ namespace dlib
(!is_matrix<node_vector_type>::value && !is_matrix<edge_vector_type>::value)); (!is_matrix<node_vector_type>::value && !is_matrix<edge_vector_type>::value));
std::ostringstream sout;
reason_for_failure.clear();
if (!is_learning_problem(samples, labels)) if (!is_learning_problem(samples, labels))
{
reason_for_failure = "is_learning_problem(samples, labels) returned false.";
return false; return false;
}
const bool ismat = is_matrix<typename graph_type::type>::value; const bool ismat = is_matrix<typename graph_type::type>::value;
...@@ -49,33 +56,61 @@ namespace dlib ...@@ -49,33 +56,61 @@ namespace dlib
for (unsigned long i = 0; i < samples.size(); ++i) for (unsigned long i = 0; i < samples.size(); ++i)
{ {
if (samples[i].number_of_nodes() != labels[i].size()) if (samples[i].number_of_nodes() != labels[i].size())
{
sout << "samples["<<i<<"].number_of_nodes() doesn't match labels["<<i<<"].size().";
reason_for_failure = sout.str();
return false; return false;
}
if (graph_contains_length_one_cycle(samples[i])) if (graph_contains_length_one_cycle(samples[i]))
{
sout << "graph_contains_length_one_cycle(samples["<<i<<"]) returned true.";
reason_for_failure = sout.str();
return false; return false;
}
for (unsigned long j = 0; j < samples[i].number_of_nodes(); ++j) for (unsigned long j = 0; j < samples[i].number_of_nodes(); ++j)
{ {
if (ismat && samples[i].node(j).data.size() == 0) if (ismat && samples[i].node(j).data.size() == 0)
{
sout << "A graph contains an empty vector at node: samples["<<i<<"].node("<<j<<").data.";
reason_for_failure = sout.str();
return false; return false;
}
if (ismat && node_dims == -1) if (ismat && node_dims == -1)
node_dims = samples[i].node(j).data.size(); node_dims = samples[i].node(j).data.size();
// all nodes must have vectors of the same size. // all nodes must have vectors of the same size.
if (ismat && (long)samples[i].node(j).data.size() != node_dims) if (ismat && (long)samples[i].node(j).data.size() != node_dims)
{
sout << "Not all node vectors in samples["<<i<<"] are the same dimension.";
reason_for_failure = sout.str();
return false; return false;
}
for (unsigned long n = 0; n < samples[i].node(j).number_of_neighbors(); ++n) for (unsigned long n = 0; n < samples[i].node(j).number_of_neighbors(); ++n)
{ {
if (ismat && samples[i].node(j).edge(n).size() == 0) if (ismat && samples[i].node(j).edge(n).size() == 0)
{
sout << "A graph contains an empty vector at edge: samples["<<i<<"].node("<<j<<").edge("<<n<<").";
reason_for_failure = sout.str();
return false; return false;
}
if (min(samples[i].node(j).edge(n)) < 0) if (min(samples[i].node(j).edge(n)) < 0)
{
sout << "A graph contains negative values on an edge vector at: samples["<<i<<"].node("<<j<<").edge("<<n<<").";
reason_for_failure = sout.str();
return false; return false;
}
if (ismat && edge_dims == -1) if (ismat && edge_dims == -1)
edge_dims = samples[i].node(j).edge(n).size(); edge_dims = samples[i].node(j).edge(n).size();
// all edges must have vectors of the same size. // all edges must have vectors of the same size.
if (ismat && (long)samples[i].node(j).edge(n).size() != edge_dims) if (ismat && (long)samples[i].node(j).edge(n).size() != edge_dims)
{
sout << "Not all edge vectors in samples["<<i<<"] are the same dimension.";
reason_for_failure = sout.str();
return false; return false;
}
} }
} }
} }
...@@ -83,6 +118,18 @@ namespace dlib ...@@ -83,6 +118,18 @@ namespace dlib
return true; return true;
} }
template <
typename graph_type
>
bool is_graph_labeling_problem (
const dlib::array<graph_type>& samples,
const std::vector<std::vector<bool> >& labels
)
{
std::string reason_for_failure;
return is_graph_labeling_problem(samples, labels, reason_for_failure);
}
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
template < template <
...@@ -184,17 +231,25 @@ namespace dlib ...@@ -184,17 +231,25 @@ namespace dlib
losses(losses_) losses(losses_)
{ {
// make sure requires clause is not broken // make sure requires clause is not broken
DLIB_ASSERT(is_graph_labeling_problem(samples, labels) == true && #ifdef ENABLE_ASSERTS
(losses.size() == 0 || sizes_match(labels, losses) == true) && std::string reason_for_failure;
all_values_are_nonnegative(losses) == true, DLIB_ASSERT(is_graph_labeling_problem(samples, labels, reason_for_failure) == true ,
"\t structural_svm_graph_labeling_problem::structural_svm_graph_labeling_problem()" "\t structural_svm_graph_labeling_problem::structural_svm_graph_labeling_problem()"
<< "\n\t Invalid inputs were given to this function." << "\n\t Invalid inputs were given to this function."
<< "\n\t reason_for_failure: " << reason_for_failure
<< "\n\t samples.size(): " << samples.size() << "\n\t samples.size(): " << samples.size()
<< "\n\t labels.size(): " << labels.size() << "\n\t labels.size(): " << labels.size()
<< "\n\t this: " << this );
DLIB_ASSERT((losses.size() == 0 || sizes_match(labels, losses) == true) &&
all_values_are_nonnegative(losses) == true,
"\t structural_svm_graph_labeling_problem::structural_svm_graph_labeling_problem()"
<< "\n\t Invalid inputs were given to this function."
<< "\n\t labels.size(): " << labels.size()
<< "\n\t losses.size(): " << losses.size() << "\n\t losses.size(): " << losses.size()
<< "\n\t sizes_match(labels,losses): " << sizes_match(labels,losses) << "\n\t sizes_match(labels,losses): " << sizes_match(labels,losses)
<< "\n\t all_values_are_nonnegative(losses): " << all_values_are_nonnegative(losses) << "\n\t all_values_are_nonnegative(losses): " << all_values_are_nonnegative(losses)
<< "\n\t this: " << this ); << "\n\t this: " << this );
#endif
loss_pos = 1.0; loss_pos = 1.0;
loss_neg = 1.0; loss_neg = 1.0;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment