Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in
Toggle navigation
D
dlib
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
钟尚武
dlib
Commits
03ec260c
Commit
03ec260c
authored
Jan 18, 2013
by
Davis King
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
reformatted comments.
parent
91e8594b
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
66 additions
and
63 deletions
+66
-63
svm_ex.cpp
examples/svm_ex.cpp
+66
-63
No files found.
examples/svm_ex.cpp
View file @
03ec260c
...
@@ -27,19 +27,19 @@ using namespace dlib;
...
@@ -27,19 +27,19 @@ using namespace dlib;
int
main
()
int
main
()
{
{
// The svm functions use column vectors to contain a lot of the data on which they
// The svm functions use column vectors to contain a lot of the data on which they
// operate. So the first thing we do here is declare a convenient typedef.
// operate. So the first thing we do here is declare a convenient typedef.
// This typedef declares a matrix with 2 rows and 1 column. It will be the
// This typedef declares a matrix with 2 rows and 1 column. It will be the
object that
//
object that contains each of our 2 dimensional samples. (Note that if you wanted
//
contains each of our 2 dimensional samples. (Note that if you wanted more than 2
//
more than 2 features in this vector you can simply change the 2 to something else.
//
features in this vector you can simply change the 2 to something else. Or if you
//
Or if you don't know how many features you want until runtime then you can put a 0
//
don't know how many features you want until runtime then you can put a 0 here and
//
here and
use the matrix.set_size() member function)
// use the matrix.set_size() member function)
typedef
matrix
<
double
,
2
,
1
>
sample_type
;
typedef
matrix
<
double
,
2
,
1
>
sample_type
;
// This is a typedef for the type of kernel we are going to use in this example.
// This is a typedef for the type of kernel we are going to use in this example.
In
//
In this case I have selected the radial basis kernel that can operate on our
//
this case I have selected the radial basis kernel that can operate on our 2D
//
2D
sample_type objects
// sample_type objects
typedef
radial_basis_kernel
<
sample_type
>
kernel_type
;
typedef
radial_basis_kernel
<
sample_type
>
kernel_type
;
...
@@ -47,9 +47,9 @@ int main()
...
@@ -47,9 +47,9 @@ int main()
std
::
vector
<
sample_type
>
samples
;
std
::
vector
<
sample_type
>
samples
;
std
::
vector
<
double
>
labels
;
std
::
vector
<
double
>
labels
;
// Now lets put some data into our samples and labels objects. We do this
// Now lets put some data into our samples and labels objects. We do this
by looping
//
by looping over a bunch of points and labeling them according to their
//
over a bunch of points and labeling them according to their distance from the
//
distance from the
origin.
// origin.
for
(
int
r
=
-
20
;
r
<=
20
;
++
r
)
for
(
int
r
=
-
20
;
r
<=
20
;
++
r
)
{
{
for
(
int
c
=
-
20
;
c
<=
20
;
++
c
)
for
(
int
c
=
-
20
;
c
<=
20
;
++
c
)
...
@@ -69,11 +69,11 @@ int main()
...
@@ -69,11 +69,11 @@ int main()
}
}
// Here we normalize all the samples by subtracting their mean and dividing by their
standard deviation.
// Here we normalize all the samples by subtracting their mean and dividing by their
//
This is generally a good idea since it often heads off numerical stability problems and also
//
standard deviation. This is generally a good idea since it often heads off
//
prevents one large feature from smothering others. Doing this doesn't matter much in this example
//
numerical stability problems and also prevents one large feature from smothering
//
so I'm just doing this here so you can see an easy way to accomplish this with
//
others. Doing this doesn't matter much in this example so I'm just doing this here
// the library.
//
so you can see an easy way to accomplish this with
the library.
vector_normalizer
<
sample_type
>
normalizer
;
vector_normalizer
<
sample_type
>
normalizer
;
// let the normalizer learn the mean and standard deviation of the samples
// let the normalizer learn the mean and standard deviation of the samples
normalizer
.
train
(
samples
);
normalizer
.
train
(
samples
);
...
@@ -82,19 +82,20 @@ int main()
...
@@ -82,19 +82,20 @@ int main()
samples
[
i
]
=
normalizer
(
samples
[
i
]);
samples
[
i
]
=
normalizer
(
samples
[
i
]);
// Now that we have some data we want to train on it. However, there are two parameters to the
// Now that we have some data we want to train on it. However, there are two
// training. These are the nu and gamma parameters. Our choice for these parameters will
// parameters to the training. These are the nu and gamma parameters. Our choice for
// influence how good the resulting decision function is. To test how good a particular choice
// these parameters will influence how good the resulting decision function is. To
// of these parameters is we can use the cross_validate_trainer() function to perform n-fold cross
// test how good a particular choice of these parameters is we can use the
// validation on our training data. However, there is a problem with the way we have sampled
// cross_validate_trainer() function to perform n-fold cross validation on our training
// our distribution above. The problem is that there is a definite ordering to the samples.
// data. However, there is a problem with the way we have sampled our distribution
// That is, the first half of the samples look like they are from a different distribution
// above. The problem is that there is a definite ordering to the samples. That is,
// than the second half. This would screw up the cross validation process but we can
// the first half of the samples look like they are from a different distribution than
// fix it by randomizing the order of the samples with the following function call.
// the second half. This would screw up the cross validation process but we can fix it
// by randomizing the order of the samples with the following function call.
randomize_samples
(
samples
,
labels
);
randomize_samples
(
samples
,
labels
);
// The nu parameter has a maximum value that is dependent on the ratio of the +1 to -1
// The nu parameter has a maximum value that is dependent on the ratio of the +1 to -1
// labels in the training data. This function finds that value.
// labels in the training data. This function finds that value.
const
double
max_nu
=
maximum_nu
(
labels
);
const
double
max_nu
=
maximum_nu
(
labels
);
...
@@ -102,8 +103,8 @@ int main()
...
@@ -102,8 +103,8 @@ int main()
svm_nu_trainer
<
kernel_type
>
trainer
;
svm_nu_trainer
<
kernel_type
>
trainer
;
// Now we loop over some different nu and gamma values to see how good they are. Note
// Now we loop over some different nu and gamma values to see how good they are. Note
// that this is a very simple way to try out a few possible parameter choices. You
// that this is a very simple way to try out a few possible parameter choices. You
// should look at the model_selection_ex.cpp program for examples of more sophisticated
// should look at the model_selection_ex.cpp program for examples of more sophisticated
// strategies for determining good parameter choices.
// strategies for determining good parameter choices.
cout
<<
"doing cross validation"
<<
endl
;
cout
<<
"doing cross validation"
<<
endl
;
for
(
double
gamma
=
0.00001
;
gamma
<=
1
;
gamma
*=
5
)
for
(
double
gamma
=
0.00001
;
gamma
<=
1
;
gamma
*=
5
)
...
@@ -115,29 +116,31 @@ int main()
...
@@ -115,29 +116,31 @@ int main()
trainer
.
set_nu
(
nu
);
trainer
.
set_nu
(
nu
);
cout
<<
"gamma: "
<<
gamma
<<
" nu: "
<<
nu
;
cout
<<
"gamma: "
<<
gamma
<<
" nu: "
<<
nu
;
// Print out the cross validation accuracy for 3-fold cross validation using the current gamma and nu.
// Print out the cross validation accuracy for 3-fold cross validation using
// cross_validate_trainer() returns a row vector. The first element of the vector is the fraction
// the current gamma and nu. cross_validate_trainer() returns a row vector.
// of +1 training examples correctly classified and the second number is the fraction of -1 training
// The first element of the vector is the fraction of +1 training examples
// correctly classified and the second number is the fraction of -1 training
// examples correctly classified.
// examples correctly classified.
cout
<<
" cross validation accuracy: "
<<
cross_validate_trainer
(
trainer
,
samples
,
labels
,
3
);
cout
<<
" cross validation accuracy: "
<<
cross_validate_trainer
(
trainer
,
samples
,
labels
,
3
);
}
}
}
}
// From looking at the output of the above loop it turns out that a good value for
// From looking at the output of the above loop it turns out that a good value for
nu
//
nu
and gamma for this problem is 0.15625 for both. So that is what we will use.
// and gamma for this problem is 0.15625 for both. So that is what we will use.
// Now we train on the full set of data and obtain the resulting decision function. We use the
// Now we train on the full set of data and obtain the resulting decision function. We
// value of 0.15625 for nu and gamma. The decision function will return values >= 0 for samples it predicts
// use the value of 0.15625 for nu and gamma. The decision function will return values
// are in the +1 class and numbers < 0 for samples it predicts to be in the -1 class.
// >= 0 for samples it predicts are in the +1 class and numbers < 0 for samples it
// predicts to be in the -1 class.
trainer
.
set_kernel
(
kernel_type
(
0.15625
));
trainer
.
set_kernel
(
kernel_type
(
0.15625
));
trainer
.
set_nu
(
0.15625
);
trainer
.
set_nu
(
0.15625
);
typedef
decision_function
<
kernel_type
>
dec_funct_type
;
typedef
decision_function
<
kernel_type
>
dec_funct_type
;
typedef
normalized_function
<
dec_funct_type
>
funct_type
;
typedef
normalized_function
<
dec_funct_type
>
funct_type
;
// Here we are making an instance of the normalized_function object. This object
provides a convenient
// Here we are making an instance of the normalized_function object. This object
//
way to store the vector normalization information along with the decision function we are
//
provides a convenient way to store the vector normalization information along with
// going to learn.
//
the decision function we are
going to learn.
funct_type
learned_function
;
funct_type
learned_function
;
learned_function
.
normalizer
=
normalizer
;
// save normalization information
learned_function
.
normalizer
=
normalizer
;
// save normalization information
learned_function
.
function
=
trainer
.
train
(
samples
,
labels
);
// perform the actual SVM training and save the results
learned_function
.
function
=
trainer
.
train
(
samples
,
labels
);
// perform the actual SVM training and save the results
...
@@ -166,8 +169,8 @@ int main()
...
@@ -166,8 +169,8 @@ int main()
cout
<<
"This sample should be < 0 and it is classified as a "
<<
learned_function
(
sample
)
<<
endl
;
cout
<<
"This sample should be < 0 and it is classified as a "
<<
learned_function
(
sample
)
<<
endl
;
// We can also train a decision function that reports a well conditioned probability
// We can also train a decision function that reports a well conditioned probability
// instead of just a number > 0 for the +1 class and < 0 for the -1 class. An example
// instead of just a number > 0 for the +1 class and < 0 for the -1 class. An example
// of doing that follows:
// of doing that follows:
typedef
probabilistic_decision_function
<
kernel_type
>
probabilistic_funct_type
;
typedef
probabilistic_decision_function
<
kernel_type
>
probabilistic_funct_type
;
typedef
normalized_function
<
probabilistic_funct_type
>
pfunct_type
;
typedef
normalized_function
<
probabilistic_funct_type
>
pfunct_type
;
...
@@ -200,8 +203,9 @@ int main()
...
@@ -200,8 +203,9 @@ int main()
// Another thing that is worth knowing is that just about everything in dlib is serializable.
// Another thing that is worth knowing is that just about everything in dlib is
// So for example, you can save the learned_pfunct object to disk and recall it later like so:
// serializable. So for example, you can save the learned_pfunct object to disk and
// recall it later like so:
ofstream
fout
(
"saved_function.dat"
,
ios
::
binary
);
ofstream
fout
(
"saved_function.dat"
,
ios
::
binary
);
serialize
(
learned_pfunct
,
fout
);
serialize
(
learned_pfunct
,
fout
);
fout
.
close
();
fout
.
close
();
...
@@ -210,27 +214,27 @@ int main()
...
@@ -210,27 +214,27 @@ int main()
ifstream
fin
(
"saved_function.dat"
,
ios
::
binary
);
ifstream
fin
(
"saved_function.dat"
,
ios
::
binary
);
deserialize
(
learned_pfunct
,
fin
);
deserialize
(
learned_pfunct
,
fin
);
// Note that there is also an example program that comes with dlib called the
file_to_code_ex.cpp
// Note that there is also an example program that comes with dlib called the
//
example. It is a simple program that takes a file and outputs a piece of C++ code
//
file_to_code_ex.cpp example. It is a simple program that takes a file and outputs a
//
that is able to fully reproduce the file's contents in the form of a std::string object.
//
piece of C++ code that is able to fully reproduce the file's contents in the form of
//
So you can use that along with the std::istringstream to save learned decision functions
//
a std::string object. So you can use that along with the std::istringstream to save
// inside your actual C++ code files if you want.
//
learned decision functions
inside your actual C++ code files if you want.
// Lastly, note that the decision functions we trained above involved well over 200
// Lastly, note that the decision functions we trained above involved well over 200
// basis vectors. Support vector machines in general tend to find decision functions
// basis vectors. Support vector machines in general tend to find decision functions
// that involve a lot of basis vectors. This is significant because the more
// that involve a lot of basis vectors. This is significant because the more
basis
//
basis vectors in a decision function, the longer it takes to classify new examples.
//
vectors in a decision function, the longer it takes to classify new examples. So
//
So dlib provides the ability to find an approximation to the normal output of a
//
dlib provides the ability to find an approximation to the normal output of a trainer
//
trainer
using fewer basis vectors.
// using fewer basis vectors.
// Here we determine the cross validation accuracy when we approximate the output
// Here we determine the cross validation accuracy when we approximate the output
using
//
using only 10 basis vectors. To do this we use the reduced2() function. It
//
only 10 basis vectors. To do this we use the reduced2() function. It takes a
// t
akes a trainer object and the number of basis vectors to use and returns
// t
rainer object and the number of basis vectors to use and returns a new trainer
//
a new trainer object that applies the necessary post processing during the creat
ion
//
object that applies the necessary post processing during the creation of decis
ion
//
of decision
function objects.
// function objects.
cout
<<
"
\n
cross validation accuracy with only 10 support vectors: "
cout
<<
"
\n
cross validation accuracy with only 10 support vectors: "
<<
cross_validate_trainer
(
reduced2
(
trainer
,
10
),
samples
,
labels
,
3
);
<<
cross_validate_trainer
(
reduced2
(
trainer
,
10
),
samples
,
labels
,
3
);
...
@@ -238,9 +242,8 @@ int main()
...
@@ -238,9 +242,8 @@ int main()
cout
<<
"cross validation accuracy with all the original support vectors: "
cout
<<
"cross validation accuracy with all the original support vectors: "
<<
cross_validate_trainer
(
trainer
,
samples
,
labels
,
3
);
<<
cross_validate_trainer
(
trainer
,
samples
,
labels
,
3
);
// When you run this program you should see that, for this problem, you can reduce
// When you run this program you should see that, for this problem, you can reduce the
// the number of basis vectors down to 10 without hurting the cross validation
// number of basis vectors down to 10 without hurting the cross validation accuracy.
// accuracy.
// To get the reduced decision function out we would just do this:
// To get the reduced decision function out we would just do this:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment