Commit 1372472d authored by Davis King's avatar Davis King

Added an optional hard limit on the number of iterations in the

structural SVM solver.
parent 65d838e6
...@@ -345,6 +345,7 @@ namespace dlib ...@@ -345,6 +345,7 @@ namespace dlib
svm_struct_controller_node ( svm_struct_controller_node (
) : ) :
eps(0.001), eps(0.001),
max_iterations(10000),
cache_based_eps(std::numeric_limits<double>::infinity()), cache_based_eps(std::numeric_limits<double>::infinity()),
verbose(false), verbose(false),
C(1) C(1)
...@@ -389,6 +390,16 @@ namespace dlib ...@@ -389,6 +390,16 @@ namespace dlib
double get_epsilon ( double get_epsilon (
) const { return eps; } ) const { return eps; }
unsigned long get_max_iterations (
) const { return max_iterations; }
void set_max_iterations (
unsigned long max_iter
)
{
max_iterations = max_iter;
}
void be_verbose ( void be_verbose (
) )
{ {
...@@ -515,6 +526,7 @@ namespace dlib ...@@ -515,6 +526,7 @@ namespace dlib
problem_type<matrix_type> problem(nodes); problem_type<matrix_type> problem(nodes);
problem.set_cache_based_epsilon(cache_based_eps); problem.set_cache_based_epsilon(cache_based_eps);
problem.set_epsilon(eps); problem.set_epsilon(eps);
problem.set_max_iterations(max_iterations);
if (verbose) if (verbose)
problem.be_verbose(); problem.be_verbose();
problem.set_c(C); problem.set_c(C);
...@@ -674,6 +686,7 @@ namespace dlib ...@@ -674,6 +686,7 @@ namespace dlib
std::vector<network_address> nodes; std::vector<network_address> nodes;
double eps; double eps;
unsigned long max_iterations;
double cache_based_eps; double cache_based_eps;
bool verbose; bool verbose;
double C; double C;
......
...@@ -64,6 +64,7 @@ namespace dlib ...@@ -64,6 +64,7 @@ namespace dlib
INITIAL VALUE INITIAL VALUE
- get_num_processing_nodes() == 0 - get_num_processing_nodes() == 0
- get_epsilon() == 0.001 - get_epsilon() == 0.001
- get_max_iterations() == 10000
- get_c() == 1 - get_c() == 1
- This object will not be verbose - This object will not be verbose
...@@ -182,6 +183,22 @@ namespace dlib ...@@ -182,6 +183,22 @@ namespace dlib
- #get_cache_based_epsilon() == eps - #get_cache_based_epsilon() == eps
!*/ !*/
void set_max_iterations (
unsigned long max_iter
);
/*!
ensures
- #get_max_iterations() == max_iter
!*/
unsigned long get_max_iterations (
);
/*!
ensures
- returns the maximum number of iterations the SVM optimizer is allowed to
run before it is required to stop and return a result.
!*/
void add_nuclear_norm_regularizer ( void add_nuclear_norm_regularizer (
long first_dimension, long first_dimension,
long rows, long rows,
......
...@@ -239,6 +239,7 @@ namespace dlib ...@@ -239,6 +239,7 @@ namespace dlib
CONVENTION CONVENTION
- C == get_c() - C == get_c()
- eps == get_epsilon() - eps == get_epsilon()
- max_iterations == get_max_iterations()
- if (skip_cache) then - if (skip_cache) then
- we won't use the oracle cache when we need to evaluate the separation - we won't use the oracle cache when we need to evaluate the separation
oracle. Instead, we will directly call the user supplied separation_oracle(). oracle. Instead, we will directly call the user supplied separation_oracle().
...@@ -259,6 +260,7 @@ namespace dlib ...@@ -259,6 +260,7 @@ namespace dlib
) : ) :
saved_current_risk_gap(0), saved_current_risk_gap(0),
eps(0.001), eps(0.001),
max_iterations(10000),
verbose(false), verbose(false),
skip_cache(true), skip_cache(true),
count_below_eps(0), count_below_eps(0),
...@@ -308,6 +310,16 @@ namespace dlib ...@@ -308,6 +310,16 @@ namespace dlib
const scalar_type get_epsilon ( const scalar_type get_epsilon (
) const { return eps; } ) const { return eps; }
unsigned long get_max_iterations (
) const { return max_iterations; }
void set_max_iterations (
unsigned long max_iter
)
{
max_iterations = max_iter;
}
void set_max_cache_size ( void set_max_cache_size (
unsigned long max_size unsigned long max_size
) )
...@@ -445,6 +457,9 @@ namespace dlib ...@@ -445,6 +457,9 @@ namespace dlib
cout << endl; cout << endl;
} }
if (num_iterations >= max_iterations)
return true;
saved_current_risk_gap = current_risk_gap; saved_current_risk_gap = current_risk_gap;
if (converged) if (converged)
...@@ -611,6 +626,7 @@ namespace dlib ...@@ -611,6 +626,7 @@ namespace dlib
mutable scalar_type saved_current_risk_gap; mutable scalar_type saved_current_risk_gap;
mutable matrix_type psi_true; mutable matrix_type psi_true;
scalar_type eps; scalar_type eps;
unsigned long max_iterations;
mutable bool verbose; mutable bool verbose;
......
...@@ -29,6 +29,7 @@ namespace dlib ...@@ -29,6 +29,7 @@ namespace dlib
INITIAL VALUE INITIAL VALUE
- get_epsilon() == 0.001 - get_epsilon() == 0.001
- get_max_iterations() == 10000
- get_max_cache_size() == 5 - get_max_cache_size() == 5
- get_c() == 1 - get_c() == 1
- get_cache_based_epsilon() == std::numeric_limits<scalar_type>::infinity() - get_cache_based_epsilon() == std::numeric_limits<scalar_type>::infinity()
...@@ -161,6 +162,22 @@ namespace dlib ...@@ -161,6 +162,22 @@ namespace dlib
- #get_cache_based_epsilon() == eps - #get_cache_based_epsilon() == eps
!*/ !*/
void set_max_iterations (
unsigned long max_iter
);
/*!
ensures
- #get_max_iterations() == max_iter
!*/
unsigned long get_max_iterations (
);
/*!
ensures
- returns the maximum number of iterations the SVM optimizer is allowed to
run before it is required to stop and return a result.
!*/
void set_max_cache_size ( void set_max_cache_size (
unsigned long max_size unsigned long max_size
); );
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment