Commit 6eee12f2 authored by Davis King's avatar Davis King

Improved the assert messages related to badly formed graph labeling

problems.
parent 0246088a
......@@ -23,13 +23,15 @@ namespace dlib
const std::vector<std::vector<bool> >& labels
)
{
DLIB_ASSERT(is_graph_labeling_problem(samples, labels) ,
#ifdef ENABLE_ASSERTS
std::string reason_for_failure;
DLIB_ASSERT(is_graph_labeling_problem(samples, labels, reason_for_failure) ,
"\t matrix test_graph_labeling_function()"
<< "\n\t invalid inputs were given to this function"
<< "\n\t samples.size(): " << samples.size()
<< "\n\t is_graph_labeling_problem(samples,labels): " << is_graph_labeling_problem(samples,labels)
<< "\n\t is_learning_problem(samples,labels): " << is_learning_problem(samples,labels)
<< "\n\t reason_for_failure: " << reason_for_failure
);
#endif
std::vector<bool> temp;
unsigned long num_pos_correct = 0;
......@@ -83,15 +85,20 @@ namespace dlib
const long folds
)
{
DLIB_ASSERT(is_graph_labeling_problem(samples, labels) &&
1 < folds && folds <= static_cast<long>(samples.size()),
#ifdef ENABLE_ASSERTS
std::string reason_for_failure;
DLIB_ASSERT(is_graph_labeling_problem(samples, labels, reason_for_failure),
"\t matrix cross_validate_graph_labeling_trainer()"
<< "\n\t invalid inputs were given to this function"
<< "\n\t samples.size(): " << samples.size()
<< "\n\t reason_for_failure: " << reason_for_failure
);
DLIB_ASSERT( 1 < folds && folds <= static_cast<long>(samples.size()),
"\t matrix cross_validate_graph_labeling_trainer()"
<< "\n\t invalid inputs were given to this function"
<< "\n\t folds: " << folds
<< "\n\t is_graph_labeling_problem(samples,labels): " << is_graph_labeling_problem(samples,labels)
<< "\n\t is_learning_problem(samples,labels): " << is_learning_problem(samples,labels)
);
#endif
typedef std::vector<bool> label_type;
......
......@@ -171,17 +171,25 @@ namespace dlib
const std::vector<std::vector<double> >& losses
) const
{
DLIB_ASSERT(is_graph_labeling_problem(samples, labels) == true &&
(losses.size() == 0 || sizes_match(labels, losses) == true) &&
all_values_are_nonnegative(losses) == true,
#ifdef ENABLE_ASSERTS
std::string reason_for_failure;
DLIB_ASSERT(is_graph_labeling_problem(samples, labels, reason_for_failure) == true ,
"\t void structural_graph_labeling_trainer::train()"
<< "\n\t Invalid inputs were given to this function."
<< "\n\t reason_for_failure: " << reason_for_failure
<< "\n\t samples.size(): " << samples.size()
<< "\n\t labels.size(): " << labels.size()
<< "\n\t this: " << this );
DLIB_ASSERT((losses.size() == 0 || sizes_match(labels, losses) == true) &&
all_values_are_nonnegative(losses) == true,
"\t void structural_graph_labeling_trainer::train()"
<< "\n\t Invalid inputs were given to this function."
<< "\n\t labels.size(): " << labels.size()
<< "\n\t losses.size(): " << losses.size()
<< "\n\t sizes_match(labels,losses): " << sizes_match(labels,losses)
<< "\n\t all_values_are_nonnegative(losses): " << all_values_are_nonnegative(losses)
<< "\n\t this: " << this );
#endif
structural_svm_graph_labeling_problem<graph_type> prob(samples, labels, losses, num_threads);
......
......@@ -13,6 +13,7 @@
#include "structural_svm_problem_threaded.h"
#include "../graph.h"
#include "sparse_vector.h"
#include <sstream>
// ----------------------------------------------------------------------------------------
......@@ -26,7 +27,8 @@ namespace dlib
>
bool is_graph_labeling_problem (
const dlib::array<graph_type>& samples,
const std::vector<std::vector<bool> >& labels
const std::vector<std::vector<bool> >& labels,
std::string& reason_for_failure
)
{
typedef typename graph_type::type node_vector_type;
......@@ -36,9 +38,14 @@ namespace dlib
(!is_matrix<node_vector_type>::value && !is_matrix<edge_vector_type>::value));
std::ostringstream sout;
reason_for_failure.clear();
if (!is_learning_problem(samples, labels))
{
reason_for_failure = "is_learning_problem(samples, labels) returned false.";
return false;
}
const bool ismat = is_matrix<typename graph_type::type>::value;
......@@ -49,33 +56,61 @@ namespace dlib
for (unsigned long i = 0; i < samples.size(); ++i)
{
if (samples[i].number_of_nodes() != labels[i].size())
{
sout << "samples["<<i<<"].number_of_nodes() doesn't match labels["<<i<<"].size().";
reason_for_failure = sout.str();
return false;
}
if (graph_contains_length_one_cycle(samples[i]))
{
sout << "graph_contains_length_one_cycle(samples["<<i<<"]) returned true.";
reason_for_failure = sout.str();
return false;
}
for (unsigned long j = 0; j < samples[i].number_of_nodes(); ++j)
{
if (ismat && samples[i].node(j).data.size() == 0)
{
sout << "A graph contains an empty vector at node: samples["<<i<<"].node("<<j<<").data.";
reason_for_failure = sout.str();
return false;
}
if (ismat && node_dims == -1)
node_dims = samples[i].node(j).data.size();
// all nodes must have vectors of the same size.
if (ismat && (long)samples[i].node(j).data.size() != node_dims)
{
sout << "Not all node vectors in samples["<<i<<"] are the same dimension.";
reason_for_failure = sout.str();
return false;
}
for (unsigned long n = 0; n < samples[i].node(j).number_of_neighbors(); ++n)
{
if (ismat && samples[i].node(j).edge(n).size() == 0)
{
sout << "A graph contains an empty vector at edge: samples["<<i<<"].node("<<j<<").edge("<<n<<").";
reason_for_failure = sout.str();
return false;
}
if (min(samples[i].node(j).edge(n)) < 0)
{
sout << "A graph contains negative values on an edge vector at: samples["<<i<<"].node("<<j<<").edge("<<n<<").";
reason_for_failure = sout.str();
return false;
}
if (ismat && edge_dims == -1)
edge_dims = samples[i].node(j).edge(n).size();
// all edges must have vectors of the same size.
if (ismat && (long)samples[i].node(j).edge(n).size() != edge_dims)
{
sout << "Not all edge vectors in samples["<<i<<"] are the same dimension.";
reason_for_failure = sout.str();
return false;
}
}
}
}
......@@ -83,6 +118,18 @@ namespace dlib
return true;
}
template <
typename graph_type
>
bool is_graph_labeling_problem (
const dlib::array<graph_type>& samples,
const std::vector<std::vector<bool> >& labels
)
{
std::string reason_for_failure;
return is_graph_labeling_problem(samples, labels, reason_for_failure);
}
// ----------------------------------------------------------------------------------------
template <
......@@ -184,17 +231,25 @@ namespace dlib
losses(losses_)
{
// make sure requires clause is not broken
DLIB_ASSERT(is_graph_labeling_problem(samples, labels) == true &&
(losses.size() == 0 || sizes_match(labels, losses) == true) &&
all_values_are_nonnegative(losses) == true,
#ifdef ENABLE_ASSERTS
std::string reason_for_failure;
DLIB_ASSERT(is_graph_labeling_problem(samples, labels, reason_for_failure) == true ,
"\t structural_svm_graph_labeling_problem::structural_svm_graph_labeling_problem()"
<< "\n\t Invalid inputs were given to this function."
<< "\n\t reason_for_failure: " << reason_for_failure
<< "\n\t samples.size(): " << samples.size()
<< "\n\t labels.size(): " << labels.size()
<< "\n\t this: " << this );
DLIB_ASSERT((losses.size() == 0 || sizes_match(labels, losses) == true) &&
all_values_are_nonnegative(losses) == true,
"\t structural_svm_graph_labeling_problem::structural_svm_graph_labeling_problem()"
<< "\n\t Invalid inputs were given to this function."
<< "\n\t labels.size(): " << labels.size()
<< "\n\t losses.size(): " << losses.size()
<< "\n\t sizes_match(labels,losses): " << sizes_match(labels,losses)
<< "\n\t all_values_are_nonnegative(losses): " << all_values_are_nonnegative(losses)
<< "\n\t this: " << this );
#endif
loss_pos = 1.0;
loss_neg = 1.0;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment