Commit 0660dc02 authored by Davis King's avatar Davis King

Made python svm struct interface work with sparse vectors

parent 8c111ee7
...@@ -13,10 +13,15 @@ using namespace dlib; ...@@ -13,10 +13,15 @@ using namespace dlib;
using namespace std; using namespace std;
using namespace boost::python; using namespace boost::python;
class svm_struct_dense : public structural_svm_problem<matrix<double,0,1> > template <typename psi_type>
class svm_struct_prob : public structural_svm_problem<matrix<double,0,1>, psi_type>
{ {
typedef structural_svm_problem<matrix<double,0,1>, psi_type> base;
typedef typename base::feature_vector_type feature_vector_type;
typedef typename base::matrix_type matrix_type;
typedef typename base::scalar_type scalar_type;
public: public:
svm_struct_dense ( svm_struct_prob (
object& problem_, object& problem_,
long num_dimensions_, long num_dimensions_,
long num_samples_ long num_samples_
...@@ -71,16 +76,8 @@ private: ...@@ -71,16 +76,8 @@ private:
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
/* template <typename psi_type>
class svm_struct_sparse : public structural_svm_problem<matrix<double,0,1>, matrix<double,0,1> solve_structural_svm_problem_impl(
std::vector<std::pair<unsigned long,double> >
{
};
*/
// ----------------------------------------------------------------------------------------
matrix<double,0,1> solve_structural_svm_problem(
object problem object problem
) )
{ {
...@@ -101,6 +98,8 @@ matrix<double,0,1> solve_structural_svm_problem( ...@@ -101,6 +98,8 @@ matrix<double,0,1> solve_structural_svm_problem(
const long num_samples = extract<long>(problem.attr("num_samples")); const long num_samples = extract<long>(problem.attr("num_samples"));
const long num_dimensions = extract<long>(problem.attr("num_dimensions")); const long num_dimensions = extract<long>(problem.attr("num_dimensions"));
pyassert(num_samples > 0, "You can't train a Structural-SVM if you don't have any training samples.");
if (be_verbose) if (be_verbose)
{ {
cout << "C: " << C << endl; cout << "C: " << C << endl;
...@@ -113,8 +112,7 @@ matrix<double,0,1> solve_structural_svm_problem( ...@@ -113,8 +112,7 @@ matrix<double,0,1> solve_structural_svm_problem(
cout << endl; cout << endl;
} }
svm_struct_prob<psi_type> prob(problem, num_dimensions, num_samples);
svm_struct_dense prob(problem, num_dimensions, num_samples);
prob.set_c(C); prob.set_c(C);
prob.set_epsilon(eps); prob.set_epsilon(eps);
prob.set_max_cache_size(max_cache_size); prob.set_max_cache_size(max_cache_size);
...@@ -132,6 +130,20 @@ matrix<double,0,1> solve_structural_svm_problem( ...@@ -132,6 +130,20 @@ matrix<double,0,1> solve_structural_svm_problem(
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
matrix<double,0,1> solve_structural_svm_problem(
object problem
)
{
// Check if the python code is using sparse or dense vectors to represent PSI()
extract<matrix<double,0,1> > isdense(problem.attr("get_truth_joint_feature_vector")(0));
if (isdense.check())
return solve_structural_svm_problem_impl<matrix<double,0,1> >(problem);
else
return solve_structural_svm_problem_impl<std::vector<std::pair<unsigned long,double> > >(problem);
}
// ----------------------------------------------------------------------------------------
void bind_svm_struct() void bind_svm_struct()
{ {
using boost::python::arg; using boost::python::arg;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment