Commit 2d57a754 authored by Davis King's avatar Davis King

Gave the shape_predictor_trainer the ability to learn from datasets where

some landmarks are missing.
parent 3286a0a3
...@@ -563,6 +563,7 @@ namespace dlib ...@@ -563,6 +563,7 @@ namespace dlib
// make sure the objects agree on the number of parts and that there is at // make sure the objects agree on the number of parts and that there is at
// least one full_object_detection. // least one full_object_detection.
unsigned long num_parts = 0; unsigned long num_parts = 0;
std::vector<int> part_present;
for (unsigned long i = 0; i < objects.size(); ++i) for (unsigned long i = 0; i < objects.size(); ++i)
{ {
for (unsigned long j = 0; j < objects[i].size(); ++j) for (unsigned long j = 0; j < objects[i].size(); ++j)
...@@ -574,6 +575,7 @@ namespace dlib ...@@ -574,6 +575,7 @@ namespace dlib
"\t shape_predictor shape_predictor_trainer::train()" "\t shape_predictor shape_predictor_trainer::train()"
<< "\n\t You can't give objects that don't have any parts to the trainer." << "\n\t You can't give objects that don't have any parts to the trainer."
); );
part_present.resize(num_parts);
} }
else else
{ {
...@@ -584,12 +586,22 @@ namespace dlib ...@@ -584,12 +586,22 @@ namespace dlib
<< "\n\t num_parts: " << num_parts << "\n\t num_parts: " << num_parts
); );
} }
for (unsigned long p = 0; p < objects[i][j].num_parts(); ++p)
{
if (objects[i][j].part(p) != OBJECT_PART_NOT_PRESENT)
part_present[p] = 1;
}
} }
} }
DLIB_CASSERT(num_parts != 0, DLIB_CASSERT(num_parts != 0,
"\t shape_predictor shape_predictor_trainer::train()" "\t shape_predictor shape_predictor_trainer::train()"
<< "\n\t You must give at least one full_object_detection if you want to train a shape model and it must have parts." << "\n\t You must give at least one full_object_detection if you want to train a shape model and it must have parts."
); );
DLIB_CASSERT(sum(mat(part_present)) == (long)num_parts,
"\t shape_predictor shape_predictor_trainer::train()"
<< "\n\t Each part must appear at least once in this training data. That is, "
<< "\n\t you can't have a part that is always set to OBJECT_PART_NOT_PRESENT."
);
...@@ -646,19 +658,33 @@ namespace dlib ...@@ -646,19 +658,33 @@ namespace dlib
private: private:
static matrix<float,0,1> object_to_shape ( static void object_to_shape (
const full_object_detection& obj const full_object_detection& obj,
matrix<float,0,1>& shape,
matrix<float,0,1>& present // a mask telling which elements of #shape are present.
) )
{ {
matrix<float,0,1> shape(obj.num_parts()*2); shape.set_size(obj.num_parts()*2);
present.set_size(obj.num_parts()*2);
const point_transform_affine tform_from_img = impl::normalizing_tform(obj.get_rect()); const point_transform_affine tform_from_img = impl::normalizing_tform(obj.get_rect());
for (unsigned long i = 0; i < obj.num_parts(); ++i) for (unsigned long i = 0; i < obj.num_parts(); ++i)
{ {
vector<float,2> p = tform_from_img(obj.part(i)); if (obj.part(i) != OBJECT_PART_NOT_PRESENT)
shape(2*i) = p.x(); {
shape(2*i+1) = p.y(); vector<float,2> p = tform_from_img(obj.part(i));
shape(2*i) = p.x();
shape(2*i+1) = p.y();
present(2*i) = 1;
present(2*i+1) = 1;
}
else
{
shape(2*i) = 0;
shape(2*i+1) = 0;
present(2*i) = 0;
present(2*i+1) = 0;
}
} }
return shape;
} }
struct training_sample struct training_sample
...@@ -671,7 +697,9 @@ namespace dlib ...@@ -671,7 +697,9 @@ namespace dlib
pixel when you look it up relative to the shape in current_shape. pixel when you look it up relative to the shape in current_shape.
- target_shape == The truth shape. Stays constant during the whole - target_shape == The truth shape. Stays constant during the whole
training process. training process (except for the parts that are not present, those are
always equal to the current_shape values).
- present == 0/1 mask saying which parts of target_shape are present.
- rect == the position of the object in the image_idx-th image. All shape - rect == the position of the object in the image_idx-th image. All shape
coordinates are coded relative to this rectangle. coordinates are coded relative to this rectangle.
!*/ !*/
...@@ -679,6 +707,7 @@ namespace dlib ...@@ -679,6 +707,7 @@ namespace dlib
unsigned long image_idx; unsigned long image_idx;
rectangle rect; rectangle rect;
matrix<float,0,1> target_shape; matrix<float,0,1> target_shape;
matrix<float,0,1> present;
matrix<float,0,1> current_shape; matrix<float,0,1> current_shape;
std::vector<float> feature_pixel_values; std::vector<float> feature_pixel_values;
...@@ -688,6 +717,7 @@ namespace dlib ...@@ -688,6 +717,7 @@ namespace dlib
std::swap(image_idx, item.image_idx); std::swap(image_idx, item.image_idx);
std::swap(rect, item.rect); std::swap(rect, item.rect);
target_shape.swap(item.target_shape); target_shape.swap(item.target_shape);
present.swap(item.present);
current_shape.swap(item.current_shape); current_shape.swap(item.current_shape);
feature_pixel_values.swap(item.feature_pixel_values); feature_pixel_values.swap(item.feature_pixel_values);
} }
...@@ -727,17 +757,38 @@ namespace dlib ...@@ -727,17 +757,38 @@ namespace dlib
// Now all the parts contain the ranges for the leaves so we can use them to // Now all the parts contain the ranges for the leaves so we can use them to
// compute the average leaf values. // compute the average leaf values.
matrix<float,0,1> present_counts(samples[0].target_shape.size());
tree.leaf_values.resize(parts.size()); tree.leaf_values.resize(parts.size());
for (unsigned long i = 0; i < parts.size(); ++i) for (unsigned long i = 0; i < parts.size(); ++i)
{ {
// Get the present counts for each dimension so we can divide each
// dimension by the number of observations we have on it to find the mean
// displacement in each leaf.
present_counts = 0;
for (unsigned long j = parts[i].first; j < parts[i].second; ++j)
present_counts += samples[j].present;
present_counts = dlib::reciprocal(present_counts);
if (parts[i].second != parts[i].first) if (parts[i].second != parts[i].first)
tree.leaf_values[i] = sums[num_split_nodes+i]*get_nu()/(parts[i].second - parts[i].first); tree.leaf_values[i] = pointwise_multiply(present_counts,sums[num_split_nodes+i]*get_nu());
else else
tree.leaf_values[i] = zeros_matrix(samples[0].target_shape); tree.leaf_values[i] = zeros_matrix(samples[0].target_shape);
// now adjust the current shape based on these predictions // now adjust the current shape based on these predictions
for (unsigned long j = parts[i].first; j < parts[i].second; ++j) for (unsigned long j = parts[i].first; j < parts[i].second; ++j)
{
samples[j].current_shape += tree.leaf_values[i]; samples[j].current_shape += tree.leaf_values[i];
// For parts that aren't present in the training data, we just make
// sure that the target shape always matches and therefore gives zero
// error. So this makes the algorithm simply ignore non-present
// landmarks.
for (long k = 0; k < samples[j].present.size(); ++k)
{
// if this part is not present
if (samples[j].present(k) == 0)
samples[j].target_shape(k) = samples[j].current_shape(k);
}
}
} }
return tree; return tree;
...@@ -867,7 +918,7 @@ namespace dlib ...@@ -867,7 +918,7 @@ namespace dlib
{ {
samples.clear(); samples.clear();
matrix<float,0,1> mean_shape; matrix<float,0,1> mean_shape;
long count = 0; matrix<float,0,1> count;
// first fill out the target shapes // first fill out the target shapes
for (unsigned long i = 0; i < objects.size(); ++i) for (unsigned long i = 0; i < objects.size(); ++i)
{ {
...@@ -876,15 +927,15 @@ namespace dlib ...@@ -876,15 +927,15 @@ namespace dlib
training_sample sample; training_sample sample;
sample.image_idx = i; sample.image_idx = i;
sample.rect = objects[i][j].get_rect(); sample.rect = objects[i][j].get_rect();
sample.target_shape = object_to_shape(objects[i][j]); object_to_shape(objects[i][j], sample.target_shape, sample.present);
for (unsigned long itr = 0; itr < get_oversampling_amount(); ++itr) for (unsigned long itr = 0; itr < get_oversampling_amount(); ++itr)
samples.push_back(sample); samples.push_back(sample);
mean_shape += sample.target_shape; mean_shape += sample.target_shape;
++count; count += sample.present;
} }
} }
mean_shape /= count; mean_shape = pointwise_multiply(mean_shape,reciprocal(count));
// now go pick random initial shapes // now go pick random initial shapes
for (unsigned long i = 0; i < samples.size(); ++i) for (unsigned long i = 0; i < samples.size(); ++i)
...@@ -897,12 +948,35 @@ namespace dlib ...@@ -897,12 +948,35 @@ namespace dlib
} }
else else
{ {
// Pick a random convex combination of two of the target shapes and use samples[i].current_shape.set_size(0);
// that as the initial shape for this sample.
const unsigned long rand_idx = rnd.get_random_32bit_number()%samples.size(); matrix<float,0,1> hits(mean_shape.size());
const unsigned long rand_idx2 = rnd.get_random_32bit_number()%samples.size(); hits = 0;
const double alpha = rnd.get_random_double();
samples[i].current_shape = alpha*samples[rand_idx].target_shape + (1-alpha)*samples[rand_idx2].target_shape; int iter = 0;
// Pick a few samples at random and randomly average them together to
// make the initial shape. Note that we make sure we get at least one
// observation (i.e. non-OBJECT_PART_NOT_PRESENT) on each part
// location.
while(min(hits) == 0 || iter < 2)
{
++iter;
const unsigned long rand_idx = rnd.get_random_32bit_number()%samples.size();
const double alpha = rnd.get_random_double()+0.1;
samples[i].current_shape += alpha*samples[rand_idx].target_shape;
hits += alpha*samples[rand_idx].present;
}
samples[i].current_shape = pointwise_multiply(samples[i].current_shape, reciprocal(hits));
}
}
for (unsigned long i = 0; i < samples.size(); ++i)
{
for (long k = 0; k < samples[i].present.size(); ++k)
{
// if this part is not present
if (samples[i].present(k) == 0)
samples[i].target_shape(k) = samples[i].current_shape(k);
} }
} }
...@@ -1029,8 +1103,11 @@ namespace dlib ...@@ -1029,8 +1103,11 @@ namespace dlib
for (unsigned long k = 0; k < det.num_parts(); ++k) for (unsigned long k = 0; k < det.num_parts(); ++k)
{ {
double score = length(det.part(k) - objects[i][j].part(k))/scale; if (objects[i][j].part(k) != OBJECT_PART_NOT_PRESENT)
rs.add(score); {
double score = length(det.part(k) - objects[i][j].part(k))/scale;
rs.add(score);
}
} }
} }
} }
......
...@@ -359,6 +359,9 @@ namespace dlib ...@@ -359,6 +359,9 @@ namespace dlib
- images.size() > 0 - images.size() > 0
- for some i: objects[i].size() != 0 - for some i: objects[i].size() != 0
(i.e. there has to be at least one full_object_detection in the training set) (i.e. there has to be at least one full_object_detection in the training set)
- for all valid p, there must exist i and j such that:
objects[i][j].part(p) != OBJECT_PART_NOT_PRESENT.
(i.e. You can't define a part that is always set to OBJECT_PART_NOT_PRESENT.)
- for all valid i,j,k,l: - for all valid i,j,k,l:
- objects[i][j].num_parts() == objects[k][l].num_parts() - objects[i][j].num_parts() == objects[k][l].num_parts()
(i.e. all objects must agree on the number of parts) (i.e. all objects must agree on the number of parts)
...@@ -370,6 +373,10 @@ namespace dlib ...@@ -370,6 +373,10 @@ namespace dlib
shape_predictor, SP, such that: shape_predictor, SP, such that:
SP(images[i], objects[i][j].get_rect()) == objects[i][j] SP(images[i], objects[i][j].get_rect()) == objects[i][j]
This learned SP object is then returned. This learned SP object is then returned.
- Not all parts are required to be observed for all objects. So if you
have training instances with missing parts then set the part positions
equal to OBJECT_PART_NOT_PRESENT and this algorithm will basically ignore
those missing parts.
!*/ !*/
}; };
...@@ -408,6 +415,8 @@ namespace dlib ...@@ -408,6 +415,8 @@ namespace dlib
and compare the result with the truth part positions in objects[i][j]. We and compare the result with the truth part positions in objects[i][j]. We
then return the average distance (measured in pixels) between a predicted then return the average distance (measured in pixels) between a predicted
part location and its true position. part location and its true position.
- Note that any parts in objects that are set to OBJECT_PART_NOT_PRESENT are
simply ignored.
- if (scales.size() != 0) then - if (scales.size() != 0) then
- Each time we compute the distance between a predicted part location and - Each time we compute the distance between a predicted part location and
its true location in objects[i][j] we divide the distance by its true location in objects[i][j] we divide the distance by
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment