Commit 44fd6f42 authored by Davis King's avatar Davis King

refined example

parent aa46752d
...@@ -152,7 +152,7 @@ class three_class_classifier_problem: ...@@ -152,7 +152,7 @@ class three_class_classifier_problem:
# There are also a number of optional arguments: # There are also a number of optional arguments:
# epsilon is the stopping tolerance. The optimizer will run until R(w) is within # epsilon is the stopping tolerance. The optimizer will run until R(w) is within
# epsilon of its optimal value. If you don't set this then it defaults to 0.001 # epsilon of its optimal value. If you don't set this then it defaults to 0.001.
#epsilon = 1e-13 #epsilon = 1e-13
# Uncomment this and the optimizer will print its progress to standard out. You will # Uncomment this and the optimizer will print its progress to standard out. You will
...@@ -172,9 +172,9 @@ class three_class_classifier_problem: ...@@ -172,9 +172,9 @@ class three_class_classifier_problem:
def __init__(self, samples, labels): def __init__(self, samples, labels):
# dlib.solve_structural_svm_problem() also expects the class to have num_samples # dlib.solve_structural_svm_problem() expects the class to have num_samples and
# and num_dimensions fields. These fields are expected to contain the number of # num_dimensions fields. These fields should contain the number of training
# training samples and the dimensionality of the psi feature vector respectively. # samples and the dimensionality of the PSI feature vector respectively.
self.num_samples = len(samples) self.num_samples = len(samples)
self.num_dimensions = len(samples[0])*3 self.num_dimensions = len(samples[0])*3
...@@ -237,7 +237,8 @@ class three_class_classifier_problem: ...@@ -237,7 +237,8 @@ class three_class_classifier_problem:
# it the current value of the parameter weights and the separation_oracle() is supposed # it the current value of the parameter weights and the separation_oracle() is supposed
# to find the label that most violates the structural SVM objective function for the # to find the label that most violates the structural SVM objective function for the
# idx-th sample. Then the separation oracle reports the corresponding PSI vector and # idx-th sample. Then the separation oracle reports the corresponding PSI vector and
# loss value. # To be more precise, separation_oracle() has the following contract: # loss value. To state this more precisely, the separation_oracle() member function
# has the following contract:
# requires # requires
# - 0 <= idx < self.num_samples # - 0 <= idx < self.num_samples
# - len(current_solution) == self.num_dimensions # - len(current_solution) == self.num_dimensions
...@@ -266,6 +267,9 @@ class three_class_classifier_problem: ...@@ -266,6 +267,9 @@ class three_class_classifier_problem:
# Add in the loss-augmentation. Recall that we maximize LOSS(idx,y) + F(X,y) in # Add in the loss-augmentation. Recall that we maximize LOSS(idx,y) + F(X,y) in
# the separate oracle, not just F(X,y) as we normally would in predict_label(). # the separate oracle, not just F(X,y) as we normally would in predict_label().
# Therefore, we must add in this extra amount to account for the loss-augmentation.
# For our simple multi-class classifier, we incur a loss of 1 if we don't predict
# the correct label and a loss of 0 if we get the right label.
if (self.labels[idx] != 0): if (self.labels[idx] != 0):
scores[0] += 1 scores[0] += 1
if (self.labels[idx] != 1): if (self.labels[idx] != 1):
...@@ -275,8 +279,8 @@ class three_class_classifier_problem: ...@@ -275,8 +279,8 @@ class three_class_classifier_problem:
# Now figure out which classifier has the largest loss-augmented score. # Now figure out which classifier has the largest loss-augmented score.
max_scoring_label = scores.index(max(scores)) max_scoring_label = scores.index(max(scores))
# We incur a loss of 1 if we don't predict the correct label and a loss of 0 if we # And finally record the loss that was associated with that predicted label.
# get the right answer. # Again, the loss is 1 if the label is incorrect and 0 otherwise.
if (max_scoring_label == self.labels[idx]): if (max_scoring_label == self.labels[idx]):
loss = 0 loss = 0
else: else:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment