Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in
Toggle navigation
D
dlib
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
钟尚武
dlib
Commits
1fbd1828
Commit
1fbd1828
authored
Nov 24, 2017
by
Davis King
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Cleaned up the code a bit.
parent
0d9043bc
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
78 additions
and
67 deletions
+78
-67
global_function_search.cpp
dlib/global_optimization/global_function_search.cpp
+51
-39
global_function_search.h
dlib/global_optimization/global_function_search.h
+5
-5
global_function_search_abstract.h
dlib/global_optimization/global_function_search_abstract.h
+22
-23
No files found.
dlib/global_optimization/global_function_search.cpp
View file @
1fbd1828
...
...
@@ -11,7 +11,7 @@ namespace dlib
namespace
qopt_impl
{
void
fit_q
p
_mse
(
void
fit_q
uadratic_to_points
_mse
(
const
matrix
<
double
>&
X
,
const
matrix
<
double
,
0
,
1
>&
Y
,
matrix
<
double
>&
H
,
...
...
@@ -64,21 +64,21 @@ namespace dlib
// ----------------------------------------------------------------------------------------
void
fit_q
p
(
void
fit_q
uadratic_to_points
(
const
matrix
<
double
>&
X
,
const
matrix
<
double
,
0
,
1
>&
Y
,
matrix
<
double
>&
H
,
matrix
<
double
,
0
,
1
>&
g
,
double
&
c
)
/*!
requires
- X.size() > 0
/*!
requires
- X.size() > 0
- X.nc() == Y.size()
- X.nr()+1 <= X.nc()
<= (X.nr()+1)*(X.nr()+2)/2
ensures
- This function finds a quadratic function, Q(x), that interpolates the
given set of points. If there aren't enough points to uniquely define
- X.nr()+1 <= X.nc()
ensures
- This function finds a quadratic function, Q(x), that interpolates the
given set of points. If there aren't enough points to uniquely define
Q(x) then the Q(x) that fits the given points with the minimum Frobenius
norm hessian matrix is selected.
- To be precise:
...
...
@@ -87,16 +87,19 @@ namespace dlib
sum(squared(H))
such that:
Q(colm(X,i)) == Y(i), for all valid i
!*/
- If there are more points than necessary to constrain Q then the Q
that best interpolates the function in the mean squared sense is
found.
!*/
{
DLIB_CASSERT
(
X
.
size
()
>
0
);
DLIB_CASSERT
(
X
.
nc
()
==
Y
.
size
());
DLIB_CASSERT
(
X
.
nr
()
+
1
<=
X
.
nc
());
// && X.nc() <= (X.nr()+1)*(X.nr()+2)/2);
DLIB_CASSERT
(
X
.
nr
()
+
1
<=
X
.
nc
());
if
(
X
.
nc
()
>=
(
X
.
nr
()
+
1
)
*
(
X
.
nr
()
+
2
)
/
2
)
{
fit_q
p
_mse
(
X
,
Y
,
H
,
g
,
c
);
fit_q
uadratic_to_points
_mse
(
X
,
Y
,
H
,
g
,
c
);
return
;
}
...
...
@@ -180,7 +183,7 @@ namespace dlib
matrix
<
double
,
0
,
1
>
g
;
double
c
;
fit_q
p
(
X
,
Y
,
H
,
g
,
c
);
fit_q
uadratic_to_points
(
X
,
Y
,
H
,
g
,
c
);
matrix
<
double
,
0
,
1
>
p
;
...
...
@@ -198,7 +201,7 @@ namespace dlib
// ----------------------------------------------------------------------------------------
quad_interp_result
pick_next_sample_
quad_interp
(
quad_interp_result
pick_next_sample_
using_trust_region
(
const
std
::
vector
<
function_evaluation
>&
samples
,
double
&
radius
,
const
matrix
<
double
,
0
,
1
>&
lower
,
...
...
@@ -324,7 +327,7 @@ namespace dlib
// ------------------------------------------------------------------------------------
max_upper_bound_function
pick_next_sample_
max_upper_bound_function
(
max_upper_bound_function
pick_next_sample_
as_max_upper_bound
(
dlib
::
rand
&
rnd
,
const
upper_bound_function
&
ub
,
const
matrix
<
double
,
0
,
1
>&
lower
,
...
...
@@ -417,10 +420,10 @@ namespace dlib
{
upper_bound_function
tmp
(
ub
);
// we are going to add the
incomplete
evals into this and assume the
//
incomplete
evals are going to take y values equal to their nearest
// we are going to add the
outstanding
evals into this and assume the
//
outstanding
evals are going to take y values equal to their nearest
// neighbor complete evals.
for
(
auto
&
eval
:
incomplete
_evals
)
for
(
auto
&
eval
:
outstanding
_evals
)
{
function_evaluation
e
;
e
.
x
=
eval
.
x
;
...
...
@@ -454,6 +457,7 @@ namespace dlib
}
// end namespace gopt_impl
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
...
...
@@ -526,9 +530,9 @@ namespace dlib
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
*
info
->
m
);
// remove the evaluation request from the
incomplete
list.
auto
i
=
std
::
find
(
info
->
incomplete_evals
.
begin
(),
info
->
incomplete
_evals
.
end
(),
req
);
info
->
incomplete
_evals
.
erase
(
i
);
// remove the evaluation request from the
outstanding
list.
auto
i
=
std
::
find
(
info
->
outstanding_evals
.
begin
(),
info
->
outstanding
_evals
.
end
(),
req
);
info
->
outstanding
_evals
.
erase
(
i
);
}
}
...
...
@@ -545,10 +549,10 @@ namespace dlib
m_has_been_evaluated
=
true
;
// move the evaluation from
incomplete
to complete
auto
i
=
std
::
find
(
info
->
incomplete_evals
.
begin
(),
info
->
incomplete
_evals
.
end
(),
req
);
DLIB_CASSERT
(
i
!=
info
->
incomplete
_evals
.
end
());
info
->
incomplete
_evals
.
erase
(
i
);
// move the evaluation from
outstanding
to complete
auto
i
=
std
::
find
(
info
->
outstanding_evals
.
begin
(),
info
->
outstanding
_evals
.
end
(),
req
);
DLIB_CASSERT
(
i
!=
info
->
outstanding
_evals
.
end
());
info
->
outstanding
_evals
.
erase
(
i
);
info
->
ub
.
add
(
function_evaluation
(
req
.
x
,
y
));
...
...
@@ -582,6 +586,8 @@ namespace dlib
}
}
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
global_function_search
::
...
...
@@ -701,13 +707,13 @@ namespace dlib
outstanding_function_eval_request
new_req
;
new_req
.
request_id
=
next_request_id
++
;
new_req
.
x
=
make_random_vector
(
rnd
,
info
->
spec
.
lower
,
info
->
spec
.
upper
,
info
->
spec
.
is_integer_variable
);
info
->
incomplete
_evals
.
emplace_back
(
new_req
);
info
->
outstanding
_evals
.
emplace_back
(
new_req
);
return
function_evaluation_request
(
new_req
,
info
);
}
}
if
(
do_trust_region_step
&&
!
has_
incomplete
_trust_region_request
())
if
(
do_trust_region_step
&&
!
has_
outstanding
_trust_region_request
())
{
// find the currently best performing function, we will do a trust region
// step on it.
...
...
@@ -716,7 +722,7 @@ namespace dlib
// if we have enough points to do a trust region step
if
(
info
->
ub
.
num_points
()
>
dims
+
1
)
{
auto
tmp
=
pick_next_sample_
quad_interp
(
info
->
ub
.
get_points
(),
auto
tmp
=
pick_next_sample_
using_trust_region
(
info
->
ub
.
get_points
(),
info
->
radius
,
info
->
spec
.
lower
,
info
->
spec
.
upper
,
info
->
spec
.
is_integer_variable
);
//std::cout << "QP predicted improvement: "<< tmp.predicted_improvement << std::endl;
if
(
tmp
.
predicted_improvement
>
min_trust_region_epsilon
)
...
...
@@ -728,7 +734,7 @@ namespace dlib
new_req
.
was_trust_region_generated_request
=
true
;
new_req
.
anchor_objective_value
=
info
->
best_objective_value
;
new_req
.
predicted_improvement
=
tmp
.
predicted_improvement
;
info
->
incomplete
_evals
.
emplace_back
(
new_req
);
info
->
outstanding
_evals
.
emplace_back
(
new_req
);
return
function_evaluation_request
(
new_req
,
info
);
}
}
...
...
@@ -747,7 +753,7 @@ namespace dlib
// function with the largest upper bound for evaluation.
for
(
auto
&
info
:
functions
)
{
auto
tmp
=
pick_next_sample_
max_upper_bound_function
(
rnd
,
auto
tmp
=
pick_next_sample_
as_max_upper_bound
(
rnd
,
info
->
build_upper_bound_with_all_function_evals
(),
info
->
spec
.
lower
,
info
->
spec
.
upper
,
info
->
spec
.
is_integer_variable
,
num_random_samples
);
if
(
tmp
.
predicted_improvement
>
0
&&
tmp
.
upper_bound
>
best_upper_bound
)
...
...
@@ -764,7 +770,7 @@ namespace dlib
outstanding_function_eval_request
new_req
;
new_req
.
request_id
=
next_request_id
++
;
new_req
.
x
=
std
::
move
(
next_sample
);
best_funct
->
incomplete
_evals
.
emplace_back
(
new_req
);
best_funct
->
outstanding
_evals
.
emplace_back
(
new_req
);
return
function_evaluation_request
(
new_req
,
best_funct
);
}
}
...
...
@@ -776,7 +782,7 @@ namespace dlib
outstanding_function_eval_request
new_req
;
new_req
.
request_id
=
next_request_id
++
;
new_req
.
x
=
make_random_vector
(
rnd
,
info
->
spec
.
lower
,
info
->
spec
.
upper
,
info
->
spec
.
is_integer_variable
);
info
->
incomplete
_evals
.
emplace_back
(
new_req
);
info
->
outstanding
_evals
.
emplace_back
(
new_req
);
return
function_evaluation_request
(
new_req
,
info
);
}
...
...
@@ -839,9 +845,13 @@ namespace dlib
{
DLIB_CASSERT
(
0
<=
value
);
relative_noise_magnitude
=
value
;
// recreate all the upper bound functions with the new relative noise magnitude
for
(
auto
&
f
:
functions
)
f
->
ub
=
upper_bound_function
(
f
->
ub
.
get_points
(),
relative_noise_magnitude
);
if
(
m
)
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
*
m
);
// recreate all the upper bound functions with the new relative noise magnitude
for
(
auto
&
f
:
functions
)
f
->
ub
=
upper_bound_function
(
f
->
ub
.
get_points
(),
relative_noise_magnitude
);
}
}
// ----------------------------------------------------------------------------------------
...
...
@@ -881,8 +891,10 @@ namespace dlib
size_t
&
idx
)
const
{
auto
i
=
std
::
max_element
(
functions
.
begin
(),
functions
.
end
(),
[](
const
std
::
shared_ptr
<
gopt_impl
::
funct_info
>&
a
,
const
std
::
shared_ptr
<
gopt_impl
::
funct_info
>&
b
)
{
return
a
->
best_objective_value
<
b
->
best_objective_value
;
});
auto
compare
=
[](
const
std
::
shared_ptr
<
gopt_impl
::
funct_info
>&
a
,
const
std
::
shared_ptr
<
gopt_impl
::
funct_info
>&
b
)
{
return
a
->
best_objective_value
<
b
->
best_objective_value
;
};
auto
i
=
std
::
max_element
(
functions
.
begin
(),
functions
.
end
(),
compare
);
idx
=
std
::
distance
(
functions
.
begin
(),
i
);
return
*
i
;
...
...
@@ -891,12 +903,12 @@ namespace dlib
// ----------------------------------------------------------------------------------------
bool
global_function_search
::
has_
incomplete
_trust_region_request
(
has_
outstanding
_trust_region_request
(
)
const
{
for
(
auto
&
f
:
functions
)
{
for
(
auto
&
i
:
f
->
incomplete
_evals
)
for
(
auto
&
i
:
f
->
outstanding
_evals
)
{
if
(
i
.
was_trust_region_generated_request
)
return
true
;
...
...
dlib/global_optimization/global_function_search.h
View file @
1fbd1828
...
...
@@ -79,7 +79,7 @@ namespace dlib
size_t
function_idx
=
0
;
std
::
shared_ptr
<
std
::
mutex
>
m
;
upper_bound_function
ub
;
std
::
vector
<
outstanding_function_eval_request
>
incomplete
_evals
;
std
::
vector
<
outstanding_function_eval_request
>
outstanding
_evals
;
matrix
<
double
,
0
,
1
>
best_x
;
double
best_objective_value
=
-
std
::
numeric_limits
<
double
>::
infinity
();
double
radius
=
0
;
...
...
@@ -101,7 +101,7 @@ namespace dlib
function_evaluation_request
(
function_evaluation_request
&&
item
);
function_evaluation_request
&
operator
=
(
function_evaluation_request
&&
item
);
void
swap
(
function_evaluation_request
&
item
);
~
function_evaluation_request
(
);
size_t
function_idx
(
)
const
;
...
...
@@ -112,12 +112,12 @@ namespace dlib
bool
has_been_evaluated
(
)
const
;
~
function_evaluation_request
();
void
set
(
double
y
);
void
swap
(
function_evaluation_request
&
item
);
private
:
friend
class
global_function_search
;
...
...
@@ -218,7 +218,7 @@ namespace dlib
size_t
&
idx
)
const
;
bool
has_
incomplete
_trust_region_request
(
bool
has_
outstanding
_trust_region_request
(
)
const
;
...
...
dlib/global_optimization/global_function_search_abstract.h
View file @
1fbd1828
...
...
@@ -89,14 +89,6 @@ namespace dlib
moving from item causes item.has_been_evaluated() == true, TODO, clarify
!*/
void
swap
(
function_evaluation_request
&
item
);
/*!
ensures
- swaps the state of *this and item
!*/
~
function_evaluation_request
(
);
/*!
...
...
@@ -113,7 +105,6 @@ namespace dlib
bool
has_been_evaluated
(
)
const
;
void
set
(
double
y
);
...
...
@@ -124,6 +115,14 @@ namespace dlib
- #has_been_evaluated() == true
!*/
void
swap
(
function_evaluation_request
&
item
);
/*!
ensures
- swaps the state of *this and item
!*/
};
// ----------------------------------------------------------------------------------------
...
...
@@ -143,18 +142,6 @@ namespace dlib
- #num_functions() == 0
!*/
// This object can't be copied.
global_function_search
(
const
global_function_search
&
)
=
delete
;
global_function_search
&
operator
=
(
const
global_function_search
&
item
)
=
delete
;
global_function_search
(
global_function_search
&&
item
)
=
default
;
global_function_search
&
operator
=
(
global_function_search
&&
item
)
=
default
;
/*!
ensures
- moves the state of item into *this
- #item.num_functions() == 0
!*/
explicit
global_function_search
(
const
function_spec
&
function
);
...
...
@@ -169,13 +156,25 @@ namespace dlib
const
double
relative_noise_magnitude
=
0
.
001
);
size_t
num_functions
(
)
const
;
// This object can't be copied.
global_function_search
(
const
global_function_search
&
)
=
delete
;
global_function_search
&
operator
=
(
const
global_function_search
&
item
)
=
delete
;
global_function_search
(
global_function_search
&&
item
)
=
default
;
global_function_search
&
operator
=
(
global_function_search
&&
item
)
=
default
;
/*!
ensures
- moves the state of item into *this
- #item.num_functions() == 0
!*/
void
set_seed
(
time_t
seed
);
size_t
num_functions
(
)
const
;
void
get_function_evaluations
(
std
::
vector
<
function_spec
>&
specs
,
std
::
vector
<
std
::
vector
<
function_evaluation
>>&
function_evals
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment