Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in
Toggle navigation
D
dlib
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
钟尚武
dlib
Commits
1aa66674
Commit
1aa66674
authored
Nov 25, 2017
by
Davis King
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Switched this example to use the svm C instead of nu trainer.
parent
0e7e4330
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
25 additions
and
31 deletions
+25
-31
model_selection_ex.cpp
examples/model_selection_ex.cpp
+25
-31
No files found.
examples/model_selection_ex.cpp
View file @
1aa66674
...
...
@@ -78,36 +78,37 @@ int main() try
// Now that we have some data we want to train on it. We are going to train a
// binary SVM with the RBF kernel to classify the data. However, there are two
// parameters to the training. These are the nu and gamma parameters. Our choice
// for these parameters will influence how good the resulting decision function is.
// To test how good a particular choice of these parameters is we can use the
// binary SVM with the RBF kernel to classify the data. However, there are
// three parameters to the training. These are the SVM C parameters for each
// class and the RBF kernel's gamma parameter. Our choice for these
// parameters will influence how good the resulting decision function is. To
// test how good a particular choice of these parameters is we can use the
// cross_validate_trainer() function to perform n-fold cross validation on our
// training data. However, there is a problem with the way we have sampled
our
//
distribution above. The problem is that there is a definite ordering to the
//
samples. That is, the first half of the samples look like they are from a
//
different distribution than the second half. This would screw up the cross
//
validation process, but we can fix it by randomizing the order of the samples
// with the following function call.
// training data. However, there is a problem with the way we have sampled
//
our distribution above. The problem is that there is a definite ordering
//
to the samples. That is, the first half of the samples look like they are
//
from a different distribution than the second half. This would screw up
//
the cross validation process, but we can fix it by randomizing the order of
//
the samples
with the following function call.
randomize_samples
(
samples
,
labels
);
// And now we get to the important bit. Here we define a function,
// cross_validation_score(), that will do the cross-validation we
// mentioned and return a number indicating how good a particular setting
// of gamma
and nu
is.
auto
cross_validation_score
=
[
&
](
const
double
gamma
,
const
double
nu
)
// of gamma
, c1, and c2
is.
auto
cross_validation_score
=
[
&
](
const
double
gamma
,
const
double
c1
,
const
double
c2
)
{
// Make a RBF SVM trainer and tell it what the parameters are supposed to be.
typedef
radial_basis_kernel
<
sample_type
>
kernel_type
;
svm_
nu
_trainer
<
kernel_type
>
trainer
;
svm_
c
_trainer
<
kernel_type
>
trainer
;
trainer
.
set_kernel
(
kernel_type
(
gamma
));
trainer
.
set_nu
(
nu
);
trainer
.
set_c_class1
(
c1
);
trainer
.
set_c_class2
(
c2
);
// Finally, perform 10-fold cross validation and then print and return the results.
matrix
<
double
>
result
=
cross_validate_trainer
(
trainer
,
samples
,
labels
,
10
);
cout
<<
"gamma: "
<<
setw
(
11
)
<<
gamma
<<
"
nu: "
<<
setw
(
11
)
<<
nu
<<
" cross validation accuracy: "
<<
result
;
cout
<<
"gamma: "
<<
setw
(
11
)
<<
gamma
<<
"
c1: "
<<
setw
(
11
)
<<
c1
<<
" c2: "
<<
setw
(
11
)
<<
c2
<<
" cross validation accuracy: "
<<
result
;
// Now return a number indicating how good the parameters are. Bigger is
// better in this example. Here I'm returning the harmonic mean between the
...
...
@@ -119,33 +120,26 @@ int main() try
return
2
*
prod
(
result
)
/
sum
(
result
);
};
// The nu parameter has a maximum value that is dependent on the ratio of the +1 to -1
// labels in the training data. This function finds that value. The 0.999 is here
// because the maximum allowable nu is strictly less than the value returned by
// maximum_nu(). So shrinking the limit a little will prevent us from hitting it.
const
double
max_nu
=
0.999
*
maximum_nu
(
labels
);
// And finally, we call this global optimizer that will search for the best parameters.
// It will call cross_validation_score()
5
0 times with different settings and return
// It will call cross_validation_score()
3
0 times with different settings and return
// the best parameter setting it finds. find_max_global() uses a global optimization
// method based on a combination of non-parametric global function modeling and
// quadratic trust region modeling to efficiently find a global maximizer. It usually
// does a good job with a relatively small number of calls to cross_validation_score().
// In this example, you should observe that it finds settings that give perfect binary
// classification o
n
the data.
// classification o
f
the data.
auto
result
=
find_max_global
(
cross_validation_score
,
{
1e-5
,
1e-5
},
// lower bound constraints on gamma and nu
, respectively
{
100
,
max_nu
},
// upper bound constraints on gamma and nu
, respectively
max_function_calls
(
5
0
));
{
1e-5
,
1e-5
,
1e-5
},
// lower bound constraints on gamma, c1, and c2
, respectively
{
100
,
1e6
,
1e6
},
// upper bound constraints on gamma, c1, and c2
, respectively
max_function_calls
(
3
0
));
double
best_gamma
=
result
.
x
(
0
);
double
best_nu
=
result
.
x
(
1
);
double
best_c1
=
result
.
x
(
1
);
double
best_c2
=
result
.
x
(
2
);
cout
<<
" best cross-validation score: "
<<
result
.
y
<<
endl
;
cout
<<
" best gamma: "
<<
best_gamma
<<
" best nu: "
<<
best_nu
<<
endl
;
cout
<<
" best gamma: "
<<
best_gamma
<<
" best c1: "
<<
best_c1
<<
" best c2: "
<<
best_c2
<<
endl
;
}
catch
(
exception
&
e
)
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment