Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in
Toggle navigation
D
dlib
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
钟尚武
dlib
Commits
a73e7659
Commit
a73e7659
authored
Mar 31, 2013
by
Davis King
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Made the object detector validation functions also output the mean average
precision measure.
parent
230dc754
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
89 additions
and
25 deletions
+89
-25
cross_validate_object_detection_trainer.h
dlib/svm/cross_validate_object_detection_trainer.h
+66
-11
cross_validate_object_detection_trainer_abstract.h
dlib/svm/cross_validate_object_detection_trainer_abstract.h
+23
-14
No files found.
dlib/svm/cross_validate_object_detection_trainer.h
View file @
a73e7659
...
...
@@ -9,6 +9,7 @@
#include "svm.h"
#include "../geometry.h"
#include "../image_processing/full_object_detection.h"
#include "../statistics.h"
namespace
dlib
{
...
...
@@ -20,7 +21,8 @@ namespace dlib
inline
unsigned
long
number_of_truth_hits
(
const
std
::
vector
<
full_object_detection
>&
truth_boxes
,
const
std
::
vector
<
rectangle
>&
boxes
,
const
double
overlap_eps
const
double
overlap_eps
,
double
&
ap
)
/*!
requires
...
...
@@ -32,10 +34,19 @@ namespace dlib
A.intersect(B).area()/(A+B).area()
- No element of boxes is allowed to account for more than one element of truth_boxes.
- The returned number is in the range [0,truth_boxes.size()]
- ap == the average precision of the given ordering in boxes relative
to truth_boxes.
!*/
{
if
(
boxes
.
size
()
==
0
)
{
if
(
truth_boxes
.
size
()
==
0
)
ap
=
1
;
else
ap
=
0
;
return
0
;
}
unsigned
long
count
=
0
;
std
::
vector
<
bool
>
used
(
boxes
.
size
(),
false
);
...
...
@@ -63,8 +74,20 @@ namespace dlib
}
}
ap
=
average_precision
(
used
,
truth_boxes
.
size
()
-
count
);
return
count
;
}
inline
unsigned
long
number_of_truth_hits
(
const
std
::
vector
<
full_object_detection
>&
truth_boxes
,
const
std
::
vector
<
rectangle
>&
boxes
,
const
double
overlap_eps
)
{
double
ap
;
return
number_of_truth_hits
(
truth_boxes
,
boxes
,
overlap_eps
,
ap
);
}
}
// ----------------------------------------------------------------------------------------
...
...
@@ -73,7 +96,7 @@ namespace dlib
typename
object_detector_type
,
typename
image_array_type
>
const
matrix
<
double
,
1
,
2
>
test_object_detection_function
(
const
matrix
<
double
,
1
,
3
>
test_object_detection_function
(
object_detector_type
&
detector
,
const
image_array_type
&
images
,
const
std
::
vector
<
std
::
vector
<
full_object_detection
>
>&
truth_dets
,
...
...
@@ -94,14 +117,30 @@ namespace dlib
double
correct_hits
=
0
;
double
total_hits
=
0
;
double
total_true_targets
=
0
;
running_stats
<
double
>
map
;
for
(
unsigned
long
i
=
0
;
i
<
images
.
size
();
++
i
)
{
const
std
::
vector
<
rectangle
>&
hits
=
detector
(
images
[
i
]);
std
::
vector
<
std
::
pair
<
double
,
rectangle
>
>
all_dets
;
detector
(
images
[
i
],
all_dets
,
-
std
::
numeric_limits
<
double
>::
infinity
());
std
::
vector
<
rectangle
>
hits
;
for
(
unsigned
long
k
=
0
;
k
<
all_dets
.
size
();
++
k
)
{
if
(
all_dets
[
k
].
first
>=
0
)
hits
.
push_back
(
all_dets
[
k
].
second
);
}
total_hits
+=
hits
.
size
();
correct_hits
+=
impl
::
number_of_truth_hits
(
truth_dets
[
i
],
hits
,
overlap_eps
);
total_true_targets
+=
truth_dets
[
i
].
size
();
// now get the average precision
hits
.
clear
();
for
(
unsigned
long
k
=
0
;
k
<
all_dets
.
size
();
++
k
)
hits
.
push_back
(
all_dets
[
k
].
second
);
double
ap
;
impl
::
number_of_truth_hits
(
truth_dets
[
i
],
hits
,
overlap_eps
,
ap
);
map
.
add
(
ap
);
}
...
...
@@ -117,8 +156,8 @@ namespace dlib
else
recall
=
correct_hits
/
total_true_targets
;
matrix
<
double
,
1
,
2
>
res
;
res
=
precision
,
recall
;
matrix
<
double
,
1
,
3
>
res
;
res
=
precision
,
recall
,
map
.
mean
()
;
return
res
;
}
...
...
@@ -126,7 +165,7 @@ namespace dlib
typename
object_detector_type
,
typename
image_array_type
>
const
matrix
<
double
,
1
,
2
>
test_object_detection_function
(
const
matrix
<
double
,
1
,
3
>
test_object_detection_function
(
object_detector_type
&
detector
,
const
image_array_type
&
images
,
const
std
::
vector
<
std
::
vector
<
rectangle
>
>&
truth_dets
,
...
...
@@ -197,7 +236,7 @@ namespace dlib
typename
trainer_type
,
typename
image_array_type
>
const
matrix
<
double
,
1
,
2
>
cross_validate_object_detection_trainer
(
const
matrix
<
double
,
1
,
3
>
cross_validate_object_detection_trainer
(
const
trainer_type
&
trainer
,
const
image_array_type
&
images
,
const
std
::
vector
<
std
::
vector
<
full_object_detection
>
>&
truth_dets
,
...
...
@@ -222,6 +261,7 @@ namespace dlib
const
long
test_size
=
images
.
size
()
/
folds
;
running_stats
<
double
>
map
;
unsigned
long
test_idx
=
0
;
for
(
long
iter
=
0
;
iter
<
folds
;
++
iter
)
{
...
...
@@ -245,11 +285,26 @@ namespace dlib
typename
trainer_type
::
trained_function_type
detector
=
trainer
.
train
(
array_subset
,
training_rects
);
for
(
unsigned
long
i
=
0
;
i
<
test_idx_set
.
size
();
++
i
)
{
const
std
::
vector
<
rectangle
>&
hits
=
detector
(
images
[
test_idx_set
[
i
]]);
std
::
vector
<
std
::
pair
<
double
,
rectangle
>
>
all_dets
;
detector
(
images
[
test_idx_set
[
i
]],
all_dets
,
-
std
::
numeric_limits
<
double
>::
infinity
());
std
::
vector
<
rectangle
>
hits
;
for
(
unsigned
long
k
=
0
;
k
<
all_dets
.
size
();
++
k
)
{
if
(
all_dets
[
k
].
first
>=
0
)
hits
.
push_back
(
all_dets
[
k
].
second
);
}
total_hits
+=
hits
.
size
();
correct_hits
+=
impl
::
number_of_truth_hits
(
truth_dets
[
test_idx_set
[
i
]],
hits
,
overlap_eps
);
total_true_targets
+=
truth_dets
[
test_idx_set
[
i
]].
size
();
// now get the average precision
hits
.
clear
();
for
(
unsigned
long
k
=
0
;
k
<
all_dets
.
size
();
++
k
)
hits
.
push_back
(
all_dets
[
k
].
second
);
double
ap
;
impl
::
number_of_truth_hits
(
truth_dets
[
test_idx_set
[
i
]],
hits
,
overlap_eps
,
ap
);
map
.
add
(
ap
);
}
}
...
...
@@ -268,8 +323,8 @@ namespace dlib
else
recall
=
correct_hits
/
total_true_targets
;
matrix
<
double
,
1
,
2
>
res
;
res
=
precision
,
recall
;
matrix
<
double
,
1
,
3
>
res
;
res
=
precision
,
recall
,
map
.
mean
()
;
return
res
;
}
...
...
@@ -277,7 +332,7 @@ namespace dlib
typename
trainer_type
,
typename
image_array_type
>
const
matrix
<
double
,
1
,
2
>
cross_validate_object_detection_trainer
(
const
matrix
<
double
,
1
,
3
>
cross_validate_object_detection_trainer
(
const
trainer_type
&
trainer
,
const
image_array_type
&
images
,
const
std
::
vector
<
std
::
vector
<
rectangle
>
>&
truth_dets
,
...
...
dlib/svm/cross_validate_object_detection_trainer_abstract.h
View file @
a73e7659
...
...
@@ -17,7 +17,7 @@ namespace dlib
typename
object_detector_type
,
typename
image_array_type
>
const
matrix
<
double
,
1
,
2
>
test_object_detection_function
(
const
matrix
<
double
,
1
,
3
>
test_object_detection_function
(
object_detector_type
&
detector
,
const
image_array_type
&
images
,
const
std
::
vector
<
std
::
vector
<
full_object_detection
>
>&
truth_dets
,
...
...
@@ -32,9 +32,10 @@ namespace dlib
- image_array_type must be an implementation of dlib/array/array_kernel_abstract.h
and it must contain objects which can be accepted by detector().
ensures
- Tests the given detector against the supplied object detection problem
and returns the precision and recall. Note that the task is to predict,
for each images[i], the set of object locations given by truth_dets[i].
- Tests the given detector against the supplied object detection problem and
returns the precision, recall, and mean average precision. Note that the
task is to predict, for each images[i], the set of object locations given by
truth_dets[i].
- In particular, returns a matrix M such that:
- M(0) == the precision of the detector object. This is a number
in the range [0,1] which measures the fraction of detector outputs
...
...
@@ -42,10 +43,18 @@ namespace dlib
never produces any false alarms while a value of 0 means it only
produces false alarms.
- M(1) == the recall of the detector object. This is a number in the
range [0,1] which measure the fraction of targets found by the
range [0,1] which measure
s
the fraction of targets found by the
detector. A value of 1 means the detector found all the targets
in truth_dets while a value of 0 means the detector didn't locate
any of the targets.
- M(2) == the mean average precision of the detector object. This is a
number in the range [0,1] which measures the overall quality of the
detector when the detector is asked to output a ranked listing of all
possible detections. In particular, this is accomplished by setting the
detection threshold such that all possible detections are output. Then
the detections are ordered by their detection score and we use the
average_precision() routine to score each ranked listing, finally setting
M(2) to the mean value over all test images.
- The rule for deciding if a detector output, D, matches a truth rectangle,
T, is the following:
T and R match if and only if: T.intersect(R).area()/(T+R).area() > overlap_eps
...
...
@@ -55,7 +64,7 @@ namespace dlib
typename
object_detector_type
,
typename
image_array_type
>
const
matrix
<
double
,
1
,
2
>
test_object_detection_function
(
const
matrix
<
double
,
1
,
3
>
test_object_detection_function
(
object_detector_type
&
detector
,
const
image_array_type
&
images
,
const
std
::
vector
<
std
::
vector
<
rectangle
>
>&
truth_dets
,
...
...
@@ -77,7 +86,7 @@ namespace dlib
typename
trainer_type
,
typename
image_array_type
>
const
matrix
<
double
,
1
,
2
>
cross_validate_object_detection_trainer
(
const
matrix
<
double
,
1
,
3
>
cross_validate_object_detection_trainer
(
const
trainer_type
&
trainer
,
const
image_array_type
&
images
,
const
std
::
vector
<
std
::
vector
<
full_object_detection
>
>&
truth_dets
,
...
...
@@ -94,19 +103,19 @@ namespace dlib
and it must contain objects which can be accepted by detector().
- it is legal to call trainer.train(images, truth_dets)
ensures
- Performs k-fold cross-validation by using the given trainer to solve an
object detection problem for the given number of folds. Each fold is tested
using the output of the trainer and a matrix summarizing the results is
returned. The matrix contains the precision
and recall of the trained
detectors and is defined identically to the test_object_detection_function()
routine defined at the top of this file.
- Performs k-fold cross-validation by using the given trainer to solve an
object detection problem for the given number of folds. Each fold is tested
using the output of the trainer and a matrix summarizing the results is
returned. The matrix contains the precision
, recall, and mean average
precision of the trained detectors and is defined identically to the
test_object_detection_function()
routine defined at the top of this file.
!*/
template
<
typename
trainer_type
,
typename
image_array_type
>
const
matrix
<
double
,
1
,
2
>
cross_validate_object_detection_trainer
(
const
matrix
<
double
,
1
,
3
>
cross_validate_object_detection_trainer
(
const
trainer_type
&
trainer
,
const
image_array_type
&
images
,
const
std
::
vector
<
std
::
vector
<
rectangle
>
>&
truth_dets
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment