Commit c91a0473 authored by Davis King's avatar Davis King

more cleanup

parent 1eee5ccd
...@@ -6,6 +6,9 @@ ...@@ -6,6 +6,9 @@
namespace dlib namespace dlib
{ {
// ----------------------------------------------------------------------------------------
namespace qopt_impl namespace qopt_impl
{ {
void fit_qp_mse( void fit_qp_mse(
...@@ -320,6 +323,8 @@ namespace dlib ...@@ -320,6 +323,8 @@ namespace dlib
double upper_bound = 0; double upper_bound = 0;
}; };
// ------------------------------------------------------------------------------------
max_upper_bound_function pick_next_sample_max_upper_bound_function ( max_upper_bound_function pick_next_sample_max_upper_bound_function (
dlib::rand& rnd, dlib::rand& rnd,
const upper_bound_function& ub, const upper_bound_function& ub,
...@@ -362,7 +367,11 @@ namespace dlib ...@@ -362,7 +367,11 @@ namespace dlib
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
function_spec::function_spec(const matrix<double,0,1>& lower_, const matrix<double,0,1>& upper_) : lower(lower_), upper(upper_) function_spec::function_spec(
const matrix<double,0,1>& lower_,
const matrix<double,0,1>& upper_
) :
lower(lower_), upper(upper_)
{ {
DLIB_CASSERT(lower.size() == upper.size()); DLIB_CASSERT(lower.size() == upper.size());
for (size_t i = 0; i < lower.size(); ++i) for (size_t i = 0; i < lower.size(); ++i)
...@@ -374,7 +383,14 @@ namespace dlib ...@@ -374,7 +383,14 @@ namespace dlib
is_integer_variable.assign(lower.size(), false); is_integer_variable.assign(lower.size(), false);
} }
function_spec::function_spec(const matrix<double,0,1>& lower, const matrix<double,0,1>& upper, std::vector<bool> is_integer) : function_spec(std::move(lower),std::move(upper)) // ----------------------------------------------------------------------------------------
function_spec::function_spec(
const matrix<double,0,1>& lower,
const matrix<double,0,1>& upper,
std::vector<bool> is_integer
) :
function_spec(std::move(lower),std::move(upper))
{ {
is_integer_variable = std::move(is_integer); is_integer_variable = std::move(is_integer);
DLIB_CASSERT(lower.size() == (long)is_integer_variable.size()); DLIB_CASSERT(lower.size() == (long)is_integer_variable.size());
...@@ -416,6 +432,8 @@ namespace dlib ...@@ -416,6 +432,8 @@ namespace dlib
return tmp; return tmp;
} }
// ------------------------------------------------------------------------------------
double funct_info::find_nn ( double funct_info::find_nn (
const std::vector<function_evaluation>& evals, const std::vector<function_evaluation>& evals,
const matrix<double,0,1>& x const matrix<double,0,1>& x
...@@ -435,11 +453,14 @@ namespace dlib ...@@ -435,11 +453,14 @@ namespace dlib
return best_y; return best_y;
} }
} } // end namespace gopt_impl
// ----------------------------------------------------------------------------------------
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
function_evaluation_request::function_evaluation_request(function_evaluation_request&& item) function_evaluation_request::function_evaluation_request(
function_evaluation_request&& item
)
{ {
m_has_been_evaluated = item.m_has_been_evaluated; m_has_been_evaluated = item.m_has_been_evaluated;
req = item.req; req = item.req;
...@@ -449,21 +470,31 @@ namespace dlib ...@@ -449,21 +470,31 @@ namespace dlib
item.m_has_been_evaluated = true; item.m_has_been_evaluated = true;
} }
// ----------------------------------------------------------------------------------------
function_evaluation_request& function_evaluation_request:: function_evaluation_request& function_evaluation_request::
operator=(function_evaluation_request&& item) operator=(
function_evaluation_request&& item
)
{ {
function_evaluation_request(std::move(item)).swap(*this); function_evaluation_request(std::move(item)).swap(*this);
return *this; return *this;
} }
// ----------------------------------------------------------------------------------------
void function_evaluation_request:: void function_evaluation_request::
swap(function_evaluation_request& item) swap(
function_evaluation_request& item
)
{ {
std::swap(m_has_been_evaluated, item.m_has_been_evaluated); std::swap(m_has_been_evaluated, item.m_has_been_evaluated);
std::swap(req, item.req); std::swap(req, item.req);
std::swap(info, item.info); std::swap(info, item.info);
} }
// ----------------------------------------------------------------------------------------
size_t function_evaluation_request:: size_t function_evaluation_request::
function_idx ( function_idx (
) const ) const
...@@ -478,6 +509,8 @@ namespace dlib ...@@ -478,6 +509,8 @@ namespace dlib
return req.x; return req.x;
} }
// ----------------------------------------------------------------------------------------
bool function_evaluation_request:: bool function_evaluation_request::
has_been_evaluated ( has_been_evaluated (
) const ) const
...@@ -485,6 +518,8 @@ namespace dlib ...@@ -485,6 +518,8 @@ namespace dlib
return m_has_been_evaluated; return m_has_been_evaluated;
} }
// ----------------------------------------------------------------------------------------
function_evaluation_request:: function_evaluation_request::
~function_evaluation_request() ~function_evaluation_request()
{ {
...@@ -498,16 +533,12 @@ namespace dlib ...@@ -498,16 +533,12 @@ namespace dlib
} }
} }
// ----------------------------------------------------------------------------------------
void function_evaluation_request:: void function_evaluation_request::
set ( set (
double y double y
) )
/*!
requires
- has_been_evaluated() == false
ensures
- #has_been_evaluated() == true
!*/
{ {
DLIB_CASSERT(has_been_evaluated() == false); DLIB_CASSERT(has_been_evaluated() == false);
std::lock_guard<std::mutex> lock(*info->m); std::lock_guard<std::mutex> lock(*info->m);
...@@ -559,6 +590,8 @@ namespace dlib ...@@ -559,6 +590,8 @@ namespace dlib
const function_spec& function const function_spec& function
) : global_function_search(std::vector<function_spec>(1,function)) {} ) : global_function_search(std::vector<function_spec>(1,function)) {}
// ----------------------------------------------------------------------------------------
global_function_search:: global_function_search::
global_function_search( global_function_search(
const std::vector<function_spec>& functions_ const std::vector<function_spec>& functions_
...@@ -570,12 +603,15 @@ namespace dlib ...@@ -570,12 +603,15 @@ namespace dlib
functions.emplace_back(std::make_shared<gopt_impl::funct_info>(functions_[i],i,m)); functions.emplace_back(std::make_shared<gopt_impl::funct_info>(functions_[i],i,m));
} }
// ----------------------------------------------------------------------------------------
global_function_search:: global_function_search::
global_function_search( global_function_search(
const std::vector<function_spec>& functions_, const std::vector<function_spec>& functions_,
const std::vector<std::vector<function_evaluation>>& initial_function_evals, const std::vector<std::vector<function_evaluation>>& initial_function_evals,
const double relative_noise_magnitude_ const double relative_noise_magnitude_
) : global_function_search(functions_) ) :
global_function_search(functions_)
{ {
DLIB_CASSERT(functions_.size() == initial_function_evals.size()); DLIB_CASSERT(functions_.size() == initial_function_evals.size());
DLIB_CASSERT(relative_noise_magnitude >= 0); DLIB_CASSERT(relative_noise_magnitude >= 0);
...@@ -586,13 +622,17 @@ namespace dlib ...@@ -586,13 +622,17 @@ namespace dlib
} }
} }
// ----------------------------------------------------------------------------------------
size_t global_function_search:: size_t global_function_search::
num_functions() const num_functions(
) const
{ {
return functions.size(); return functions.size();
} }
// ----------------------------------------------------------------------------------------
void global_function_search:: void global_function_search::
set_seed ( set_seed (
time_t seed time_t seed
...@@ -601,6 +641,8 @@ namespace dlib ...@@ -601,6 +641,8 @@ namespace dlib
rnd = dlib::rand(seed); rnd = dlib::rand(seed);
} }
// ----------------------------------------------------------------------------------------
void global_function_search:: void global_function_search::
get_function_evaluations ( get_function_evaluations (
std::vector<function_spec>& specs, std::vector<function_spec>& specs,
...@@ -617,6 +659,8 @@ namespace dlib ...@@ -617,6 +659,8 @@ namespace dlib
} }
} }
// ----------------------------------------------------------------------------------------
void global_function_search:: void global_function_search::
get_best_function_eval ( get_best_function_eval (
matrix<double,0,1>& x, matrix<double,0,1>& x,
...@@ -634,6 +678,8 @@ namespace dlib ...@@ -634,6 +678,8 @@ namespace dlib
x = info.best_x; x = info.best_x;
} }
// ----------------------------------------------------------------------------------------
function_evaluation_request global_function_search:: function_evaluation_request global_function_search::
get_next_x ( get_next_x (
) )
...@@ -734,9 +780,16 @@ namespace dlib ...@@ -734,9 +780,16 @@ namespace dlib
} }
// ----------------------------------------------------------------------------------------
double global_function_search:: double global_function_search::
get_pure_random_search_probability ( get_pure_random_search_probability (
) const { return pure_random_search_probability; } ) const
{
return pure_random_search_probability;
}
// ----------------------------------------------------------------------------------------
void global_function_search:: void global_function_search::
set_pure_random_search_probability ( set_pure_random_search_probability (
...@@ -747,9 +800,16 @@ namespace dlib ...@@ -747,9 +800,16 @@ namespace dlib
pure_random_search_probability = prob; pure_random_search_probability = prob;
} }
// ----------------------------------------------------------------------------------------
double global_function_search:: double global_function_search::
get_solver_epsilon ( get_solver_epsilon (
) const { return min_trust_region_epsilon; } ) const
{
return min_trust_region_epsilon;
}
// ----------------------------------------------------------------------------------------
void global_function_search:: void global_function_search::
set_solver_epsilon ( set_solver_epsilon (
...@@ -760,6 +820,8 @@ namespace dlib ...@@ -760,6 +820,8 @@ namespace dlib
min_trust_region_epsilon = eps; min_trust_region_epsilon = eps;
} }
// ----------------------------------------------------------------------------------------
double global_function_search:: double global_function_search::
get_relative_noise_magnitude ( get_relative_noise_magnitude (
) const ) const
...@@ -767,6 +829,8 @@ namespace dlib ...@@ -767,6 +829,8 @@ namespace dlib
return relative_noise_magnitude; return relative_noise_magnitude;
} }
// ----------------------------------------------------------------------------------------
void global_function_search:: void global_function_search::
set_relative_noise_magnitude ( set_relative_noise_magnitude (
double value double value
...@@ -779,6 +843,8 @@ namespace dlib ...@@ -779,6 +843,8 @@ namespace dlib
f->ub = upper_bound_function(f->ub.get_points(), relative_noise_magnitude); f->ub = upper_bound_function(f->ub.get_points(), relative_noise_magnitude);
} }
// ----------------------------------------------------------------------------------------
size_t global_function_search:: size_t global_function_search::
get_monte_carlo_upper_bound_sample_num ( get_monte_carlo_upper_bound_sample_num (
) const ) const
...@@ -786,6 +852,8 @@ namespace dlib ...@@ -786,6 +852,8 @@ namespace dlib
return num_random_samples; return num_random_samples;
} }
// ----------------------------------------------------------------------------------------
void global_function_search:: void global_function_search::
set_monte_carlo_upper_bound_sample_num ( set_monte_carlo_upper_bound_sample_num (
size_t num size_t num
...@@ -795,16 +863,22 @@ namespace dlib ...@@ -795,16 +863,22 @@ namespace dlib
num_random_samples = num; num_random_samples = num;
} }
// ----------------------------------------------------------------------------------------
std::shared_ptr<gopt_impl::funct_info> global_function_search:: std::shared_ptr<gopt_impl::funct_info> global_function_search::
best_function() const best_function(
) const
{ {
size_t idx = 0; size_t idx = 0;
return best_function(idx); return best_function(idx);
} }
// ----------------------------------------------------------------------------------------
std::shared_ptr<gopt_impl::funct_info> global_function_search:: std::shared_ptr<gopt_impl::funct_info> global_function_search::
best_function(size_t& idx) const best_function(
size_t& idx
) const
{ {
auto i = std::max_element(functions.begin(), functions.end(), auto i = std::max_element(functions.begin(), functions.end(),
[](const std::shared_ptr<gopt_impl::funct_info>& a, const std::shared_ptr<gopt_impl::funct_info>& b) { return a->best_objective_value < b->best_objective_value; }); [](const std::shared_ptr<gopt_impl::funct_info>& a, const std::shared_ptr<gopt_impl::funct_info>& b) { return a->best_objective_value < b->best_objective_value; });
...@@ -813,6 +887,8 @@ namespace dlib ...@@ -813,6 +887,8 @@ namespace dlib
return *i; return *i;
} }
// ----------------------------------------------------------------------------------------
bool global_function_search:: bool global_function_search::
has_incomplete_trust_region_request ( has_incomplete_trust_region_request (
) const ) const
......
...@@ -16,9 +16,16 @@ namespace dlib ...@@ -16,9 +16,16 @@ namespace dlib
struct function_spec struct function_spec
{ {
function_spec(const matrix<double,0,1>& lower_, const matrix<double,0,1>& upper_); function_spec(
const matrix<double,0,1>& lower_,
const matrix<double,0,1>& upper_
);
function_spec(const matrix<double,0,1>& lower, const matrix<double,0,1>& upper, std::vector<bool> is_integer); function_spec(
const matrix<double,0,1>& lower,
const matrix<double,0,1>& upper,
std::vector<bool> is_integer
);
matrix<double,0,1> lower; matrix<double,0,1> lower;
matrix<double,0,1> upper; matrix<double,0,1> upper;
...@@ -48,7 +55,12 @@ namespace dlib ...@@ -48,7 +55,12 @@ namespace dlib
funct_info(const funct_info&) = delete; funct_info(const funct_info&) = delete;
funct_info& operator=(const funct_info&) = delete; funct_info& operator=(const funct_info&) = delete;
funct_info(const function_spec& spec, size_t function_idx, const std::shared_ptr<std::mutex>& m) : spec(spec), function_idx(function_idx), m(m) funct_info(
const function_spec& spec,
size_t function_idx,
const std::shared_ptr<std::mutex>& m
) :
spec(spec), function_idx(function_idx), m(m)
{ {
best_x = zeros_matrix(spec.lower); best_x = zeros_matrix(spec.lower);
} }
...@@ -81,12 +93,11 @@ namespace dlib ...@@ -81,12 +93,11 @@ namespace dlib
public: public:
function_evaluation_request() = delete; function_evaluation_request() = delete;
function_evaluation_request(const function_evaluation_request&) = delete; function_evaluation_request(const function_evaluation_request&) = delete;
function_evaluation_request& operator=(const function_evaluation_request&) = delete; function_evaluation_request& operator=(const function_evaluation_request&) = delete;
function_evaluation_request(function_evaluation_request&& item);
function_evaluation_request(function_evaluation_request&& item);
function_evaluation_request& operator=(function_evaluation_request&& item); function_evaluation_request& operator=(function_evaluation_request&& item);
void swap(function_evaluation_request& item); void swap(function_evaluation_request& item);
...@@ -151,7 +162,8 @@ namespace dlib ...@@ -151,7 +162,8 @@ namespace dlib
global_function_search(const global_function_search&) = delete; global_function_search(const global_function_search&) = delete;
global_function_search& operator=(const global_function_search& item) = delete; global_function_search& operator=(const global_function_search& item) = delete;
size_t num_functions() const; size_t num_functions(
) const;
void set_seed ( void set_seed (
time_t seed time_t seed
...@@ -201,9 +213,12 @@ namespace dlib ...@@ -201,9 +213,12 @@ namespace dlib
private: private:
std::shared_ptr<gopt_impl::funct_info> best_function() const; std::shared_ptr<gopt_impl::funct_info> best_function(
) const;
std::shared_ptr<gopt_impl::funct_info> best_function(size_t& idx) const; std::shared_ptr<gopt_impl::funct_info> best_function(
size_t& idx
) const;
bool has_incomplete_trust_region_request ( bool has_incomplete_trust_region_request (
) const; ) const;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment