Commit cc9ff97a authored by Davis King's avatar Davis King

Cleaned up python svm struct code a little.

parent d0a054f1
#!/usr/bin/python #!/usr/bin/python
# The contents of this file are in the public domain. See LICENSE_FOR_EXAMPLE_PROGRAMS.txt # The contents of this file are in the public domain. See LICENSE_FOR_EXAMPLE_PROGRAMS.txt
# #
# # This is an example illustrating the use of the structural SVM solver from the dlib C++
# Library. This example will briefly introduce it and then walk through an example showing
# how to use it to create a simple multi-class classifier.
#
# #
# COMPILING THE DLIB PYTHON INTERFACE # COMPILING THE DLIB PYTHON INTERFACE
# Dlib comes with a compiled python interface for python 2.7 on MS Windows. If # Dlib comes with a compiled python interface for python 2.7 on MS Windows. If
...@@ -15,6 +18,7 @@ ...@@ -15,6 +18,7 @@
import dlib import dlib
def dot(a, b): def dot(a, b):
"Compute the dot product between the two vectors a and b."
return sum(i*j for i,j in zip(a,b)) return sum(i*j for i,j in zip(a,b))
...@@ -23,30 +27,35 @@ class three_class_classifier_problem: ...@@ -23,30 +27,35 @@ class three_class_classifier_problem:
be_verbose = True be_verbose = True
epsilon = 0.0001 epsilon = 0.0001
def __init__(self, samples, labels): def __init__(self, samples, labels):
self.num_samples = len(samples) self.num_samples = len(samples)
self.num_dimensions = len(samples[0])*3 self.num_dimensions = len(samples[0])*3
self.samples = samples self.samples = samples
self.labels = labels self.labels = labels
def make_psi(self, psi, vector, label):
def make_psi(self, vector, label):
psi = dlib.vector()
psi.resize(self.num_dimensions) psi.resize(self.num_dimensions)
dims = len(vector) dims = len(vector)
if (label == 1): if (label == 0):
for i in range(0,dims): for i in range(0,dims):
psi[i] = vector[i] psi[i] = vector[i]
elif (label == 2): elif (label == 1):
for i in range(dims,2*dims): for i in range(dims,2*dims):
psi[i] = vector[i-dims] psi[i] = vector[i-dims]
else: else: # the label must be 2
for i in range(2*dims,3*dims): for i in range(2*dims,3*dims):
psi[i] = vector[i-2*dims] psi[i] = vector[i-2*dims]
return psi
def get_truth_joint_feature_vector(self, idx):
return self.make_psi(self.samples[idx], self.labels[idx])
def get_truth_joint_feature_vector(self, idx, psi):
self.make_psi(psi, self.samples[idx], self.labels[idx])
def separation_oracle(self, idx, current_solution, psi): def separation_oracle(self, idx, current_solution):
samp = samples[idx] samp = samples[idx]
dims = len(samp) dims = len(samp)
scores = [0,0,0] scores = [0,0,0]
...@@ -56,29 +65,28 @@ class three_class_classifier_problem: ...@@ -56,29 +65,28 @@ class three_class_classifier_problem:
scores[2] = dot(current_solution[2*dims:3*dims], samp) scores[2] = dot(current_solution[2*dims:3*dims], samp)
# Add in the loss-augmentation # Add in the loss-augmentation
if (labels[idx] != 1): if (labels[idx] != 0):
scores[0] += 1 scores[0] += 1
if (labels[idx] != 2): if (labels[idx] != 1):
scores[1] += 1 scores[1] += 1
if (labels[idx] != 3): if (labels[idx] != 2):
scores[2] += 1 scores[2] += 1
# Now figure out which classifier has the largest loss-augmented score. # Now figure out which classifier has the largest loss-augmented score.
max_scoring_label = scores.index(max(scores))+1 max_scoring_label = scores.index(max(scores))
if (max_scoring_label == labels[idx]): if (max_scoring_label == labels[idx]):
loss = 0 loss = 0
else: else:
loss = 1 loss = 1
self.make_psi(psi, samp, max_scoring_label) psi = self.make_psi(samp, max_scoring_label)
return loss return loss,psi
samples = [ [0,0,1], [0,1,0], [1,0,0]]; samples = [[0,0,1], [0,1,0], [1,0,0]];
labels = [1, 2, 3] labels = [0,1,2]
problem = three_class_classifier_problem(samples, labels) problem = three_class_classifier_problem(samples, labels)
weights = dlib.solve_structural_svm_problem(problem) weights = dlib.solve_structural_svm_problem(problem)
......
...@@ -37,7 +37,7 @@ public: ...@@ -37,7 +37,7 @@ public:
feature_vector_type& psi feature_vector_type& psi
) const ) const
{ {
problem.attr("get_truth_joint_feature_vector")(idx,boost::ref(psi)); psi = extract<feature_vector_type&>(problem.attr("get_truth_joint_feature_vector")(idx));
} }
virtual void separation_oracle ( virtual void separation_oracle (
...@@ -47,7 +47,19 @@ public: ...@@ -47,7 +47,19 @@ public:
feature_vector_type& psi feature_vector_type& psi
) const ) const
{ {
loss = extract<double>(problem.attr("separation_oracle")(idx,boost::ref(current_solution),boost::ref(psi))); object res = problem.attr("separation_oracle")(idx,boost::ref(current_solution));
pyassert(len(res) == 2, "separation_oracle() must return two objects, the loss and the psi vector");
// let the user supply the output arguments in any order.
if (extract<double>(res[0]).check())
{
loss = extract<double>(res[0]);
psi = extract<feature_vector_type&>(res[1]);
}
else
{
psi = extract<feature_vector_type&>(res[0]);
loss = extract<double>(res[1]);
}
} }
private: private:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment