Commit 0660dc02 authored by Davis King's avatar Davis King

Made python svm struct interface work with sparse vectors

parent 8c111ee7
......@@ -13,10 +13,15 @@ using namespace dlib;
using namespace std;
using namespace boost::python;
class svm_struct_dense : public structural_svm_problem<matrix<double,0,1> >
template <typename psi_type>
class svm_struct_prob : public structural_svm_problem<matrix<double,0,1>, psi_type>
{
typedef structural_svm_problem<matrix<double,0,1>, psi_type> base;
typedef typename base::feature_vector_type feature_vector_type;
typedef typename base::matrix_type matrix_type;
typedef typename base::scalar_type scalar_type;
public:
svm_struct_dense (
svm_struct_prob (
object& problem_,
long num_dimensions_,
long num_samples_
......@@ -71,16 +76,8 @@ private:
// ----------------------------------------------------------------------------------------
/*
class svm_struct_sparse : public structural_svm_problem<matrix<double,0,1>,
std::vector<std::pair<unsigned long,double> >
{
};
*/
// ----------------------------------------------------------------------------------------
matrix<double,0,1> solve_structural_svm_problem(
template <typename psi_type>
matrix<double,0,1> solve_structural_svm_problem_impl(
object problem
)
{
......@@ -101,6 +98,8 @@ matrix<double,0,1> solve_structural_svm_problem(
const long num_samples = extract<long>(problem.attr("num_samples"));
const long num_dimensions = extract<long>(problem.attr("num_dimensions"));
pyassert(num_samples > 0, "You can't train a Structural-SVM if you don't have any training samples.");
if (be_verbose)
{
cout << "C: " << C << endl;
......@@ -113,8 +112,7 @@ matrix<double,0,1> solve_structural_svm_problem(
cout << endl;
}
svm_struct_dense prob(problem, num_dimensions, num_samples);
svm_struct_prob<psi_type> prob(problem, num_dimensions, num_samples);
prob.set_c(C);
prob.set_epsilon(eps);
prob.set_max_cache_size(max_cache_size);
......@@ -132,6 +130,20 @@ matrix<double,0,1> solve_structural_svm_problem(
// ----------------------------------------------------------------------------------------
matrix<double,0,1> solve_structural_svm_problem(
object problem
)
{
// Check if the python code is using sparse or dense vectors to represent PSI()
extract<matrix<double,0,1> > isdense(problem.attr("get_truth_joint_feature_vector")(0));
if (isdense.check())
return solve_structural_svm_problem_impl<matrix<double,0,1> >(problem);
else
return solve_structural_svm_problem_impl<std::vector<std::pair<unsigned long,double> > >(problem);
}
// ----------------------------------------------------------------------------------------
void bind_svm_struct()
{
using boost::python::arg;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment