Commit 44fd6f42 authored by Davis King's avatar Davis King

refined example

parent aa46752d
......@@ -152,7 +152,7 @@ class three_class_classifier_problem:
# There are also a number of optional arguments:
# epsilon is the stopping tolerance. The optimizer will run until R(w) is within
# epsilon of its optimal value. If you don't set this then it defaults to 0.001
# epsilon of its optimal value. If you don't set this then it defaults to 0.001.
#epsilon = 1e-13
# Uncomment this and the optimizer will print its progress to standard out. You will
......@@ -172,9 +172,9 @@ class three_class_classifier_problem:
def __init__(self, samples, labels):
# dlib.solve_structural_svm_problem() also expects the class to have num_samples
# and num_dimensions fields. These fields are expected to contain the number of
# training samples and the dimensionality of the psi feature vector respectively.
# dlib.solve_structural_svm_problem() expects the class to have num_samples and
# num_dimensions fields. These fields should contain the number of training
# samples and the dimensionality of the PSI feature vector respectively.
self.num_samples = len(samples)
self.num_dimensions = len(samples[0])*3
......@@ -237,7 +237,8 @@ class three_class_classifier_problem:
# it the current value of the parameter weights and the separation_oracle() is supposed
# to find the label that most violates the structural SVM objective function for the
# idx-th sample. Then the separation oracle reports the corresponding PSI vector and
# loss value. # To be more precise, separation_oracle() has the following contract:
# loss value. To state this more precisely, the separation_oracle() member function
# has the following contract:
# requires
# - 0 <= idx < self.num_samples
# - len(current_solution) == self.num_dimensions
......@@ -264,8 +265,11 @@ class three_class_classifier_problem:
scores[1] = dot(current_solution[dims:2*dims], samp)
scores[2] = dot(current_solution[2*dims:3*dims], samp)
# Add in the loss-augmentation. Recall that we maximize LOSS(idx,y) + F(X,y) in
# Add in the loss-augmentation. Recall that we maximize LOSS(idx,y) + F(X,y) in
# the separate oracle, not just F(X,y) as we normally would in predict_label().
# Therefore, we must add in this extra amount to account for the loss-augmentation.
# For our simple multi-class classifier, we incur a loss of 1 if we don't predict
# the correct label and a loss of 0 if we get the right label.
if (self.labels[idx] != 0):
scores[0] += 1
if (self.labels[idx] != 1):
......@@ -275,8 +279,8 @@ class three_class_classifier_problem:
# Now figure out which classifier has the largest loss-augmented score.
max_scoring_label = scores.index(max(scores))
# We incur a loss of 1 if we don't predict the correct label and a loss of 0 if we
# get the right answer.
# And finally record the loss that was associated with that predicted label.
# Again, the loss is 1 if the label is incorrect and 0 otherwise.
if (max_scoring_label == self.labels[idx]):
loss = 0
else:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment