Commit 1fbd1828 authored by Davis King's avatar Davis King

Cleaned up the code a bit.

parent 0d9043bc
......@@ -11,7 +11,7 @@ namespace dlib
namespace qopt_impl
{
void fit_qp_mse(
void fit_quadratic_to_points_mse(
const matrix<double>& X,
const matrix<double,0,1>& Y,
matrix<double>& H,
......@@ -64,7 +64,7 @@ namespace dlib
// ----------------------------------------------------------------------------------------
void fit_qp(
void fit_quadratic_to_points(
const matrix<double>& X,
const matrix<double,0,1>& Y,
matrix<double>& H,
......@@ -75,7 +75,7 @@ namespace dlib
requires
- X.size() > 0
- X.nc() == Y.size()
- X.nr()+1 <= X.nc() <= (X.nr()+1)*(X.nr()+2)/2
- X.nr()+1 <= X.nc()
ensures
- This function finds a quadratic function, Q(x), that interpolates the
given set of points. If there aren't enough points to uniquely define
......@@ -87,16 +87,19 @@ namespace dlib
sum(squared(H))
such that:
Q(colm(X,i)) == Y(i), for all valid i
- If there are more points than necessary to constrain Q then the Q
that best interpolates the function in the mean squared sense is
found.
!*/
{
DLIB_CASSERT(X.size() > 0);
DLIB_CASSERT(X.nc() == Y.size());
DLIB_CASSERT(X.nr()+1 <= X.nc());// && X.nc() <= (X.nr()+1)*(X.nr()+2)/2);
DLIB_CASSERT(X.nr()+1 <= X.nc());
if (X.nc() >= (X.nr()+1)*(X.nr()+2)/2)
{
fit_qp_mse(X,Y,H,g,c);
fit_quadratic_to_points_mse(X,Y,H,g,c);
return;
}
......@@ -180,7 +183,7 @@ namespace dlib
matrix<double,0,1> g;
double c;
fit_qp(X, Y, H, g, c);
fit_quadratic_to_points(X, Y, H, g, c);
matrix<double,0,1> p;
......@@ -198,7 +201,7 @@ namespace dlib
// ----------------------------------------------------------------------------------------
quad_interp_result pick_next_sample_quad_interp (
quad_interp_result pick_next_sample_using_trust_region (
const std::vector<function_evaluation>& samples,
double& radius,
const matrix<double,0,1>& lower,
......@@ -324,7 +327,7 @@ namespace dlib
// ------------------------------------------------------------------------------------
max_upper_bound_function pick_next_sample_max_upper_bound_function (
max_upper_bound_function pick_next_sample_as_max_upper_bound (
dlib::rand& rnd,
const upper_bound_function& ub,
const matrix<double,0,1>& lower,
......@@ -417,10 +420,10 @@ namespace dlib
{
upper_bound_function tmp(ub);
// we are going to add the incomplete evals into this and assume the
// incomplete evals are going to take y values equal to their nearest
// we are going to add the outstanding evals into this and assume the
// outstanding evals are going to take y values equal to their nearest
// neighbor complete evals.
for (auto& eval : incomplete_evals)
for (auto& eval : outstanding_evals)
{
function_evaluation e;
e.x = eval.x;
......@@ -454,6 +457,7 @@ namespace dlib
} // end namespace gopt_impl
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
......@@ -526,9 +530,9 @@ namespace dlib
{
std::lock_guard<std::mutex> lock(*info->m);
// remove the evaluation request from the incomplete list.
auto i = std::find(info->incomplete_evals.begin(), info->incomplete_evals.end(), req);
info->incomplete_evals.erase(i);
// remove the evaluation request from the outstanding list.
auto i = std::find(info->outstanding_evals.begin(), info->outstanding_evals.end(), req);
info->outstanding_evals.erase(i);
}
}
......@@ -545,10 +549,10 @@ namespace dlib
m_has_been_evaluated = true;
// move the evaluation from incomplete to complete
auto i = std::find(info->incomplete_evals.begin(), info->incomplete_evals.end(), req);
DLIB_CASSERT(i != info->incomplete_evals.end());
info->incomplete_evals.erase(i);
// move the evaluation from outstanding to complete
auto i = std::find(info->outstanding_evals.begin(), info->outstanding_evals.end(), req);
DLIB_CASSERT(i != info->outstanding_evals.end());
info->outstanding_evals.erase(i);
info->ub.add(function_evaluation(req.x,y));
......@@ -582,6 +586,8 @@ namespace dlib
}
}
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
global_function_search::
......@@ -701,13 +707,13 @@ namespace dlib
outstanding_function_eval_request new_req;
new_req.request_id = next_request_id++;
new_req.x = make_random_vector(rnd, info->spec.lower, info->spec.upper, info->spec.is_integer_variable);
info->incomplete_evals.emplace_back(new_req);
info->outstanding_evals.emplace_back(new_req);
return function_evaluation_request(new_req,info);
}
}
if (do_trust_region_step && !has_incomplete_trust_region_request())
if (do_trust_region_step && !has_outstanding_trust_region_request())
{
// find the currently best performing function, we will do a trust region
// step on it.
......@@ -716,7 +722,7 @@ namespace dlib
// if we have enough points to do a trust region step
if (info->ub.num_points() > dims+1)
{
auto tmp = pick_next_sample_quad_interp(info->ub.get_points(),
auto tmp = pick_next_sample_using_trust_region(info->ub.get_points(),
info->radius, info->spec.lower, info->spec.upper, info->spec.is_integer_variable);
//std::cout << "QP predicted improvement: "<< tmp.predicted_improvement << std::endl;
if (tmp.predicted_improvement > min_trust_region_epsilon)
......@@ -728,7 +734,7 @@ namespace dlib
new_req.was_trust_region_generated_request = true;
new_req.anchor_objective_value = info->best_objective_value;
new_req.predicted_improvement = tmp.predicted_improvement;
info->incomplete_evals.emplace_back(new_req);
info->outstanding_evals.emplace_back(new_req);
return function_evaluation_request(new_req, info);
}
}
......@@ -747,7 +753,7 @@ namespace dlib
// function with the largest upper bound for evaluation.
for (auto& info : functions)
{
auto tmp = pick_next_sample_max_upper_bound_function(rnd,
auto tmp = pick_next_sample_as_max_upper_bound(rnd,
info->build_upper_bound_with_all_function_evals(), info->spec.lower, info->spec.upper,
info->spec.is_integer_variable, num_random_samples);
if (tmp.predicted_improvement > 0 && tmp.upper_bound > best_upper_bound)
......@@ -764,7 +770,7 @@ namespace dlib
outstanding_function_eval_request new_req;
new_req.request_id = next_request_id++;
new_req.x = std::move(next_sample);
best_funct->incomplete_evals.emplace_back(new_req);
best_funct->outstanding_evals.emplace_back(new_req);
return function_evaluation_request(new_req, best_funct);
}
}
......@@ -776,7 +782,7 @@ namespace dlib
outstanding_function_eval_request new_req;
new_req.request_id = next_request_id++;
new_req.x = make_random_vector(rnd, info->spec.lower, info->spec.upper, info->spec.is_integer_variable);
info->incomplete_evals.emplace_back(new_req);
info->outstanding_evals.emplace_back(new_req);
return function_evaluation_request(new_req, info);
}
......@@ -839,10 +845,14 @@ namespace dlib
{
DLIB_CASSERT(0 <= value);
relative_noise_magnitude = value;
if (m)
{
std::lock_guard<std::mutex> lock(*m);
// recreate all the upper bound functions with the new relative noise magnitude
for (auto& f : functions)
f->ub = upper_bound_function(f->ub.get_points(), relative_noise_magnitude);
}
}
// ----------------------------------------------------------------------------------------
......@@ -881,8 +891,10 @@ namespace dlib
size_t& idx
) const
{
auto i = std::max_element(functions.begin(), functions.end(),
[](const std::shared_ptr<gopt_impl::funct_info>& a, const std::shared_ptr<gopt_impl::funct_info>& b) { return a->best_objective_value < b->best_objective_value; });
auto compare = [](const std::shared_ptr<gopt_impl::funct_info>& a, const std::shared_ptr<gopt_impl::funct_info>& b)
{ return a->best_objective_value < b->best_objective_value; };
auto i = std::max_element(functions.begin(), functions.end(), compare);
idx = std::distance(functions.begin(),i);
return *i;
......@@ -891,12 +903,12 @@ namespace dlib
// ----------------------------------------------------------------------------------------
bool global_function_search::
has_incomplete_trust_region_request (
has_outstanding_trust_region_request (
) const
{
for (auto& f : functions)
{
for (auto& i : f->incomplete_evals)
for (auto& i : f->outstanding_evals)
{
if (i.was_trust_region_generated_request)
return true;
......
......@@ -79,7 +79,7 @@ namespace dlib
size_t function_idx = 0;
std::shared_ptr<std::mutex> m;
upper_bound_function ub;
std::vector<outstanding_function_eval_request> incomplete_evals;
std::vector<outstanding_function_eval_request> outstanding_evals;
matrix<double,0,1> best_x;
double best_objective_value = -std::numeric_limits<double>::infinity();
double radius = 0;
......@@ -101,7 +101,7 @@ namespace dlib
function_evaluation_request(function_evaluation_request&& item);
function_evaluation_request& operator=(function_evaluation_request&& item);
void swap(function_evaluation_request& item);
~function_evaluation_request();
size_t function_idx (
) const;
......@@ -112,12 +112,12 @@ namespace dlib
bool has_been_evaluated (
) const;
~function_evaluation_request();
void set (
double y
);
void swap(function_evaluation_request& item);
private:
friend class global_function_search;
......@@ -218,7 +218,7 @@ namespace dlib
size_t& idx
) const;
bool has_incomplete_trust_region_request (
bool has_outstanding_trust_region_request (
) const;
......
......@@ -89,14 +89,6 @@ namespace dlib
moving from item causes item.has_been_evaluated() == true, TODO, clarify
!*/
void swap(
function_evaluation_request& item
);
/*!
ensures
- swaps the state of *this and item
!*/
~function_evaluation_request(
);
/*!
......@@ -113,7 +105,6 @@ namespace dlib
bool has_been_evaluated (
) const;
void set (
double y
);
......@@ -124,6 +115,14 @@ namespace dlib
- #has_been_evaluated() == true
!*/
void swap(
function_evaluation_request& item
);
/*!
ensures
- swaps the state of *this and item
!*/
};
// ----------------------------------------------------------------------------------------
......@@ -143,18 +142,6 @@ namespace dlib
- #num_functions() == 0
!*/
// This object can't be copied.
global_function_search(const global_function_search&) = delete;
global_function_search& operator=(const global_function_search& item) = delete;
global_function_search(global_function_search&& item) = default;
global_function_search& operator=(global_function_search&& item) = default;
/*!
ensures
- moves the state of item into *this
- #item.num_functions() == 0
!*/
explicit global_function_search(
const function_spec& function
);
......@@ -169,13 +156,25 @@ namespace dlib
const double relative_noise_magnitude = 0.001
);
size_t num_functions(
) const;
// This object can't be copied.
global_function_search(const global_function_search&) = delete;
global_function_search& operator=(const global_function_search& item) = delete;
global_function_search(global_function_search&& item) = default;
global_function_search& operator=(global_function_search&& item) = default;
/*!
ensures
- moves the state of item into *this
- #item.num_functions() == 0
!*/
void set_seed (
time_t seed
);
size_t num_functions(
) const;
void get_function_evaluations (
std::vector<function_spec>& specs,
std::vector<std::vector<function_evaluation>>& function_evals
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment