Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in
Toggle navigation
D
dlib
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
钟尚武
dlib
Commits
3eb0d973
Commit
3eb0d973
authored
Sep 15, 2011
by
Davis King
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Added the cross_validate_object_detection_trainer() and test_object_detection_function()
routines.
parent
0aa89e07
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
332 additions
and
0 deletions
+332
-0
svm.h
dlib/svm.h
+1
-0
cross_validate_object_detection_trainer.h
dlib/svm/cross_validate_object_detection_trainer.h
+242
-0
cross_validate_object_detection_trainer_abstract.h
dlib/svm/cross_validate_object_detection_trainer_abstract.h
+89
-0
No files found.
dlib/svm.h
View file @
3eb0d973
...
...
@@ -33,6 +33,7 @@
#include "svm/multiclass_tools.h"
#include "svm/cross_validate_multiclass_trainer.h"
#include "svm/cross_validate_regression_trainer.h"
#include "svm/cross_validate_object_detection_trainer.h"
#include "svm/one_vs_all_decision_function.h"
#include "svm/one_vs_all_trainer.h"
...
...
dlib/svm/cross_validate_object_detection_trainer.h
0 → 100644
View file @
3eb0d973
// Copyright (C) 2011 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#ifndef DLIB_CROSS_VALIDATE_OBJECT_DETECTION_TRaINER_H__
#define DLIB_CROSS_VALIDATE_OBJECT_DETECTION_TRaINER_H__
#include "cross_validate_object_detection_trainer_abstract.h"
#include <vector>
#include "../matrix.h"
#include "svm.h"
#include "../geometry.h"
namespace
dlib
{
// ----------------------------------------------------------------------------------------
namespace
impl
{
unsigned
long
number_of_truth_hits
(
const
std
::
vector
<
rectangle
>&
truth_boxes
,
const
std
::
vector
<
rectangle
>&
boxes
,
const
double
overlap_eps
)
/*!
requires
- 0 < overlap_eps <= 1
ensures
- returns the number of elements in truth_boxes which are overlapped by an
element of boxes. In this context, two boxes, A and B, overlap if and only if
the following quantity is greater than overlap_eps:
A.intersect(B).area()/(A+B).area()
- No element of boxes is allowed to account for more than one element of truth_boxes.
- The returned number is in the range [0,truth_boxes.size()]
!*/
{
if
(
boxes
.
size
()
==
0
)
return
0
;
unsigned
long
count
=
0
;
std
::
vector
<
bool
>
used
(
boxes
.
size
(),
false
);
for
(
unsigned
long
i
=
0
;
i
<
truth_boxes
.
size
();
++
i
)
{
unsigned
long
best_idx
=
0
;
double
best_overlap
=
0
;
for
(
unsigned
long
j
=
0
;
j
<
boxes
.
size
();
++
j
)
{
if
(
used
[
j
])
continue
;
const
double
overlap
=
truth_boxes
[
i
].
intersect
(
boxes
[
j
]).
area
()
/
(
double
)(
truth_boxes
[
i
]
+
boxes
[
j
]).
area
();
if
(
overlap
>
best_overlap
)
{
best_overlap
=
overlap
;
best_idx
=
j
;
}
}
if
(
best_overlap
>
overlap_eps
&&
used
[
best_idx
]
==
false
)
{
used
[
best_idx
]
=
true
;
++
count
;
}
}
return
count
;
}
}
// ----------------------------------------------------------------------------------------
template
<
typename
object_detector_type
,
typename
image_array_type
>
const
matrix
<
double
,
1
,
2
>
test_object_detection_function
(
const
object_detector_type
&
detector
,
const
image_array_type
&
images
,
const
std
::
vector
<
std
::
vector
<
rectangle
>
>&
truth_rects
,
const
double
overlap_eps
=
0
.
5
)
{
// make sure requires clause is not broken
DLIB_ASSERT
(
is_learning_problem
(
images
,
truth_rects
)
==
true
&&
0
<
overlap_eps
&&
overlap_eps
<=
1
,
"
\t
matrix test_object_detection_function()"
<<
"
\n\t
invalid inputs were given to this function"
<<
"
\n\t
is_learning_problem(images,truth_rects): "
<<
is_learning_problem
(
images
,
truth_rects
)
<<
"
\n\t
overlap_eps: "
<<
overlap_eps
);
double
correct_hits
=
0
;
double
total_hits
=
0
;
double
total_true_targets
=
0
;
for
(
unsigned
long
i
=
0
;
i
<
images
.
size
();
++
i
)
{
const
std
::
vector
<
rectangle
>&
hits
=
detector
(
images
[
i
]);
total_hits
+=
hits
.
size
();
correct_hits
+=
impl
::
number_of_truth_hits
(
truth_rects
[
i
],
hits
,
overlap_eps
);
total_true_targets
+=
truth_rects
[
i
].
size
();
}
double
precision
,
recall
;
if
(
total_hits
==
0
)
precision
=
1
;
else
precision
=
correct_hits
/
total_hits
;
if
(
total_true_targets
==
0
)
recall
=
1
;
else
recall
=
correct_hits
/
total_true_targets
;
matrix
<
double
,
1
,
2
>
res
;
res
=
precision
,
recall
;
return
res
;
}
// ----------------------------------------------------------------------------------------
namespace
impl
{
template
<
typename
array_type
>
struct
array_subset_helper
{
array_subset_helper
(
const
array_type
&
array_
,
const
std
::
vector
<
unsigned
long
>&
idx_set_
)
:
array
(
array_
),
idx_set
(
idx_set_
)
{
}
unsigned
long
size
()
const
{
return
idx_set
.
size
();
}
typedef
typename
array_type
::
type
type
;
const
type
&
operator
[]
(
unsigned
long
idx
)
const
{
return
array
[
idx_set
[
idx
]];
}
private
:
const
array_type
&
array
;
const
std
::
vector
<
unsigned
long
>&
idx_set
;
};
}
// ----------------------------------------------------------------------------------------
template
<
typename
trainer_type
,
typename
image_array_type
>
const
matrix
<
double
,
1
,
2
>
cross_validate_object_detection_trainer
(
const
trainer_type
&
trainer
,
const
image_array_type
&
images
,
const
std
::
vector
<
std
::
vector
<
rectangle
>
>&
truth_rects
,
const
long
folds
,
const
double
overlap_eps
=
0
.
5
)
{
// make sure requires clause is not broken
DLIB_ASSERT
(
is_learning_problem
(
images
,
truth_rects
)
==
true
&&
0
<
overlap_eps
&&
overlap_eps
<=
1
&&
1
<
folds
&&
folds
<=
images
.
size
(),
"
\t
matrix cross_validate_object_detection_trainer()"
<<
"
\n\t
invalid inputs were given to this function"
<<
"
\n\t
is_learning_problem(images,truth_rects): "
<<
is_learning_problem
(
images
,
truth_rects
)
<<
"
\n\t
overlap_eps: "
<<
overlap_eps
<<
"
\n\t
folds: "
<<
folds
);
double
correct_hits
=
0
;
double
total_hits
=
0
;
double
total_true_targets
=
0
;
const
long
test_size
=
images
.
size
()
/
folds
;
unsigned
long
test_idx
=
0
;
for
(
long
iter
=
0
;
iter
<
folds
;
++
iter
)
{
std
::
vector
<
unsigned
long
>
train_idx_set
;
std
::
vector
<
unsigned
long
>
test_idx_set
;
for
(
unsigned
long
i
=
0
;
i
<
test_size
;
++
i
)
test_idx_set
.
push_back
(
test_idx
++
);
unsigned
long
train_idx
=
test_idx
%
images
.
size
();
std
::
vector
<
std
::
vector
<
rectangle
>
>
training_rects
;
for
(
unsigned
long
i
=
0
;
i
<
images
.
size
()
-
test_size
;
++
i
)
{
training_rects
.
push_back
(
truth_rects
[
train_idx
]);
train_idx_set
.
push_back
(
train_idx
);
train_idx
=
(
train_idx
+
1
)
%
images
.
size
();
}
impl
::
array_subset_helper
<
image_array_type
>
array_subset
(
images
,
train_idx_set
);
const
typename
trainer_type
::
trained_function_type
&
detector
=
trainer
.
train
(
array_subset
,
training_rects
);
for
(
unsigned
long
i
=
0
;
i
<
test_idx_set
.
size
();
++
i
)
{
const
std
::
vector
<
rectangle
>&
hits
=
detector
(
images
[
test_idx_set
[
i
]]);
total_hits
+=
hits
.
size
();
correct_hits
+=
impl
::
number_of_truth_hits
(
truth_rects
[
test_idx_set
[
i
]],
hits
,
overlap_eps
);
total_true_targets
+=
truth_rects
[
test_idx_set
[
i
]].
size
();
}
}
double
precision
,
recall
;
if
(
total_hits
==
0
)
precision
=
1
;
else
precision
=
correct_hits
/
total_hits
;
if
(
total_true_targets
==
0
)
recall
=
1
;
else
recall
=
correct_hits
/
total_true_targets
;
matrix
<
double
,
1
,
2
>
res
;
res
=
precision
,
recall
;
return
res
;
}
// ----------------------------------------------------------------------------------------
}
#endif // DLIB_CROSS_VALIDATE_OBJECT_DETECTION_TRaINER_H__
dlib/svm/cross_validate_object_detection_trainer_abstract.h
0 → 100644
View file @
3eb0d973
// Copyright (C) 2011 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#undef DLIB_CROSS_VALIDATE_OBJECT_DETECTION_TRaINER_ABSTRACT_H__
#ifdef DLIB_CROSS_VALIDATE_OBJECT_DETECTION_TRaINER_ABSTRACT_H__
#include <vector>
#include "../matrix.h"
#include "../geometry.h"
namespace
dlib
{
// ----------------------------------------------------------------------------------------
template
<
typename
object_detector_type
,
typename
image_array_type
>
const
matrix
<
double
,
1
,
2
>
test_object_detection_function
(
const
object_detector_type
&
detector
,
const
image_array_type
&
images
,
const
std
::
vector
<
std
::
vector
<
rectangle
>
>&
truth_rects
,
const
double
overlap_eps
=
0
.
5
);
/*!
requires
- is_learning_problem(images,truth_rects)
- 0 < overlap_eps <= 1
- object_detector_type == some kind of object detector function object
(e.g. object_detector)
- image_array_type must be an implementation of dlib/array/array_kernel_abstract.h
and it must contain objects which can be accepted by detector().
ensures
- Tests the given detector against the supplied object detection problem
and returns the precision and recall. Note that the task is to predict,
for each images[i], the set of object locations given by truth_rects[i].
- In particular, returns a matrix M such that:
- M(0) == the precision of the detector object. This is a number
in the range [0,1] which measures the fraction of detector outputs
which correspond to a real target. A value of 1 means the detector
never produces any false alarms while a value of 0 means it only
produces false alarms.
- M(1) == the recall of the detector object. This is a number in the
range [0,1] which measure the fraction of targets found by the
detector. A value of 1 means the detector found all the targets
in truth_rects while a value of 0 means the detector didn't locate
any of the targets.
- The rule for deciding if a detector output, D, matches a truth rectangle,
T, is the following:
T and R match if and only if: T.intersect(R).area()/(T+R).area() > overlap_eps
!*/
// ----------------------------------------------------------------------------------------
template
<
typename
trainer_type
,
typename
image_array_type
>
const
matrix
<
double
,
1
,
2
>
cross_validate_object_detection_trainer
(
const
trainer_type
&
trainer
,
const
image_array_type
&
images
,
const
std
::
vector
<
std
::
vector
<
rectangle
>
>&
truth_rects
,
const
long
folds
,
const
double
overlap_eps
=
0
.
5
);
/*!
requires
- is_learning_problem(images,truth_rects)
- 0 < overlap_eps <= 1
- 1 < folds <= images.size()
- trainer_type == some kind of object detection trainer (e.g structural_object_detection_trainer)
- image_array_type must be an implementation of dlib/array/array_kernel_abstract.h
and it must contain objects which can be accepted by detector().
ensures
- Performs k-fold cross-validation by using the given trainer to solve an
object detection problem for the given number of folds. Each fold is tested
using the output of the trainer and a matrix summarizing the results is
returned. The matrix contains the precision and recall of the trained
detectors and is defined identically to the test_object_detection_function()
routine defined at the top of this file.
!*/
// ----------------------------------------------------------------------------------------
}
#endif // DLIB_CROSS_VALIDATE_OBJECT_DETECTION_TRaINER_ABSTRACT_H__
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment