Commit 1fbd1828 authored by Davis King's avatar Davis King

Cleaned up the code a bit.

parent 0d9043bc
...@@ -11,7 +11,7 @@ namespace dlib ...@@ -11,7 +11,7 @@ namespace dlib
namespace qopt_impl namespace qopt_impl
{ {
void fit_qp_mse( void fit_quadratic_to_points_mse(
const matrix<double>& X, const matrix<double>& X,
const matrix<double,0,1>& Y, const matrix<double,0,1>& Y,
matrix<double>& H, matrix<double>& H,
...@@ -64,21 +64,21 @@ namespace dlib ...@@ -64,21 +64,21 @@ namespace dlib
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
void fit_qp( void fit_quadratic_to_points(
const matrix<double>& X, const matrix<double>& X,
const matrix<double,0,1>& Y, const matrix<double,0,1>& Y,
matrix<double>& H, matrix<double>& H,
matrix<double,0,1>& g, matrix<double,0,1>& g,
double& c double& c
) )
/*! /*!
requires requires
- X.size() > 0 - X.size() > 0
- X.nc() == Y.size() - X.nc() == Y.size()
- X.nr()+1 <= X.nc() <= (X.nr()+1)*(X.nr()+2)/2 - X.nr()+1 <= X.nc()
ensures ensures
- This function finds a quadratic function, Q(x), that interpolates the - This function finds a quadratic function, Q(x), that interpolates the
given set of points. If there aren't enough points to uniquely define given set of points. If there aren't enough points to uniquely define
Q(x) then the Q(x) that fits the given points with the minimum Frobenius Q(x) then the Q(x) that fits the given points with the minimum Frobenius
norm hessian matrix is selected. norm hessian matrix is selected.
- To be precise: - To be precise:
...@@ -87,16 +87,19 @@ namespace dlib ...@@ -87,16 +87,19 @@ namespace dlib
sum(squared(H)) sum(squared(H))
such that: such that:
Q(colm(X,i)) == Y(i), for all valid i Q(colm(X,i)) == Y(i), for all valid i
!*/ - If there are more points than necessary to constrain Q then the Q
that best interpolates the function in the mean squared sense is
found.
!*/
{ {
DLIB_CASSERT(X.size() > 0); DLIB_CASSERT(X.size() > 0);
DLIB_CASSERT(X.nc() == Y.size()); DLIB_CASSERT(X.nc() == Y.size());
DLIB_CASSERT(X.nr()+1 <= X.nc());// && X.nc() <= (X.nr()+1)*(X.nr()+2)/2); DLIB_CASSERT(X.nr()+1 <= X.nc());
if (X.nc() >= (X.nr()+1)*(X.nr()+2)/2) if (X.nc() >= (X.nr()+1)*(X.nr()+2)/2)
{ {
fit_qp_mse(X,Y,H,g,c); fit_quadratic_to_points_mse(X,Y,H,g,c);
return; return;
} }
...@@ -180,7 +183,7 @@ namespace dlib ...@@ -180,7 +183,7 @@ namespace dlib
matrix<double,0,1> g; matrix<double,0,1> g;
double c; double c;
fit_qp(X, Y, H, g, c); fit_quadratic_to_points(X, Y, H, g, c);
matrix<double,0,1> p; matrix<double,0,1> p;
...@@ -198,7 +201,7 @@ namespace dlib ...@@ -198,7 +201,7 @@ namespace dlib
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
quad_interp_result pick_next_sample_quad_interp ( quad_interp_result pick_next_sample_using_trust_region (
const std::vector<function_evaluation>& samples, const std::vector<function_evaluation>& samples,
double& radius, double& radius,
const matrix<double,0,1>& lower, const matrix<double,0,1>& lower,
...@@ -324,7 +327,7 @@ namespace dlib ...@@ -324,7 +327,7 @@ namespace dlib
// ------------------------------------------------------------------------------------ // ------------------------------------------------------------------------------------
max_upper_bound_function pick_next_sample_max_upper_bound_function ( max_upper_bound_function pick_next_sample_as_max_upper_bound (
dlib::rand& rnd, dlib::rand& rnd,
const upper_bound_function& ub, const upper_bound_function& ub,
const matrix<double,0,1>& lower, const matrix<double,0,1>& lower,
...@@ -417,10 +420,10 @@ namespace dlib ...@@ -417,10 +420,10 @@ namespace dlib
{ {
upper_bound_function tmp(ub); upper_bound_function tmp(ub);
// we are going to add the incomplete evals into this and assume the // we are going to add the outstanding evals into this and assume the
// incomplete evals are going to take y values equal to their nearest // outstanding evals are going to take y values equal to their nearest
// neighbor complete evals. // neighbor complete evals.
for (auto& eval : incomplete_evals) for (auto& eval : outstanding_evals)
{ {
function_evaluation e; function_evaluation e;
e.x = eval.x; e.x = eval.x;
...@@ -454,6 +457,7 @@ namespace dlib ...@@ -454,6 +457,7 @@ namespace dlib
} // end namespace gopt_impl } // end namespace gopt_impl
// ----------------------------------------------------------------------------------------
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
...@@ -526,9 +530,9 @@ namespace dlib ...@@ -526,9 +530,9 @@ namespace dlib
{ {
std::lock_guard<std::mutex> lock(*info->m); std::lock_guard<std::mutex> lock(*info->m);
// remove the evaluation request from the incomplete list. // remove the evaluation request from the outstanding list.
auto i = std::find(info->incomplete_evals.begin(), info->incomplete_evals.end(), req); auto i = std::find(info->outstanding_evals.begin(), info->outstanding_evals.end(), req);
info->incomplete_evals.erase(i); info->outstanding_evals.erase(i);
} }
} }
...@@ -545,10 +549,10 @@ namespace dlib ...@@ -545,10 +549,10 @@ namespace dlib
m_has_been_evaluated = true; m_has_been_evaluated = true;
// move the evaluation from incomplete to complete // move the evaluation from outstanding to complete
auto i = std::find(info->incomplete_evals.begin(), info->incomplete_evals.end(), req); auto i = std::find(info->outstanding_evals.begin(), info->outstanding_evals.end(), req);
DLIB_CASSERT(i != info->incomplete_evals.end()); DLIB_CASSERT(i != info->outstanding_evals.end());
info->incomplete_evals.erase(i); info->outstanding_evals.erase(i);
info->ub.add(function_evaluation(req.x,y)); info->ub.add(function_evaluation(req.x,y));
...@@ -582,6 +586,8 @@ namespace dlib ...@@ -582,6 +586,8 @@ namespace dlib
} }
} }
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
global_function_search:: global_function_search::
...@@ -701,13 +707,13 @@ namespace dlib ...@@ -701,13 +707,13 @@ namespace dlib
outstanding_function_eval_request new_req; outstanding_function_eval_request new_req;
new_req.request_id = next_request_id++; new_req.request_id = next_request_id++;
new_req.x = make_random_vector(rnd, info->spec.lower, info->spec.upper, info->spec.is_integer_variable); new_req.x = make_random_vector(rnd, info->spec.lower, info->spec.upper, info->spec.is_integer_variable);
info->incomplete_evals.emplace_back(new_req); info->outstanding_evals.emplace_back(new_req);
return function_evaluation_request(new_req,info); return function_evaluation_request(new_req,info);
} }
} }
if (do_trust_region_step && !has_incomplete_trust_region_request()) if (do_trust_region_step && !has_outstanding_trust_region_request())
{ {
// find the currently best performing function, we will do a trust region // find the currently best performing function, we will do a trust region
// step on it. // step on it.
...@@ -716,7 +722,7 @@ namespace dlib ...@@ -716,7 +722,7 @@ namespace dlib
// if we have enough points to do a trust region step // if we have enough points to do a trust region step
if (info->ub.num_points() > dims+1) if (info->ub.num_points() > dims+1)
{ {
auto tmp = pick_next_sample_quad_interp(info->ub.get_points(), auto tmp = pick_next_sample_using_trust_region(info->ub.get_points(),
info->radius, info->spec.lower, info->spec.upper, info->spec.is_integer_variable); info->radius, info->spec.lower, info->spec.upper, info->spec.is_integer_variable);
//std::cout << "QP predicted improvement: "<< tmp.predicted_improvement << std::endl; //std::cout << "QP predicted improvement: "<< tmp.predicted_improvement << std::endl;
if (tmp.predicted_improvement > min_trust_region_epsilon) if (tmp.predicted_improvement > min_trust_region_epsilon)
...@@ -728,7 +734,7 @@ namespace dlib ...@@ -728,7 +734,7 @@ namespace dlib
new_req.was_trust_region_generated_request = true; new_req.was_trust_region_generated_request = true;
new_req.anchor_objective_value = info->best_objective_value; new_req.anchor_objective_value = info->best_objective_value;
new_req.predicted_improvement = tmp.predicted_improvement; new_req.predicted_improvement = tmp.predicted_improvement;
info->incomplete_evals.emplace_back(new_req); info->outstanding_evals.emplace_back(new_req);
return function_evaluation_request(new_req, info); return function_evaluation_request(new_req, info);
} }
} }
...@@ -747,7 +753,7 @@ namespace dlib ...@@ -747,7 +753,7 @@ namespace dlib
// function with the largest upper bound for evaluation. // function with the largest upper bound for evaluation.
for (auto& info : functions) for (auto& info : functions)
{ {
auto tmp = pick_next_sample_max_upper_bound_function(rnd, auto tmp = pick_next_sample_as_max_upper_bound(rnd,
info->build_upper_bound_with_all_function_evals(), info->spec.lower, info->spec.upper, info->build_upper_bound_with_all_function_evals(), info->spec.lower, info->spec.upper,
info->spec.is_integer_variable, num_random_samples); info->spec.is_integer_variable, num_random_samples);
if (tmp.predicted_improvement > 0 && tmp.upper_bound > best_upper_bound) if (tmp.predicted_improvement > 0 && tmp.upper_bound > best_upper_bound)
...@@ -764,7 +770,7 @@ namespace dlib ...@@ -764,7 +770,7 @@ namespace dlib
outstanding_function_eval_request new_req; outstanding_function_eval_request new_req;
new_req.request_id = next_request_id++; new_req.request_id = next_request_id++;
new_req.x = std::move(next_sample); new_req.x = std::move(next_sample);
best_funct->incomplete_evals.emplace_back(new_req); best_funct->outstanding_evals.emplace_back(new_req);
return function_evaluation_request(new_req, best_funct); return function_evaluation_request(new_req, best_funct);
} }
} }
...@@ -776,7 +782,7 @@ namespace dlib ...@@ -776,7 +782,7 @@ namespace dlib
outstanding_function_eval_request new_req; outstanding_function_eval_request new_req;
new_req.request_id = next_request_id++; new_req.request_id = next_request_id++;
new_req.x = make_random_vector(rnd, info->spec.lower, info->spec.upper, info->spec.is_integer_variable); new_req.x = make_random_vector(rnd, info->spec.lower, info->spec.upper, info->spec.is_integer_variable);
info->incomplete_evals.emplace_back(new_req); info->outstanding_evals.emplace_back(new_req);
return function_evaluation_request(new_req, info); return function_evaluation_request(new_req, info);
} }
...@@ -839,9 +845,13 @@ namespace dlib ...@@ -839,9 +845,13 @@ namespace dlib
{ {
DLIB_CASSERT(0 <= value); DLIB_CASSERT(0 <= value);
relative_noise_magnitude = value; relative_noise_magnitude = value;
// recreate all the upper bound functions with the new relative noise magnitude if (m)
for (auto& f : functions) {
f->ub = upper_bound_function(f->ub.get_points(), relative_noise_magnitude); std::lock_guard<std::mutex> lock(*m);
// recreate all the upper bound functions with the new relative noise magnitude
for (auto& f : functions)
f->ub = upper_bound_function(f->ub.get_points(), relative_noise_magnitude);
}
} }
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
...@@ -881,8 +891,10 @@ namespace dlib ...@@ -881,8 +891,10 @@ namespace dlib
size_t& idx size_t& idx
) const ) const
{ {
auto i = std::max_element(functions.begin(), functions.end(), auto compare = [](const std::shared_ptr<gopt_impl::funct_info>& a, const std::shared_ptr<gopt_impl::funct_info>& b)
[](const std::shared_ptr<gopt_impl::funct_info>& a, const std::shared_ptr<gopt_impl::funct_info>& b) { return a->best_objective_value < b->best_objective_value; }); { return a->best_objective_value < b->best_objective_value; };
auto i = std::max_element(functions.begin(), functions.end(), compare);
idx = std::distance(functions.begin(),i); idx = std::distance(functions.begin(),i);
return *i; return *i;
...@@ -891,12 +903,12 @@ namespace dlib ...@@ -891,12 +903,12 @@ namespace dlib
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
bool global_function_search:: bool global_function_search::
has_incomplete_trust_region_request ( has_outstanding_trust_region_request (
) const ) const
{ {
for (auto& f : functions) for (auto& f : functions)
{ {
for (auto& i : f->incomplete_evals) for (auto& i : f->outstanding_evals)
{ {
if (i.was_trust_region_generated_request) if (i.was_trust_region_generated_request)
return true; return true;
......
...@@ -79,7 +79,7 @@ namespace dlib ...@@ -79,7 +79,7 @@ namespace dlib
size_t function_idx = 0; size_t function_idx = 0;
std::shared_ptr<std::mutex> m; std::shared_ptr<std::mutex> m;
upper_bound_function ub; upper_bound_function ub;
std::vector<outstanding_function_eval_request> incomplete_evals; std::vector<outstanding_function_eval_request> outstanding_evals;
matrix<double,0,1> best_x; matrix<double,0,1> best_x;
double best_objective_value = -std::numeric_limits<double>::infinity(); double best_objective_value = -std::numeric_limits<double>::infinity();
double radius = 0; double radius = 0;
...@@ -101,7 +101,7 @@ namespace dlib ...@@ -101,7 +101,7 @@ namespace dlib
function_evaluation_request(function_evaluation_request&& item); function_evaluation_request(function_evaluation_request&& item);
function_evaluation_request& operator=(function_evaluation_request&& item); function_evaluation_request& operator=(function_evaluation_request&& item);
void swap(function_evaluation_request& item); ~function_evaluation_request();
size_t function_idx ( size_t function_idx (
) const; ) const;
...@@ -112,12 +112,12 @@ namespace dlib ...@@ -112,12 +112,12 @@ namespace dlib
bool has_been_evaluated ( bool has_been_evaluated (
) const; ) const;
~function_evaluation_request();
void set ( void set (
double y double y
); );
void swap(function_evaluation_request& item);
private: private:
friend class global_function_search; friend class global_function_search;
...@@ -218,7 +218,7 @@ namespace dlib ...@@ -218,7 +218,7 @@ namespace dlib
size_t& idx size_t& idx
) const; ) const;
bool has_incomplete_trust_region_request ( bool has_outstanding_trust_region_request (
) const; ) const;
......
...@@ -89,14 +89,6 @@ namespace dlib ...@@ -89,14 +89,6 @@ namespace dlib
moving from item causes item.has_been_evaluated() == true, TODO, clarify moving from item causes item.has_been_evaluated() == true, TODO, clarify
!*/ !*/
void swap(
function_evaluation_request& item
);
/*!
ensures
- swaps the state of *this and item
!*/
~function_evaluation_request( ~function_evaluation_request(
); );
/*! /*!
...@@ -113,7 +105,6 @@ namespace dlib ...@@ -113,7 +105,6 @@ namespace dlib
bool has_been_evaluated ( bool has_been_evaluated (
) const; ) const;
void set ( void set (
double y double y
); );
...@@ -124,6 +115,14 @@ namespace dlib ...@@ -124,6 +115,14 @@ namespace dlib
- #has_been_evaluated() == true - #has_been_evaluated() == true
!*/ !*/
void swap(
function_evaluation_request& item
);
/*!
ensures
- swaps the state of *this and item
!*/
}; };
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
...@@ -143,18 +142,6 @@ namespace dlib ...@@ -143,18 +142,6 @@ namespace dlib
- #num_functions() == 0 - #num_functions() == 0
!*/ !*/
// This object can't be copied.
global_function_search(const global_function_search&) = delete;
global_function_search& operator=(const global_function_search& item) = delete;
global_function_search(global_function_search&& item) = default;
global_function_search& operator=(global_function_search&& item) = default;
/*!
ensures
- moves the state of item into *this
- #item.num_functions() == 0
!*/
explicit global_function_search( explicit global_function_search(
const function_spec& function const function_spec& function
); );
...@@ -169,13 +156,25 @@ namespace dlib ...@@ -169,13 +156,25 @@ namespace dlib
const double relative_noise_magnitude = 0.001 const double relative_noise_magnitude = 0.001
); );
size_t num_functions( // This object can't be copied.
) const; global_function_search(const global_function_search&) = delete;
global_function_search& operator=(const global_function_search& item) = delete;
global_function_search(global_function_search&& item) = default;
global_function_search& operator=(global_function_search&& item) = default;
/*!
ensures
- moves the state of item into *this
- #item.num_functions() == 0
!*/
void set_seed ( void set_seed (
time_t seed time_t seed
); );
size_t num_functions(
) const;
void get_function_evaluations ( void get_function_evaluations (
std::vector<function_spec>& specs, std::vector<function_spec>& specs,
std::vector<std::vector<function_evaluation>>& function_evals std::vector<std::vector<function_evaluation>>& function_evals
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment