Commit d575fdce authored by Davis King's avatar Davis King

Switched all the graph labeling stuff to use bool as a node label

rather than the node_label type from the min_cut object.  This
should make the interface much less confusing.
parent ba33b82e
......@@ -24,7 +24,7 @@ namespace dlib
public:
typedef std::vector<node_label> label_type;
typedef std::vector<bool> label_type;
typedef label_type result_type;
graph_labeler()
......@@ -56,7 +56,7 @@ namespace dlib
template <typename graph_type>
void operator() (
const graph_type& sample,
result_type& labels
std::vector<bool>& labels
) const
{
// make sure requires clause is not broken
......@@ -113,7 +113,6 @@ namespace dlib
}
#endif
labels.clear();
graph<double,double>::kernel_1a g;
copy_graph_structure(sample, g);
......@@ -133,15 +132,24 @@ namespace dlib
}
find_max_factor_graph_potts(g, labels);
labels.clear();
std::vector<node_label> temp;
find_max_factor_graph_potts(g, temp);
for (unsigned long i = 0; i < temp.size(); ++i)
{
if (temp[i] != 0)
labels.push_back(true);
else
labels.push_back(false);
}
}
template <typename graph_type>
result_type operator() (
std::vector<bool> operator() (
const graph_type& sample
) const
{
result_type temp;
std::vector<bool> temp;
(*this)(sample, temp);
return temp;
}
......
......@@ -44,7 +44,7 @@ namespace dlib
public:
typedef std::vector<node_label> label_type;
typedef std::vector<bool> label_type;
typedef label_type result_type;
graph_labeler(
......@@ -73,9 +73,9 @@ namespace dlib
/*!
ensures
- Recall that the score function for an edge is a linear function of
the vector stored at that edge in the graph. This means there is some
vector E which we dot product with the vector in the graph to compute
the score. Therefore, this function returns that E vector which defines
the vector stored at that edge. This means there is some vector, E,
which we dot product with the vector in the graph to compute the
score. Therefore, this function returns that E vector which defines
the edge score function.
!*/
......@@ -84,16 +84,16 @@ namespace dlib
/*!
ensures
- Recall that the score function for a node is a linear function of
the vector stored in that node in the graph. This means there is some
vector W which we dot product with the vector in the graph to compute
the score. Therefore, this function returns that W vector which defines
the node score function.
the vector stored in that node. This means there is some vector, W,
which we dot product with the vector in the graph to compute the score.
Therefore, this function returns that W vector which defines the node
score function.
!*/
template <typename graph_type>
void operator() (
const graph_type& sample,
result_type& labels
std::vector<bool>& labels
) const;
/*!
requires
......@@ -111,10 +111,7 @@ namespace dlib
in #labels.
- #labels.size() == sample.number_of_nodes()
- for all valid i:
- if (sample.node(i) is predicted to have a label of true) then
- #labels[i] != 0
- else
- #labels[i] == 0
- #labels[i] == the label of the node sample.node(i).
- The labels are computed by creating a graph, G, with scalar values on each node
and edge. The scalar values are calculated according to the following:
- for all valid i:
......@@ -125,7 +122,7 @@ namespace dlib
!*/
template <typename graph_type>
result_type operator() (
std::vector<bool> operator() (
const graph_type& sample
) const;
/*!
......
......@@ -20,7 +20,7 @@ namespace dlib
matrix<double,1,2> test_graph_labeling_function (
const graph_labeler& labeler,
const dlib::array<graph_type>& samples,
const std::vector<std::vector<node_label> >& labels
const std::vector<std::vector<bool> >& labels
)
{
DLIB_ASSERT(is_graph_labeling_problem(samples, labels) ,
......@@ -31,7 +31,7 @@ namespace dlib
<< "\n\t is_learning_problem(samples,labels): " << is_learning_problem(samples,labels)
);
std::vector<node_label> temp;
std::vector<bool> temp;
unsigned long num_pos_correct = 0;
unsigned long num_pos = 0;
unsigned long num_neg_correct = 0;
......@@ -79,7 +79,7 @@ namespace dlib
matrix<double,1,2> cross_validate_graph_labeling_trainer (
const trainer_type& trainer,
const dlib::array<graph_type>& samples,
const std::vector<std::vector<node_label> >& labels,
const std::vector<std::vector<bool> >& labels,
const long folds
)
{
......@@ -93,7 +93,7 @@ namespace dlib
<< "\n\t is_learning_problem(samples,labels): " << is_learning_problem(samples,labels)
);
typedef std::vector<node_label> label_type;
typedef std::vector<bool> label_type;
const long num_in_test = samples.size()/folds;
const long num_in_train = samples.size() - num_in_test;
......@@ -105,7 +105,7 @@ namespace dlib
long next_test_idx = 0;
std::vector<node_label> temp;
std::vector<bool> temp;
unsigned long num_pos_correct = 0;
unsigned long num_pos = 0;
unsigned long num_neg_correct = 0;
......
......@@ -19,7 +19,7 @@ namespace dlib
matrix<double,1,2> test_graph_labeling_function (
const graph_labeler& labeler,
const dlib::array<graph_type>& samples,
const std::vector<std::vector<node_label> >& labels
const std::vector<std::vector<bool> >& labels
);
/*!
requires
......@@ -48,7 +48,7 @@ namespace dlib
matrix<double,1,2> cross_validate_graph_labeling_trainer (
const trainer_type& trainer,
const dlib::array<graph_type>& samples,
const std::vector<std::vector<node_label> >& labels,
const std::vector<std::vector<bool> >& labels,
const long folds
);
/*!
......
......@@ -21,7 +21,7 @@ namespace dlib
class structural_graph_labeling_trainer
{
public:
typedef std::vector<node_label> label_type;
typedef std::vector<bool> label_type;
typedef graph_labeler<vector_type> trained_function_type;
structural_graph_labeling_trainer (
......
......@@ -36,7 +36,7 @@ namespace dlib
!*/
public:
typedef std::vector<node_label> label_type;
typedef std::vector<bool> label_type;
typedef graph_labeler<vector_type> trained_function_type;
structural_graph_labeling_trainer (
......
......@@ -26,7 +26,7 @@ namespace dlib
>
bool is_graph_labeling_problem (
const dlib::array<graph_type>& samples,
const std::vector<std::vector<node_label> >& labels
const std::vector<std::vector<bool> >& labels
)
{
typedef typename graph_type::type node_vector_type;
......@@ -130,7 +130,7 @@ namespace dlib
typedef graph_type sample_type;
typedef std::vector<node_label> label_type;
typedef std::vector<bool> label_type;
structural_svm_graph_labeling_problem(
const dlib::array<sample_type>& samples_,
......@@ -234,20 +234,17 @@ namespace dlib
psi = 0;
for (unsigned long i = 0; i < sample.number_of_nodes(); ++i)
{
const bool label_i = (label[i]!=0);
// accumulate the node vectors
if (label_i == true)
if (label[i] == true)
set_rowm(psi, range(edge_dims, psi.size()-1)) += sample.node(i).data;
for (unsigned long n = 0; n < sample.node(i).number_of_neighbors(); ++n)
{
const unsigned long j = sample.node(i).neighbor(n).index();
const bool label_j = (label[j]!=0);
// Don't double count edges. Also only include the vector if
// the labels disagree.
if (i < j && label_i != label_j)
if (i < j && label[i] != label[j])
{
set_rowm(psi, range(0, edge_dims-1)) -= sample.node(i).edge(n);
}
......@@ -290,20 +287,17 @@ namespace dlib
psi.clear();
for (unsigned long i = 0; i < sample.number_of_nodes(); ++i)
{
const bool label_i = (label[i]!=0);
// accumulate the node vectors
if (label_i == true)
if (label[i] == true)
add_to_sparse_vect(psi, sample.node(i).data, edge_dims);
for (unsigned long n = 0; n < sample.node(i).number_of_neighbors(); ++n)
{
const unsigned long j = sample.node(i).neighbor(n).index();
const bool label_j = (label[j]!=0);
// Don't double count edges. Also only include the vector if
// the labels disagree.
if (i < j && label_i != label_j)
if (i < j && label[i] != label[j])
{
subtract_from_sparse_vect(psi, sample.node(i).edge(n));
}
......@@ -338,8 +332,7 @@ namespace dlib
// Include a loss augmentation so that we will get the proper loss augmented
// max when we use find_max_factor_graph_potts() below.
const bool label_i = (labels[idx][i]!=0);
if (label_i)
if (labels[idx][i])
g.node(i).data -= loss_pos;
else
g.node(i).data += loss_neg;
......@@ -361,12 +354,15 @@ namespace dlib
find_max_factor_graph_potts(g, labeling);
std::vector<bool> bool_labeling;
bool_labeling.reserve(labeling.size());
// figure out the loss
loss = 0;
for (unsigned long i = 0; i < labeling.size(); ++i)
{
const bool true_label = (labels[idx][i]!= 0);
const bool true_label = labels[idx][i];
const bool pred_label = (labeling[i]!= 0);
bool_labeling.push_back(pred_label);
if (true_label != pred_label)
{
if (true_label == true)
......@@ -377,7 +373,7 @@ namespace dlib
}
// compute psi
get_joint_feature_vector(samp, labeling, psi);
get_joint_feature_vector(samp, bool_labeling, psi);
}
const dlib::array<sample_type>& samples;
......
......@@ -22,7 +22,7 @@ namespace dlib
>
bool is_graph_labeling_problem (
const dlib::array<graph_type>& samples,
const std::vector<std::vector<node_label> >& labels
const std::vector<std::vector<bool> >& labels
);
/*!
requires
......@@ -83,7 +83,7 @@ namespace dlib
typedef matrix<double,0,1> matrix_type;
typedef typename graph_type::type feature_vector_type;
typedef graph_type sample_type;
typedef std::vector<node_label> label_type;
typedef std::vector<bool> label_type;
structural_svm_graph_labeling_problem(
const dlib::array<sample_type>& samples,
......
......@@ -34,16 +34,16 @@ namespace
//samples.clear();
//labels.clear();
std::vector<node_label> label;
std::vector<bool> label;
graph_type g;
// ---------------------------
g.set_number_of_nodes(4);
label.resize(g.number_of_nodes());
g.node(0).data = 0, 0, 1; label[0] = 1;
g.node(1).data = 0, 0, 1; label[1] = 1;
g.node(2).data = 0, 1, 0; label[2] = 0;
g.node(3).data = 0, 1, 0; label[3] = 0;
g.node(0).data = 0, 0, 1; label[0] = true;
g.node(1).data = 0, 0, 1; label[1] = true;
g.node(2).data = 0, 1, 0; label[2] = false;
g.node(3).data = 0, 1, 0; label[3] = false;
g.add_edge(0,1);
g.add_edge(1,2);
......@@ -61,10 +61,10 @@ namespace
g.clear();
g.set_number_of_nodes(4);
label.resize(g.number_of_nodes());
g.node(0).data = 0, 0, 1; label[0] = 1;
g.node(1).data = 0, 0, 0; label[1] = 1;
g.node(2).data = 0, 1, 0; label[2] = 0;
g.node(3).data = 0, 0, 0; label[3] = 0;
g.node(0).data = 0, 0, 1; label[0] = true;
g.node(1).data = 0, 0, 0; label[1] = true;
g.node(2).data = 0, 1, 0; label[2] = false;
g.node(3).data = 0, 0, 0; label[3] = false;
g.add_edge(0,1);
g.add_edge(1,2);
......@@ -82,10 +82,10 @@ namespace
g.clear();
g.set_number_of_nodes(4);
label.resize(g.number_of_nodes());
g.node(0).data = 0, 1, 0; label[0] = 0;
g.node(1).data = 0, 1, 0; label[1] = 0;
g.node(2).data = 0, 1, 0; label[2] = 0;
g.node(3).data = 0, 0, 0; label[3] = 0;
g.node(0).data = 0, 1, 0; label[0] = false;
g.node(1).data = 0, 1, 0; label[1] = false;
g.node(2).data = 0, 1, 0; label[2] = false;
g.node(3).data = 0, 0, 0; label[3] = false;
g.add_edge(0,1);
g.add_edge(1,2);
......@@ -117,17 +117,17 @@ namespace
//samples.clear();
//labels.clear();
std::vector<node_label> label;
std::vector<bool> label;
graph_type g;
typename graph_type::edge_type v;
// ---------------------------
g.set_number_of_nodes(4);
label.resize(g.number_of_nodes());
g.node(0).data[2] = 1; label[0] = 1;
g.node(1).data[2] = 1; label[1] = 1;
g.node(2).data[1] = 1; label[2] = 0;
g.node(3).data[1] = 1; label[3] = 0;
g.node(0).data[2] = 1; label[0] = true;
g.node(1).data[2] = 1; label[1] = true;
g.node(2).data[1] = 1; label[2] = false;
g.node(3).data[1] = 1; label[3] = false;
g.add_edge(0,1);
g.add_edge(1,2);
......@@ -147,11 +147,11 @@ namespace
g.clear();
g.set_number_of_nodes(5);
label.resize(g.number_of_nodes());
g.node(0).data[2] = 1; label[0] = 1;
g.node(1).data[0] = 0; label[1] = 1;
g.node(2).data[1] = 1; label[2] = 0;
g.node(3).data[0] = 0; label[3] = 0;
label[4] = 1;
g.node(0).data[2] = 1; label[0] = true;
g.node(1).data[0] = 0; label[1] = true;
g.node(2).data[1] = 1; label[2] = false;
g.node(3).data[0] = 0; label[3] = false;
label[4] = true;
g.add_edge(0,1);
g.add_edge(1,4);
......@@ -171,10 +171,10 @@ namespace
g.clear();
g.set_number_of_nodes(4);
label.resize(g.number_of_nodes());
g.node(0).data[1] = 1; label[0] = 0;
g.node(1).data[1] = 1; label[1] = 0;
g.node(2).data[1] = 1; label[2] = 0;
g.node(3).data[1] = 0; label[3] = 0;
g.node(0).data[1] = 1; label[0] = false;
g.node(1).data[1] = 1; label[1] = false;
g.node(2).data[1] = 1; label[2] = false;
g.node(3).data[1] = 0; label[3] = false;
g.add_edge(0,1);
g.add_edge(1,2);
......@@ -208,16 +208,16 @@ namespace
//samples.clear();
//labels.clear();
std::vector<node_label> label;
std::vector<bool> label;
graph_type g;
// ---------------------------
g.set_number_of_nodes(4);
label.resize(g.number_of_nodes());
g.node(0).data = 0, 0, 1; label[0] = 1;
g.node(1).data = 0, 0, 1; label[1] = 1;
g.node(2).data = 0, 1, 0; label[2] = 0;
g.node(3).data = 0, 1, 0; label[3] = 0;
g.node(0).data = 0, 0, 1; label[0] = true;
g.node(1).data = 0, 0, 1; label[1] = true;
g.node(2).data = 0, 1, 0; label[2] = false;
g.node(3).data = 0, 1, 0; label[3] = false;
g.add_edge(0,1);
g.add_edge(1,2);
......@@ -249,17 +249,17 @@ namespace
//samples.clear();
//labels.clear();
std::vector<node_label> label;
std::vector<bool> label;
graph_type g;
typename graph_type::edge_type v;
// ---------------------------
g.set_number_of_nodes(4);
label.resize(g.number_of_nodes());
g.node(0).data[2] = 1; label[0] = 1;
g.node(1).data[2] = 1; label[1] = 1;
g.node(2).data[1] = 1; label[2] = 0;
g.node(3).data[1] = 1; label[3] = 0;
g.node(0).data[2] = 1; label[0] = true;
g.node(1).data[2] = 1; label[1] = true;
g.node(2).data[1] = 1; label[2] = false;
g.node(3).data[1] = 1; label[3] = false;
g.add_edge(0,1);
g.add_edge(1,2);
......@@ -289,7 +289,7 @@ namespace
>
void test1(
const dlib::array<graph_type>& samples,
const std::vector<std::vector<node_label> >& labels
const std::vector<std::vector<bool> >& labels
)
{
dlog << LINFO << "begin test1()";
......@@ -307,7 +307,7 @@ namespace
labeler = graph_labeler<vector_type>();
deserialize(labeler, sin);
std::vector<node_label> temp;
std::vector<bool> temp;
for (unsigned long k = 0; k < samples.size(); ++k)
{
temp = labeler(samples[k]);
......@@ -353,7 +353,7 @@ namespace
typedef dlib::graph<node_vector_type, edge_vector_type>::kernel_1a_c graph_type;
dlib::array<graph_type> samples;
std::vector<std::vector<node_label> > labels;
std::vector<std::vector<bool> > labels;
make_data<graph_type>(samples, labels);
make_data<graph_type>(samples, labels);
......@@ -372,7 +372,7 @@ namespace
typedef dlib::graph<node_vector_type, edge_vector_type>::kernel_1a_c graph_type;
dlib::array<graph_type> samples;
std::vector<std::vector<node_label> > labels;
std::vector<std::vector<bool> > labels;
make_data<graph_type>(samples, labels);
make_data<graph_type>(samples, labels);
......@@ -391,7 +391,7 @@ namespace
typedef dlib::graph<node_vector_type, edge_vector_type>::kernel_1a_c graph_type;
dlib::array<graph_type> samples;
std::vector<std::vector<node_label> > labels;
std::vector<std::vector<bool> > labels;
make_data_sparse<graph_type>(samples, labels);
make_data_sparse<graph_type>(samples, labels);
......@@ -413,7 +413,7 @@ namespace
typedef dlib::graph<node_vector_type, edge_vector_type>::kernel_1a_c graph_type;
dlib::array<graph_type> samples;
std::vector<std::vector<node_label> > labels;
std::vector<std::vector<bool> > labels;
make_data2<graph_type>(samples, labels);
make_data2<graph_type>(samples, labels);
......@@ -432,7 +432,7 @@ namespace
typedef dlib::graph<node_vector_type, edge_vector_type>::kernel_1a_c graph_type;
dlib::array<graph_type> samples;
std::vector<std::vector<node_label> > labels;
std::vector<std::vector<bool> > labels;
make_data2_sparse<graph_type>(samples, labels);
make_data2_sparse<graph_type>(samples, labels);
......@@ -451,7 +451,7 @@ namespace
typedef dlib::graph<node_vector_type, edge_vector_type>::kernel_1a_c graph_type;
dlib::array<graph_type> samples;
std::vector<std::vector<node_label> > labels;
std::vector<std::vector<bool> > labels;
make_data2_sparse<graph_type>(samples, labels);
make_data2_sparse<graph_type>(samples, labels);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment