Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in
Toggle navigation
D
dlib
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
钟尚武
dlib
Commits
28d76d01
Commit
28d76d01
authored
Nov 08, 2016
by
Davis King
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Made rls run a bit faster, especially if the new mode that allows the
regularization to decay away is activated.
parent
c8c1abb7
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
59 additions
and
11 deletions
+59
-11
rls.h
dlib/svm/rls.h
+43
-10
rls_abstract.h
dlib/svm/rls_abstract.h
+16
-1
No files found.
dlib/svm/rls.h
View file @
28d76d01
...
...
@@ -20,7 +20,8 @@ namespace dlib
explicit
rls
(
double
forget_factor_
,
double
C_
=
1000
double
C_
=
1000
,
bool
apply_forget_factor_to_C_
=
false
)
{
// make sure requires clause is not broken
...
...
@@ -36,6 +37,7 @@ namespace dlib
C
=
C_
;
forget_factor
=
forget_factor_
;
apply_forget_factor_to_C
=
apply_forget_factor_to_C_
;
}
rls
(
...
...
@@ -43,6 +45,7 @@ namespace dlib
{
C
=
1000
;
forget_factor
=
1
;
apply_forget_factor_to_C
=
false
;
}
double
get_c
(
...
...
@@ -57,6 +60,12 @@ namespace dlib
return
forget_factor
;
}
bool
should_apply_forget_factor_to_C
(
)
const
{
return
apply_forget_factor_to_C
;
}
template
<
typename
EXP
>
void
train
(
const
matrix_exp
<
EXP
>&
x
,
...
...
@@ -84,20 +93,25 @@ namespace dlib
// multiply by forget factor and incorporate x*trans(x) into R.
const
double
l
=
1
.
0
/
forget_factor
;
const
double
temp
=
1
+
l
*
trans
(
x
)
*
R
*
x
;
matrix
<
double
,
0
,
1
>
tmp
=
R
*
x
;
tmp
=
R
*
x
;
R
=
l
*
R
-
l
*
l
*
(
tmp
*
trans
(
tmp
))
/
temp
;
// Since we multiplied by the forget factor, we need to add (1-forget_factor) of the
// identity matrix back in to keep the regularization alive.
add_eye_to_inv
(
R
,
(
1
-
forget_factor
)
/
C
);
if
(
forget_factor
!=
1
&&
!
apply_forget_factor_to_C
)
add_eye_to_inv
(
R
,
(
1
-
forget_factor
)
/
C
);
// R should always be symmetric. This line improves numeric stability of this algorithm.
R
=
0
.
5
*
(
R
+
trans
(
R
));
if
(
cnt
%
100
==
0
)
R
=
0
.
5
*
(
R
+
trans
(
R
));
++
cnt
;
w
=
w
+
R
*
x
*
(
y
-
trans
(
x
)
*
w
);
}
const
matrix
<
double
,
0
,
1
>&
get_w
(
)
const
{
...
...
@@ -145,25 +159,37 @@ namespace dlib
friend
inline
void
serialize
(
const
rls
&
item
,
std
::
ostream
&
out
)
{
int
version
=
1
;
int
version
=
2
;
serialize
(
version
,
out
);
serialize
(
item
.
w
,
out
);
serialize
(
item
.
R
,
out
);
serialize
(
item
.
C
,
out
);
serialize
(
item
.
forget_factor
,
out
);
serialize
(
item
.
cnt
,
out
);
serialize
(
item
.
apply_forget_factor_to_C
,
out
);
}
friend
inline
void
deserialize
(
rls
&
item
,
std
::
istream
&
in
)
{
int
version
=
0
;
deserialize
(
version
,
in
);
if
(
version
!=
1
)
if
(
!
(
1
<=
version
&&
version
<=
2
)
)
throw
dlib
::
serialization_error
(
"Unknown version number found while deserializing rls object."
);
deserialize
(
item
.
w
,
in
);
deserialize
(
item
.
R
,
in
);
deserialize
(
item
.
C
,
in
);
deserialize
(
item
.
forget_factor
,
in
);
if
(
version
>=
1
)
{
deserialize
(
item
.
w
,
in
);
deserialize
(
item
.
R
,
in
);
deserialize
(
item
.
C
,
in
);
deserialize
(
item
.
forget_factor
,
in
);
}
item
.
cnt
=
0
;
item
.
apply_forget_factor_to_C
=
false
;
if
(
version
>=
2
)
{
deserialize
(
item
.
cnt
,
in
);
deserialize
(
item
.
apply_forget_factor_to_C
,
in
);
}
}
private
:
...
...
@@ -189,6 +215,13 @@ namespace dlib
matrix
<
double
>
R
;
double
C
;
double
forget_factor
;
int
cnt
=
0
;
bool
apply_forget_factor_to_C
;
// This object is here only to avoid reallocation during training. It don't
// logically contribute to the state of this object.
matrix
<
double
,
0
,
1
>
tmp
;
};
// ----------------------------------------------------------------------------------------
...
...
dlib/svm/rls_abstract.h
View file @
28d76d01
...
...
@@ -37,7 +37,8 @@ namespace dlib
explicit
rls
(
double
forget_factor
,
double
C
=
1000
double
C
=
1000
,
bool
apply_forget_factor_to_C
=
false
);
/*!
requires
...
...
@@ -47,6 +48,7 @@ namespace dlib
- #get_w().size() == 0
- #get_c() == C
- #get_forget_factor() == forget_factor
- #should_apply_forget_factor_to_C() == apply_forget_factor_to_C
!*/
rls
(
...
...
@@ -56,6 +58,7 @@ namespace dlib
- #get_w().size() == 0
- #get_c() == 1000
- #get_forget_factor() == 1
- #should_apply_forget_factor_to_C() == false
!*/
double
get_c
(
...
...
@@ -80,6 +83,18 @@ namespace dlib
zero the faster old examples are forgotten.
!*/
bool
should_apply_forget_factor_to_C
(
)
const
;
/*!
ensures
- If this function returns false then it means we are optimizing the
objective function discussed in the WHAT THIS OBJECT REPRESENTS section
above. However, if it returns true then we will allow the forget factor
(get_forget_factor()) to be applied to the C value which causes the
algorithm to slowly increase C and convert into a textbook version of RLS
without regularization. The main reason you might want to do this is
because it can make the algorithm run significantly faster.
!*/
template
<
typename
EXP
>
void
train
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment