Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in
Toggle navigation
D
dlib
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
钟尚武
dlib
Commits
b1909d5c
Commit
b1909d5c
authored
Nov 08, 2016
by
Davis King
Browse files
Options
Browse Files
Download
Plain Diff
merged
parents
08a89c80
28d76d01
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
96 additions
and
23 deletions
+96
-23
mex_wrapper.cpp
dlib/matlab/mex_wrapper.cpp
+37
-12
rls.h
dlib/svm/rls.h
+43
-10
rls_abstract.h
dlib/svm/rls_abstract.h
+16
-1
No files found.
dlib/matlab/mex_wrapper.cpp
View file @
b1909d5c
...
@@ -368,6 +368,31 @@ namespace mex_binding
...
@@ -368,6 +368,31 @@ namespace mex_binding
struct
is_column_major_matrix
<
matrix
<
T
,
num_rows
,
num_cols
,
mem_manager
,
column_major_layout
>
>
struct
is_column_major_matrix
<
matrix
<
T
,
num_rows
,
num_cols
,
mem_manager
,
column_major_layout
>
>
{
static
const
bool
value
=
true
;
};
{
static
const
bool
value
=
true
;
};
// -------------------------------------------------------
string
escape_percent
(
const
string
&
str
)
{
string
temp
;
for
(
auto
c
:
str
)
{
if
(
c
!=
'%'
)
{
temp
+=
c
;
}
else
{
temp
+=
c
;
temp
+=
c
;
}
}
return
temp
;
}
string
escape_percent
(
const
std
::
ostringstream
&
sout
)
{
return
escape_percent
(
sout
.
str
());
}
// -------------------------------------------------------
// -------------------------------------------------------
template
<
template
<
...
@@ -386,14 +411,14 @@ namespace mex_binding
...
@@ -386,14 +411,14 @@ namespace mex_binding
std
::
ostringstream
sout
;
std
::
ostringstream
sout
;
sout
<<
"Argument "
<<
arg_idx
+
1
<<
" expects a matrix with "
<<
matrix_type
::
NR
<<
" rows but got one with "
<<
src
.
nc
();
sout
<<
"Argument "
<<
arg_idx
+
1
<<
" expects a matrix with "
<<
matrix_type
::
NR
<<
" rows but got one with "
<<
src
.
nc
();
mexErrMsgIdAndTxt
(
"mex_function:validate_and_populate_arg"
,
mexErrMsgIdAndTxt
(
"mex_function:validate_and_populate_arg"
,
sout
.
str
(
).
c_str
());
escape_percent
(
sout
).
c_str
());
}
}
if
(
matrix_type
::
NC
!=
0
&&
matrix_type
::
NC
!=
src
.
nr
())
if
(
matrix_type
::
NC
!=
0
&&
matrix_type
::
NC
!=
src
.
nr
())
{
{
std
::
ostringstream
sout
;
std
::
ostringstream
sout
;
sout
<<
"Argument "
<<
arg_idx
+
1
<<
" expects a matrix with "
<<
matrix_type
::
NC
<<
" columns but got one with "
<<
src
.
nr
();
sout
<<
"Argument "
<<
arg_idx
+
1
<<
" expects a matrix with "
<<
matrix_type
::
NC
<<
" columns but got one with "
<<
src
.
nr
();
mexErrMsgIdAndTxt
(
"mex_function:validate_and_populate_arg"
,
mexErrMsgIdAndTxt
(
"mex_function:validate_and_populate_arg"
,
sout
.
str
(
).
c_str
());
escape_percent
(
sout
).
c_str
());
}
}
...
@@ -429,7 +454,7 @@ namespace mex_binding
...
@@ -429,7 +454,7 @@ namespace mex_binding
std
::
ostringstream
sout
;
std
::
ostringstream
sout
;
sout
<<
"mex_function has some bug in it related to processing input argument "
<<
arg_idx
+
1
;
sout
<<
"mex_function has some bug in it related to processing input argument "
<<
arg_idx
+
1
;
mexErrMsgIdAndTxt
(
"mex_function:validate_and_populate_arg"
,
mexErrMsgIdAndTxt
(
"mex_function:validate_and_populate_arg"
,
sout
.
str
(
).
c_str
());
escape_percent
(
sout
).
c_str
());
}
}
...
@@ -451,7 +476,7 @@ namespace mex_binding
...
@@ -451,7 +476,7 @@ namespace mex_binding
std
::
ostringstream
sout
;
std
::
ostringstream
sout
;
sout
<<
"Error, input argument "
<<
arg_idx
+
1
<<
" must be a non-negative number."
;
sout
<<
"Error, input argument "
<<
arg_idx
+
1
<<
" must be a non-negative number."
;
mexErrMsgIdAndTxt
(
"mex_function:validate_and_populate_arg"
,
mexErrMsgIdAndTxt
(
"mex_function:validate_and_populate_arg"
,
sout
.
str
(
).
c_str
());
escape_percent
(
sout
).
c_str
());
}
}
else
else
{
{
...
@@ -473,7 +498,7 @@ namespace mex_binding
...
@@ -473,7 +498,7 @@ namespace mex_binding
std
::
ostringstream
sout
;
std
::
ostringstream
sout
;
sout
<<
"mex_function has some bug in it related to processing input argument "
<<
arg_idx
+
1
;
sout
<<
"mex_function has some bug in it related to processing input argument "
<<
arg_idx
+
1
;
mexErrMsgIdAndTxt
(
"mex_function:validate_and_populate_arg"
,
mexErrMsgIdAndTxt
(
"mex_function:validate_and_populate_arg"
,
sout
.
str
(
).
c_str
());
escape_percent
(
sout
).
c_str
());
}
}
...
@@ -500,7 +525,7 @@ namespace mex_binding
...
@@ -500,7 +525,7 @@ namespace mex_binding
std
::
ostringstream
sout
;
std
::
ostringstream
sout
;
sout
<<
"mex_function has some bug in it related to processing input argument "
<<
arg_idx
+
1
;
sout
<<
"mex_function has some bug in it related to processing input argument "
<<
arg_idx
+
1
;
mexErrMsgIdAndTxt
(
"mex_function:validate_and_populate_arg"
,
mexErrMsgIdAndTxt
(
"mex_function:validate_and_populate_arg"
,
sout
.
str
(
).
c_str
());
escape_percent
(
sout
).
c_str
());
}
}
...
@@ -567,7 +592,7 @@ namespace mex_binding
...
@@ -567,7 +592,7 @@ namespace mex_binding
std
::
ostringstream
sout
;
std
::
ostringstream
sout
;
sout
<<
"mex_function has some bug in it related to processing input argument "
<<
arg_idx
+
1
;
sout
<<
"mex_function has some bug in it related to processing input argument "
<<
arg_idx
+
1
;
mexErrMsgIdAndTxt
(
"mex_function:validate_and_populate_arg"
,
mexErrMsgIdAndTxt
(
"mex_function:validate_and_populate_arg"
,
sout
.
str
(
).
c_str
());
escape_percent
(
sout
).
c_str
());
}
}
// -------------------------------------------------------
// -------------------------------------------------------
...
@@ -584,7 +609,7 @@ namespace mex_binding
...
@@ -584,7 +609,7 @@ namespace mex_binding
std
::
ostringstream
sout
;
std
::
ostringstream
sout
;
sout
<<
"mex_function has some bug in it related to processing input argument "
<<
arg_idx
+
1
;
sout
<<
"mex_function has some bug in it related to processing input argument "
<<
arg_idx
+
1
;
mexErrMsgIdAndTxt
(
"mex_function:validate_and_populate_arg"
,
mexErrMsgIdAndTxt
(
"mex_function:validate_and_populate_arg"
,
sout
.
str
(
).
c_str
());
escape_percent
(
sout
).
c_str
());
}
}
template
<
typename
MM
>
template
<
typename
MM
>
...
@@ -2913,7 +2938,7 @@ namespace mex_binding
...
@@ -2913,7 +2938,7 @@ namespace mex_binding
<<
" and "
<<
expected_nrhs
<<
" input arguments, got "
<<
nrhs
<<
"."
;
<<
" and "
<<
expected_nrhs
<<
" input arguments, got "
<<
nrhs
<<
"."
;
mexErrMsgIdAndTxt
(
"mex_function:nrhs"
,
mexErrMsgIdAndTxt
(
"mex_function:nrhs"
,
sout
.
str
(
).
c_str
());
escape_percent
(
sout
).
c_str
());
}
}
if
(
nlhs
>
expected_nlhs
)
if
(
nlhs
>
expected_nlhs
)
...
@@ -2922,7 +2947,7 @@ namespace mex_binding
...
@@ -2922,7 +2947,7 @@ namespace mex_binding
sout
<<
"Expected at most "
<<
expected_nlhs
<<
" output arguments, got "
<<
nlhs
<<
"."
;
sout
<<
"Expected at most "
<<
expected_nlhs
<<
" output arguments, got "
<<
nlhs
<<
"."
;
mexErrMsgIdAndTxt
(
"mex_function:nlhs"
,
mexErrMsgIdAndTxt
(
"mex_function:nlhs"
,
sout
.
str
(
).
c_str
());
escape_percent
(
sout
).
c_str
());
}
}
call_mex_function_helper
<
sig_traits
<
funct
>::
num_args
>
helper
;
call_mex_function_helper
<
sig_traits
<
funct
>::
num_args
>
helper
;
...
@@ -4988,7 +5013,7 @@ void mexFunction( int nlhs, mxArray *plhs[],
...
@@ -4988,7 +5013,7 @@ void mexFunction( int nlhs, mxArray *plhs[],
catch
(
mex_binding
::
invalid_args_exception
&
e
)
catch
(
mex_binding
::
invalid_args_exception
&
e
)
{
{
mexErrMsgIdAndTxt
(
"mex_function:validate_and_populate_arg"
,
mexErrMsgIdAndTxt
(
"mex_function:validate_and_populate_arg"
,
(
"Input"
+
e
.
msg
).
c_str
());
mex_binding
::
escape_percent
(
"Input"
+
e
.
msg
).
c_str
());
}
}
catch
(
mex_binding
::
user_hit_ctrl_c
&
)
catch
(
mex_binding
::
user_hit_ctrl_c
&
)
{
{
...
@@ -4997,7 +5022,7 @@ void mexFunction( int nlhs, mxArray *plhs[],
...
@@ -4997,7 +5022,7 @@ void mexFunction( int nlhs, mxArray *plhs[],
catch
(
std
::
exception
&
e
)
catch
(
std
::
exception
&
e
)
{
{
mexErrMsgIdAndTxt
(
"mex_function:error"
,
mexErrMsgIdAndTxt
(
"mex_function:error"
,
e
.
what
());
mex_binding
::
escape_percent
(
e
.
what
()).
c_str
());
}
}
cout
<<
flush
;
cout
<<
flush
;
...
...
dlib/svm/rls.h
View file @
b1909d5c
...
@@ -20,7 +20,8 @@ namespace dlib
...
@@ -20,7 +20,8 @@ namespace dlib
explicit
rls
(
explicit
rls
(
double
forget_factor_
,
double
forget_factor_
,
double
C_
=
1000
double
C_
=
1000
,
bool
apply_forget_factor_to_C_
=
false
)
)
{
{
// make sure requires clause is not broken
// make sure requires clause is not broken
...
@@ -36,6 +37,7 @@ namespace dlib
...
@@ -36,6 +37,7 @@ namespace dlib
C
=
C_
;
C
=
C_
;
forget_factor
=
forget_factor_
;
forget_factor
=
forget_factor_
;
apply_forget_factor_to_C
=
apply_forget_factor_to_C_
;
}
}
rls
(
rls
(
...
@@ -43,6 +45,7 @@ namespace dlib
...
@@ -43,6 +45,7 @@ namespace dlib
{
{
C
=
1000
;
C
=
1000
;
forget_factor
=
1
;
forget_factor
=
1
;
apply_forget_factor_to_C
=
false
;
}
}
double
get_c
(
double
get_c
(
...
@@ -57,6 +60,12 @@ namespace dlib
...
@@ -57,6 +60,12 @@ namespace dlib
return
forget_factor
;
return
forget_factor
;
}
}
bool
should_apply_forget_factor_to_C
(
)
const
{
return
apply_forget_factor_to_C
;
}
template
<
typename
EXP
>
template
<
typename
EXP
>
void
train
(
void
train
(
const
matrix_exp
<
EXP
>&
x
,
const
matrix_exp
<
EXP
>&
x
,
...
@@ -84,20 +93,25 @@ namespace dlib
...
@@ -84,20 +93,25 @@ namespace dlib
// multiply by forget factor and incorporate x*trans(x) into R.
// multiply by forget factor and incorporate x*trans(x) into R.
const
double
l
=
1
.
0
/
forget_factor
;
const
double
l
=
1
.
0
/
forget_factor
;
const
double
temp
=
1
+
l
*
trans
(
x
)
*
R
*
x
;
const
double
temp
=
1
+
l
*
trans
(
x
)
*
R
*
x
;
matrix
<
double
,
0
,
1
>
tmp
=
R
*
x
;
tmp
=
R
*
x
;
R
=
l
*
R
-
l
*
l
*
(
tmp
*
trans
(
tmp
))
/
temp
;
R
=
l
*
R
-
l
*
l
*
(
tmp
*
trans
(
tmp
))
/
temp
;
// Since we multiplied by the forget factor, we need to add (1-forget_factor) of the
// Since we multiplied by the forget factor, we need to add (1-forget_factor) of the
// identity matrix back in to keep the regularization alive.
// identity matrix back in to keep the regularization alive.
add_eye_to_inv
(
R
,
(
1
-
forget_factor
)
/
C
);
if
(
forget_factor
!=
1
&&
!
apply_forget_factor_to_C
)
add_eye_to_inv
(
R
,
(
1
-
forget_factor
)
/
C
);
// R should always be symmetric. This line improves numeric stability of this algorithm.
// R should always be symmetric. This line improves numeric stability of this algorithm.
R
=
0
.
5
*
(
R
+
trans
(
R
));
if
(
cnt
%
100
==
0
)
R
=
0
.
5
*
(
R
+
trans
(
R
));
++
cnt
;
w
=
w
+
R
*
x
*
(
y
-
trans
(
x
)
*
w
);
w
=
w
+
R
*
x
*
(
y
-
trans
(
x
)
*
w
);
}
}
const
matrix
<
double
,
0
,
1
>&
get_w
(
const
matrix
<
double
,
0
,
1
>&
get_w
(
)
const
)
const
{
{
...
@@ -145,25 +159,37 @@ namespace dlib
...
@@ -145,25 +159,37 @@ namespace dlib
friend
inline
void
serialize
(
const
rls
&
item
,
std
::
ostream
&
out
)
friend
inline
void
serialize
(
const
rls
&
item
,
std
::
ostream
&
out
)
{
{
int
version
=
1
;
int
version
=
2
;
serialize
(
version
,
out
);
serialize
(
version
,
out
);
serialize
(
item
.
w
,
out
);
serialize
(
item
.
w
,
out
);
serialize
(
item
.
R
,
out
);
serialize
(
item
.
R
,
out
);
serialize
(
item
.
C
,
out
);
serialize
(
item
.
C
,
out
);
serialize
(
item
.
forget_factor
,
out
);
serialize
(
item
.
forget_factor
,
out
);
serialize
(
item
.
cnt
,
out
);
serialize
(
item
.
apply_forget_factor_to_C
,
out
);
}
}
friend
inline
void
deserialize
(
rls
&
item
,
std
::
istream
&
in
)
friend
inline
void
deserialize
(
rls
&
item
,
std
::
istream
&
in
)
{
{
int
version
=
0
;
int
version
=
0
;
deserialize
(
version
,
in
);
deserialize
(
version
,
in
);
if
(
version
!=
1
)
if
(
!
(
1
<=
version
&&
version
<=
2
)
)
throw
dlib
::
serialization_error
(
"Unknown version number found while deserializing rls object."
);
throw
dlib
::
serialization_error
(
"Unknown version number found while deserializing rls object."
);
deserialize
(
item
.
w
,
in
);
if
(
version
>=
1
)
deserialize
(
item
.
R
,
in
);
{
deserialize
(
item
.
C
,
in
);
deserialize
(
item
.
w
,
in
);
deserialize
(
item
.
forget_factor
,
in
);
deserialize
(
item
.
R
,
in
);
deserialize
(
item
.
C
,
in
);
deserialize
(
item
.
forget_factor
,
in
);
}
item
.
cnt
=
0
;
item
.
apply_forget_factor_to_C
=
false
;
if
(
version
>=
2
)
{
deserialize
(
item
.
cnt
,
in
);
deserialize
(
item
.
apply_forget_factor_to_C
,
in
);
}
}
}
private
:
private
:
...
@@ -189,6 +215,13 @@ namespace dlib
...
@@ -189,6 +215,13 @@ namespace dlib
matrix
<
double
>
R
;
matrix
<
double
>
R
;
double
C
;
double
C
;
double
forget_factor
;
double
forget_factor
;
int
cnt
=
0
;
bool
apply_forget_factor_to_C
;
// This object is here only to avoid reallocation during training. It don't
// logically contribute to the state of this object.
matrix
<
double
,
0
,
1
>
tmp
;
};
};
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
...
...
dlib/svm/rls_abstract.h
View file @
b1909d5c
...
@@ -37,7 +37,8 @@ namespace dlib
...
@@ -37,7 +37,8 @@ namespace dlib
explicit
rls
(
explicit
rls
(
double
forget_factor
,
double
forget_factor
,
double
C
=
1000
double
C
=
1000
,
bool
apply_forget_factor_to_C
=
false
);
);
/*!
/*!
requires
requires
...
@@ -47,6 +48,7 @@ namespace dlib
...
@@ -47,6 +48,7 @@ namespace dlib
- #get_w().size() == 0
- #get_w().size() == 0
- #get_c() == C
- #get_c() == C
- #get_forget_factor() == forget_factor
- #get_forget_factor() == forget_factor
- #should_apply_forget_factor_to_C() == apply_forget_factor_to_C
!*/
!*/
rls
(
rls
(
...
@@ -56,6 +58,7 @@ namespace dlib
...
@@ -56,6 +58,7 @@ namespace dlib
- #get_w().size() == 0
- #get_w().size() == 0
- #get_c() == 1000
- #get_c() == 1000
- #get_forget_factor() == 1
- #get_forget_factor() == 1
- #should_apply_forget_factor_to_C() == false
!*/
!*/
double
get_c
(
double
get_c
(
...
@@ -80,6 +83,18 @@ namespace dlib
...
@@ -80,6 +83,18 @@ namespace dlib
zero the faster old examples are forgotten.
zero the faster old examples are forgotten.
!*/
!*/
bool
should_apply_forget_factor_to_C
(
)
const
;
/*!
ensures
- If this function returns false then it means we are optimizing the
objective function discussed in the WHAT THIS OBJECT REPRESENTS section
above. However, if it returns true then we will allow the forget factor
(get_forget_factor()) to be applied to the C value which causes the
algorithm to slowly increase C and convert into a textbook version of RLS
without regularization. The main reason you might want to do this is
because it can make the algorithm run significantly faster.
!*/
template
<
typename
EXP
>
template
<
typename
EXP
>
void
train
(
void
train
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment