Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in
Toggle navigation
D
dlib
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
钟尚武
dlib
Commits
2a27b690
Commit
2a27b690
authored
Jun 07, 2018
by
Davis King
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Added auto_train_rbf_classifier()
parent
c14dca07
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
166 additions
and
0 deletions
+166
-0
CMakeLists.txt
dlib/CMakeLists.txt
+1
-0
source.cpp
dlib/all/source.cpp
+1
-0
svm.h
dlib/svm.h
+1
-0
auto.cpp
dlib/svm/auto.cpp
+102
-0
auto.h
dlib/svm/auto.h
+25
-0
auto_abstract.h
dlib/svm/auto_abstract.h
+36
-0
No files found.
dlib/CMakeLists.txt
View file @
2a27b690
...
@@ -243,6 +243,7 @@ if (NOT TARGET dlib)
...
@@ -243,6 +243,7 @@ if (NOT TARGET dlib)
global_optimization/global_function_search.cpp
global_optimization/global_function_search.cpp
filtering/kalman_filter.cpp
filtering/kalman_filter.cpp
test_for_odr_violations.cpp
test_for_odr_violations.cpp
svm/auto.cpp
)
)
...
...
dlib/all/source.cpp
View file @
2a27b690
...
@@ -90,6 +90,7 @@
...
@@ -90,6 +90,7 @@
#include "../data_io/mnist.cpp"
#include "../data_io/mnist.cpp"
#include "../global_optimization/global_function_search.cpp"
#include "../global_optimization/global_function_search.cpp"
#include "../filtering/kalman_filter.cpp"
#include "../filtering/kalman_filter.cpp"
#include "../svm/auto.cpp"
#define DLIB_ALL_SOURCE_END
#define DLIB_ALL_SOURCE_END
...
...
dlib/svm.h
View file @
2a27b690
...
@@ -54,6 +54,7 @@
...
@@ -54,6 +54,7 @@
#include "svm/active_learning.h"
#include "svm/active_learning.h"
#include "svm/svr_linear_trainer.h"
#include "svm/svr_linear_trainer.h"
#include "svm/sequence_segmenter.h"
#include "svm/sequence_segmenter.h"
#include "svm/auto.h"
#endif // DLIB_SVm_HEADER
#endif // DLIB_SVm_HEADER
...
...
dlib/svm/auto.cpp
0 → 100644
View file @
2a27b690
// Copyright (C) 2018 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#ifndef DLIB_AUTO_LEARnING_CPP_
#define DLIB_AUTO_LEARnING_CPP_
#include "auto.h"
#include "../global_optimization.h"
#include "svm_c_trainer.h"
#include <iostream>
#include <thread>
namespace
dlib
{
normalized_function
<
decision_function
<
radial_basis_kernel
<
matrix
<
double
,
0
,
1
>>>>
auto_train_rbf_classifier
(
std
::
vector
<
matrix
<
double
,
0
,
1
>>
x
,
std
::
vector
<
double
>
y
,
const
std
::
chrono
::
nanoseconds
max_runtime
,
bool
be_verbose
)
{
const
auto
num_positive_training_samples
=
sum
(
mat
(
y
)
>
0
);
const
auto
num_negative_training_samples
=
sum
(
mat
(
y
)
<
0
);
DLIB_CASSERT
(
num_positive_training_samples
>=
6
&&
num_negative_training_samples
>=
6
,
"You must provide at least 6 examples of each class to this training routine."
);
// make sure requires clause is not broken
DLIB_CASSERT
(
is_binary_classification_problem
(
x
,
y
)
==
true
,
"
\t
decision_function svm_c_trainer::train(x,y)"
<<
"
\n\t
invalid inputs were given to this function"
<<
"
\n\t
x.size(): "
<<
x
.
size
()
<<
"
\n\t
y.size(): "
<<
y
.
size
()
<<
"
\n\t
is_binary_classification_problem(x,y): "
<<
is_binary_classification_problem
(
x
,
y
)
);
randomize_samples
(
x
,
y
);
vector_normalizer
<
matrix
<
double
,
0
,
1
>>
normalizer
;
// let the normalizer learn the mean and standard deviation of the samples
normalizer
.
train
(
x
);
for
(
auto
&
samp
:
x
)
samp
=
normalizer
(
samp
);
normalized_function
<
decision_function
<
radial_basis_kernel
<
matrix
<
double
,
0
,
1
>>>>
df
;
df
.
normalizer
=
normalizer
;
typedef
radial_basis_kernel
<
matrix
<
double
,
0
,
1
>>
kernel_type
;
std
::
mutex
m
;
auto
cross_validation_score
=
[
&
](
const
double
gamma
,
const
double
c1
,
const
double
c2
)
{
svm_c_trainer
<
kernel_type
>
trainer
;
trainer
.
set_kernel
(
kernel_type
(
gamma
));
trainer
.
set_c_class1
(
c1
);
trainer
.
set_c_class2
(
c2
);
// Finally, perform 6-fold cross validation and then print and return the results.
matrix
<
double
>
result
=
cross_validate_trainer
(
trainer
,
x
,
y
,
6
);
if
(
be_verbose
)
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
m
);
std
::
cout
<<
"gamma: "
<<
std
::
setw
(
11
)
<<
gamma
<<
" c1: "
<<
std
::
setw
(
11
)
<<
c1
<<
" c2: "
<<
std
::
setw
(
11
)
<<
c2
<<
" cross validation accuracy: "
<<
result
<<
std
::
flush
;
}
// return the f1 score plus a penalty for picking large parameter settings
// since those are, a priori less likely to generalize.
return
2
*
prod
(
result
)
/
sum
(
result
)
-
std
::
max
(
c1
,
c2
)
/
1e12
-
gamma
/
1e8
;
};
std
::
cout
<<
"Searching for best RBF-SVM training parameters..."
<<
std
::
endl
;
auto
result
=
find_max_global
(
default_thread_pool
(),
cross_validation_score
,
{
1e-5
,
1e-5
,
1e-5
},
// lower bound constraints on gamma, c1, and c2, respectively
{
100
,
1e6
,
1e6
},
// upper bound constraints on gamma, c1, and c2, respectively
max_runtime
);
double
best_gamma
=
result
.
x
(
0
);
double
best_c1
=
result
.
x
(
1
);
double
best_c2
=
result
.
x
(
2
);
std
::
cout
<<
" best cross-validation score: "
<<
result
.
y
<<
std
::
endl
;
std
::
cout
<<
" best gamma: "
<<
best_gamma
<<
" best c1: "
<<
best_c1
<<
" best c2: "
<<
best_c2
<<
std
::
endl
;
svm_c_trainer
<
kernel_type
>
trainer
;
trainer
.
set_kernel
(
kernel_type
(
best_gamma
));
trainer
.
set_c_class1
(
best_c1
);
trainer
.
set_c_class2
(
best_c2
);
std
::
cout
<<
"Training final classifier with best parameters..."
<<
std
::
endl
;
df
.
function
=
trainer
.
train
(
x
,
y
);
return
df
;
}
}
#endif // DLIB_AUTO_LEARnING_CPP_
dlib/svm/auto.h
0 → 100644
View file @
2a27b690
// Copyright (C) 2018 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#ifndef DLIB_AUTO_LEARnING_Hh_
#define DLIB_AUTO_LEARnING_Hh_
#include "auto_abstract.h"
#include "../algs.h"
#include "function.h"
#include "kernel.h"
#include <chrono>
#include <vector>
namespace
dlib
{
normalized_function
<
decision_function
<
radial_basis_kernel
<
matrix
<
double
,
0
,
1
>>>>
auto_train_rbf_classifier
(
std
::
vector
<
matrix
<
double
,
0
,
1
>>
x
,
std
::
vector
<
double
>
y
,
const
std
::
chrono
::
nanoseconds
max_runtime
,
bool
be_verbose
=
true
);
}
#endif // DLIB_AUTO_LEARnING_Hh_
dlib/svm/auto_abstract.h
0 → 100644
View file @
2a27b690
// Copyright (C) 2018 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#undef DLIB_AUTO_LEARnING_ABSTRACT_Hh_
#ifdef DLIB_AUTO_LEARnING_ABSTRACT_Hh_
#include "kernel_abstract.h"
#include "function_abstract.h"
#include <chrono>
#include <vector>
namespace
dlib
{
normalized_function
<
decision_function
<
radial_basis_kernel
<
matrix
<
double
,
0
,
1
>>>>
auto_train_rbf_classifier
(
std
::
vector
<
matrix
<
double
,
0
,
1
>>
x
,
std
::
vector
<
double
>
y
,
const
std
::
chrono
::
nanoseconds
max_runtime
,
bool
be_verbose
=
true
);
/*!
requires
- is_binary_classification_problem(x,y) == true
- y contains at least 6 examples of each class.
ensures
- This routine trains a radial basis function SVM on the given binary
classification training data. It uses the svm_c_trainer to do this. It also
uses find_max_global() and 6-fold cross-validation to automatically determine
the best settings of the SVM's hyper parameters.
- The hyperparameter search will run for about max_runtime and will print
messages to the screen as it runs if be_verbose==true.
!*/
}
#endif // DLIB_AUTO_LEARnING_ABSTRACT_Hh_
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment