Commit 3fa93d7e authored by Davis King's avatar Davis King

Fixed the code so it works with sparse vectors.

parent e7552958
......@@ -147,13 +147,49 @@ namespace dlib
matrix<double,0,1> w;
solver(prob, w, prob.get_num_edge_weights());
vector_type edge_weights = rowm(w,range(0, prob.get_num_edge_weights()-1));
vector_type node_weights = rowm(w,range(prob.get_num_edge_weights(),w.size()-1));
vector_type edge_weights;
vector_type node_weights;
populate_weights(w, edge_weights, node_weights, prob.get_num_edge_weights());
return graph_labeler<vector_type>(edge_weights, node_weights);
}
private:
template <typename T>
typename enable_if<is_matrix<T> >::type populate_weights (
const matrix<double,0,1>& w,
T& edge_weights,
T& node_weights,
long split_idx
) const
{
edge_weights = rowm(w,range(0, split_idx-1));
node_weights = rowm(w,range(split_idx,w.size()-1));
}
template <typename T>
typename disable_if<is_matrix<T> >::type populate_weights (
const matrix<double,0,1>& w,
T& edge_weights,
T& node_weights,
long split_idx
) const
{
edge_weights.clear();
node_weights.clear();
for (long i = 0; i < split_idx; ++i)
{
if (w(i) != 0)
edge_weights.insert(edge_weights.end(), std::make_pair(i,w(i)));
}
for (long i = split_idx; i < w.size(); ++i)
{
if (w(i) != 0)
node_weights.insert(node_weights.end(), std::make_pair(i-split_idx,w(i)));
}
}
double C;
oca solver;
double eps;
......
......@@ -12,6 +12,7 @@
#include <iterator>
#include "structural_svm_problem_threaded.h"
#include "../graph.h"
#include "sparse_vector.h"
// ----------------------------------------------------------------------------------------
......@@ -170,36 +171,22 @@ namespace dlib
"\t structural_svm_graph_labeling_problem::structural_svm_graph_labeling_problem()"
<< "\n\t invalid inputs were given to this function");
using namespace dlib::sparse_vector;
// Figure out how many dimensions are in a node vector. Just pick
// the first node we find and use it as the representative example.
node_dims = 0;
for (unsigned long i = 0; i < samples.size(); ++i)
{
if (samples[i].number_of_nodes() > 0)
{
node_dims = samples[i].node(0).data.size();
break;
}
}
// Figure out how many dimensions are in an edge vector. Just pick
// the first edge we find and use it as the representative example.
// figure out how many dimensions are in the node and edge vectors.
node_dims = 0;
edge_dims = 0;
for (unsigned long i = 0; i < samples.size(); ++i)
{
for (unsigned long j = 0; j < samples[i].number_of_nodes(); ++j)
{
if (samples[i].node(j).number_of_neighbors() != 0)
node_dims = std::max(node_dims,(long)max_index_plus_one(samples[i].node(j).data));
for (unsigned long n = 0; n < samples[i].node(j).number_of_neighbors(); ++n)
{
edge_dims = samples[i].node(j).edge(0).size();
break;
edge_dims = std::max(edge_dims, (long)max_index_plus_one(samples[i].node(j).edge(n)));
}
}
// if we found an edge then stop
if (edge_dims != 0)
break;
}
}
......@@ -270,9 +257,9 @@ namespace dlib
unsigned long offset
) const
{
for (unsigned long i = 0; i < vect.size(); ++i)
for (typename T::const_iterator i = vect.begin(); i != vect.end(); ++i)
{
psi.push_back(std::make_pair(vect[i].first+offset, vect[i].second));
psi.insert(psi.end(), std::make_pair(i->first+offset, i->second));
}
}
......@@ -282,9 +269,9 @@ namespace dlib
const T& vect
) const
{
for (unsigned long i = 0; i < vect.size(); ++i)
for (typename T::const_iterator i = vect.begin(); i != vect.end(); ++i)
{
psi.push_back(std::make_pair(vect[i].first, -vect[i].second));
psi.insert(psi.end(), std::make_pair(i->first, -i->second));
}
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment