Commit ea460afd authored by Davis King's avatar Davis King

Fixed the distributed version of the structural svm solver to work with the

recent changes to the core solver.  Also added support for the nuclear norm
regularization and cache refinement options.
parent 43fdc03f
...@@ -65,32 +65,32 @@ namespace dlib ...@@ -65,32 +65,32 @@ namespace dlib
typedef typename matrix_type::type scalar_type; typedef typename matrix_type::type scalar_type;
matrix_type current_solution; matrix_type current_solution;
scalar_type cur_risk_lower_bound; scalar_type saved_current_risk_gap;
double eps;
bool skip_cache; bool skip_cache;
bool converged;
friend void swap (oracle_request& a, oracle_request& b) friend void swap (oracle_request& a, oracle_request& b)
{ {
a.current_solution.swap(b.current_solution); a.current_solution.swap(b.current_solution);
std::swap(a.cur_risk_lower_bound, b.cur_risk_lower_bound); std::swap(a.saved_current_risk_gap, b.saved_current_risk_gap);
std::swap(a.eps, b.eps);
std::swap(a.skip_cache, b.skip_cache); std::swap(a.skip_cache, b.skip_cache);
std::swap(a.converged, b.converged);
} }
friend void serialize (const oracle_request& item, std::ostream& out) friend void serialize (const oracle_request& item, std::ostream& out)
{ {
serialize(item.current_solution, out); serialize(item.current_solution, out);
dlib::serialize(item.cur_risk_lower_bound, out); dlib::serialize(item.saved_current_risk_gap, out);
dlib::serialize(item.eps, out);
dlib::serialize(item.skip_cache, out); dlib::serialize(item.skip_cache, out);
dlib::serialize(item.converged, out);
} }
friend void deserialize (oracle_request& item, std::istream& in) friend void deserialize (oracle_request& item, std::istream& in)
{ {
deserialize(item.current_solution, in); deserialize(item.current_solution, in);
dlib::deserialize(item.cur_risk_lower_bound, in); dlib::deserialize(item.saved_current_risk_gap, in);
dlib::deserialize(item.eps, in);
dlib::deserialize(item.skip_cache, in); dlib::deserialize(item.skip_cache, in);
dlib::deserialize(item.converged, in);
} }
}; };
...@@ -270,8 +270,9 @@ namespace dlib ...@@ -270,8 +270,9 @@ namespace dlib
feature_vector_type ftemp; feature_vector_type ftemp;
for (long i = begin; i < end; ++i) for (long i = begin; i < end; ++i)
{ {
self.cache[i].separation_oracle_cached(req.skip_cache, self.cache[i].separation_oracle_cached(req.converged,
req.cur_risk_lower_bound, req.skip_cache,
req.saved_current_risk_gap,
req.current_solution, req.current_solution,
loss, loss,
ftemp); ftemp);
...@@ -292,8 +293,9 @@ namespace dlib ...@@ -292,8 +293,9 @@ namespace dlib
for (long i = begin; i < end; ++i) for (long i = begin; i < end; ++i)
{ {
scalar_type loss_temp; scalar_type loss_temp;
self.cache[i].separation_oracle_cached(req.skip_cache, self.cache[i].separation_oracle_cached(req.converged,
req.cur_risk_lower_bound, req.skip_cache,
req.saved_current_risk_gap,
req.current_solution, req.current_solution,
loss_temp, loss_temp,
ftemp); ftemp);
...@@ -343,17 +345,39 @@ namespace dlib ...@@ -343,17 +345,39 @@ namespace dlib
svm_struct_controller_node ( svm_struct_controller_node (
) : ) :
eps(0.001), eps(0.001),
cache_based_eps(std::numeric_limits<double>::infinity()),
verbose(false), verbose(false),
C(1) C(1)
{} {}
double get_cache_based_epsilon (
) const
{
return cache_based_eps;
}
void set_cache_based_epsilon (
double eps_
)
{
// make sure requires clause is not broken
DLIB_ASSERT(eps_ > 0,
"\t void svm_struct_controller_node::set_cache_based_epsilon()"
<< "\n\t eps_ must be greater than 0"
<< "\n\t eps_: " << eps_
<< "\n\t this: " << this
);
cache_based_eps = eps_;
}
void set_epsilon ( void set_epsilon (
double eps_ double eps_
) )
{ {
// make sure requires clause is not broken // make sure requires clause is not broken
DLIB_ASSERT(eps_ > 0, DLIB_ASSERT(eps_ > 0,
"\t void structural_svm_problem::set_epsilon()" "\t void svm_struct_controller_node::set_epsilon()"
<< "\n\t eps_ must be greater than 0" << "\n\t eps_ must be greater than 0"
<< "\n\t eps_: " << eps_ << "\n\t eps_: " << eps_
<< "\n\t this: " << this << "\n\t this: " << this
...@@ -377,6 +401,41 @@ namespace dlib ...@@ -377,6 +401,41 @@ namespace dlib
verbose = false; verbose = false;
} }
void add_nuclear_norm_regularizer (
long first_dimension,
long rows,
long cols,
double regularization_strength
)
{
// make sure requires clause is not broken
DLIB_ASSERT(0 <= first_dimension &&
0 <= rows && 0 <= cols &&
0 < regularization_strength,
"\t void svm_struct_controller_node::add_nuclear_norm_regularizer()"
<< "\n\t Invalid arguments were given to this function."
<< "\n\t first_dimension: " << first_dimension
<< "\n\t rows: " << rows
<< "\n\t cols: " << cols
<< "\n\t regularization_strength: " << regularization_strength
<< "\n\t this: " << this
);
impl::nuclear_norm_regularizer temp;
temp.first_dimension = first_dimension;
temp.nr = rows;
temp.nc = cols;
temp.regularization_strength = regularization_strength;
nuclear_norm_regularizers.push_back(temp);
}
unsigned long num_nuclear_norm_regularizers (
) const { return nuclear_norm_regularizers.size(); }
void clear_nuclear_norm_regularizers (
) { nuclear_norm_regularizers.clear(); }
double get_c ( double get_c (
) const { return C; } ) const { return C; }
...@@ -386,7 +445,7 @@ namespace dlib ...@@ -386,7 +445,7 @@ namespace dlib
{ {
// make sure requires clause is not broken // make sure requires clause is not broken
DLIB_ASSERT(C_ > 0, DLIB_ASSERT(C_ > 0,
"\t void structural_svm_problem::set_c()" "\t void svm_struct_controller_node::set_c()"
<< "\n\t C_ must be greater than 0" << "\n\t C_ must be greater than 0"
<< "\n\t C_: " << C_ << "\n\t C_: " << C_
<< "\n\t this: " << this << "\n\t this: " << this
...@@ -401,7 +460,7 @@ namespace dlib ...@@ -401,7 +460,7 @@ namespace dlib
{ {
// make sure requires clause is not broken // make sure requires clause is not broken
DLIB_ASSERT(addr.port != 0, DLIB_ASSERT(addr.port != 0,
"\t void structural_svm_problem::add_processing_node()" "\t void svm_struct_controller_node::add_processing_node()"
<< "\n\t Invalid inputs were given to this function" << "\n\t Invalid inputs were given to this function"
<< "\n\t addr.host_address: " << addr.host_address << "\n\t addr.host_address: " << addr.host_address
<< "\n\t addr.port: " << addr.port << "\n\t addr.port: " << addr.port
...@@ -453,7 +512,20 @@ namespace dlib ...@@ -453,7 +512,20 @@ namespace dlib
<< "\n\t this: " << this << "\n\t this: " << this
); );
problem_type<matrix_type> problem(nodes,eps,verbose,C); problem_type<matrix_type> problem(nodes);
problem.set_cache_based_epsilon(cache_based_eps);
problem.set_epsilon(eps);
if (verbose)
problem.be_verbose();
problem.set_c(C);
for (unsigned long i = 0; i < nuclear_norm_regularizers.size(); ++i)
{
problem.add_nuclear_norm_regularizer(
nuclear_norm_regularizers[i].first_dimension,
nuclear_norm_regularizers[i].nr,
nuclear_norm_regularizers[i].nc,
nuclear_norm_regularizers[i].regularization_strength);
}
return solver(problem, w); return solver(problem, w);
} }
...@@ -470,25 +542,17 @@ namespace dlib ...@@ -470,25 +542,17 @@ namespace dlib
private: private:
template <typename matrix_type_> template <typename matrix_type_>
class problem_type : public oca_problem<matrix_type_> class problem_type : public structural_svm_problem<matrix_type_>
{ {
public: public:
typedef typename matrix_type_::type scalar_type; typedef typename matrix_type_::type scalar_type;
typedef matrix_type_ matrix_type; typedef matrix_type_ matrix_type;
problem_type ( problem_type (
const std::vector<network_address>& nodes_, const std::vector<network_address>& nodes_
double eps_,
bool verbose_,
double C_
) : ) :
nodes(nodes_), nodes(nodes_),
eps(eps_),
verbose(verbose_),
C(C_),
in(3), in(3),
cur_risk_lower_bound(0),
skip_cache(true),
num_dims(0) num_dims(0)
{ {
...@@ -529,69 +593,14 @@ namespace dlib ...@@ -529,69 +593,14 @@ namespace dlib
num_dims = temp.template get<long>(); num_dims = temp.template get<long>();
} }
} }
} }
// These functions are just here because the structural_svm_problem requires
// them, but since we are overloading get_risk() they are never called so they
virtual bool risk_has_lower_bound ( // don't matter.
scalar_type& lower_bound virtual long get_num_samples () const {return 0;}
) const virtual void get_truth_joint_feature_vector ( long , matrix_type& ) const {}
{ virtual void separation_oracle ( const long , const matrix_type& , scalar_type& , matrix_type& ) const {}
lower_bound = 0;
return true;
}
virtual bool optimization_status (
scalar_type current_objective_value,
scalar_type current_error_gap,
scalar_type current_risk_value,
scalar_type current_risk_gap,
unsigned long num_cutting_planes,
unsigned long num_iterations
) const
{
if (verbose)
{
using namespace std;
cout << "objective: " << current_objective_value << endl;
cout << "objective gap: " << current_error_gap << endl;
cout << "risk: " << current_risk_value << endl;
cout << "risk gap: " << current_risk_gap << endl;
cout << "num planes: " << num_cutting_planes << endl;
cout << "iter: " << num_iterations << endl;
cout << endl;
}
cur_risk_lower_bound = std::max<scalar_type>(current_risk_value - current_risk_gap, 0);
bool should_stop = false;
if (current_risk_gap < eps)
should_stop = true;
if (should_stop && !skip_cache)
{
// Instead of stopping we shouldn't use the cache on the next iteration. This way
// we can be sure to have the best solution rather than assuming the cache is up-to-date
// enough.
should_stop = false;
skip_cache = true;
}
else
{
skip_cache = false;
}
return should_stop;
}
virtual scalar_type get_c (
) const
{
return C;
}
virtual long get_num_dimensions ( virtual long get_num_dimensions (
) const ) const
...@@ -614,9 +623,9 @@ namespace dlib ...@@ -614,9 +623,9 @@ namespace dlib
for (unsigned long i = 0; i < out_pipes.size(); ++i) for (unsigned long i = 0; i < out_pipes.size(); ++i)
{ {
temp_out.template get<oracle_request<matrix_type> >().current_solution = w; temp_out.template get<oracle_request<matrix_type> >().current_solution = w;
temp_out.template get<oracle_request<matrix_type> >().eps = eps; temp_out.template get<oracle_request<matrix_type> >().saved_current_risk_gap = this->saved_current_risk_gap;
temp_out.template get<oracle_request<matrix_type> >().cur_risk_lower_bound = cur_risk_lower_bound; temp_out.template get<oracle_request<matrix_type> >().skip_cache = this->skip_cache;
temp_out.template get<oracle_request<matrix_type> >().skip_cache = skip_cache; temp_out.template get<oracle_request<matrix_type> >().converged = this->converged;
out_pipes[i]->enqueue(temp_out); out_pipes[i]->enqueue(temp_out);
} }
...@@ -641,12 +650,18 @@ namespace dlib ...@@ -641,12 +650,18 @@ namespace dlib
subgradient /= num; subgradient /= num;
total_loss /= num; total_loss /= num;
risk = total_loss + dot(subgradient,w); risk = total_loss + dot(subgradient,w);
if (this->nuclear_norm_regularizers.size() != 0)
{
matrix_type grad;
double obj;
this->compute_nuclear_norm_parts(w, grad, obj);
risk += obj;
subgradient += grad;
}
} }
std::vector<network_address> nodes; std::vector<network_address> nodes;
double eps;
mutable bool verbose;
double C;
typedef type_safe_union<impl::oracle_request<matrix_type> > tsu_out; typedef type_safe_union<impl::oracle_request<matrix_type> > tsu_out;
typedef type_safe_union<impl::oracle_response<matrix_type>, long> tsu_in; typedef type_safe_union<impl::oracle_response<matrix_type>, long> tsu_in;
...@@ -654,16 +669,15 @@ namespace dlib ...@@ -654,16 +669,15 @@ namespace dlib
std::vector<shared_ptr<pipe<tsu_out> > > out_pipes; std::vector<shared_ptr<pipe<tsu_out> > > out_pipes;
mutable pipe<tsu_in> in; mutable pipe<tsu_in> in;
std::vector<shared_ptr<bridge> > bridges; std::vector<shared_ptr<bridge> > bridges;
mutable scalar_type cur_risk_lower_bound;
mutable bool skip_cache;
long num_dims; long num_dims;
}; };
std::vector<network_address> nodes; std::vector<network_address> nodes;
double eps; double eps;
mutable bool verbose; double cache_based_eps;
bool verbose;
double C; double C;
std::vector<impl::nuclear_norm_regularizer> nuclear_norm_regularizers;
}; };
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
......
...@@ -47,6 +47,8 @@ namespace dlib ...@@ -47,6 +47,8 @@ namespace dlib
- Note that the following parameters within the given problem are ignored: - Note that the following parameters within the given problem are ignored:
- problem.get_c() - problem.get_c()
- problem.get_epsilon() - problem.get_epsilon()
- problem.get_cache_based_epsilon()
- problem.num_nuclear_norm_regularizers()
- weather the problem is verbose or not - weather the problem is verbose or not
Instead, they are defined by the svm_struct_controller_node. Note, however, Instead, they are defined by the svm_struct_controller_node. Note, however,
that the problem.get_max_cache_size() parameter is meaningful and controls that the problem.get_max_cache_size() parameter is meaningful and controls
...@@ -145,6 +147,87 @@ namespace dlib ...@@ -145,6 +147,87 @@ namespace dlib
optimal value". optimal value".
!*/ !*/
double get_cache_based_epsilon (
) const;
/*!
ensures
- if (get_max_cache_size() != 0) then
- The solver will not stop when the average sample risk is within
get_epsilon() of its optimal value. Instead, it will keep running
but will run the optimizer completely on the cache until the average
sample risk is within #get_cache_based_epsilon() of its optimal
value. This means that it will perform this additional refinement in
the solution accuracy without making any additional calls to the
separation_oracle(). This is useful when using a nuclear norm
regularization term because it allows you to quickly solve the
optimization problem to a high precision, which in the case of a
nuclear norm regularized problem means that many of the learned
matrices will be low rank or very close to low rank due to the
nuclear norm regularizer. This may not happen without solving the
problem to a high accuracy or their ranks may be difficult to
determine, so the extra accuracy given by the cache based refinement
is very useful. Finally, note that we include the nuclear norm term
as part of the "risk" for the purposes of determining when to stop.
- else
- The value of #get_cache_based_epsilon() has no effect.
!*/
void set_cache_based_epsilon (
double eps
);
/*!
requires
- eps > 0
ensures
- #get_cache_based_epsilon() == eps
!*/
void add_nuclear_norm_regularizer (
long first_dimension,
long rows,
long cols,
double regularization_strength
);
/*!
requires
- 0 <= first_dimension < number of dimensions in problem
- 0 <= rows
- 0 <= cols
- first_dimension+rows*cols <= number of dimensions in problem
- 0 < regularization_strength
ensures
- Adds a nuclear norm regularization term to the optimization problem
solved by this object. That is, instead of solving:
Minimize: h(w) == 0.5*dot(w,w) + C*R(w)
this object will solve:
Minimize: h(w) == 0.5*dot(w,w) + C*R(w) + regularization_strength*nuclear_norm_of(part of w)
where "part of w" is the part of w indicated by the arguments to this
function. In particular, the part of w included in the nuclear norm is
exactly the matrix reshape(rowm(w, range(first_dimension, first_dimension+rows*cols-1)), rows, cols).
Therefore, if you think of the w vector as being the concatenation of a
bunch of matrices then you can use multiple calls to add_nuclear_norm_regularizer()
to add nuclear norm regularization terms to any of the matrices packed into w.
- #num_nuclear_norm_regularizers() == num_nuclear_norm_regularizers() + 1
!*/
unsigned long num_nuclear_norm_regularizers (
) const;
/*!
ensures
- returns the number of nuclear norm regularizers that are currently a part
of this optimization problem. That is, returns the number of times
add_nuclear_norm_regularizer() has been called since the last call to
clear_nuclear_norm_regularizers() or object construction, whichever is
most recent.
!*/
void clear_nuclear_norm_regularizers (
);
/*!
ensures
- #num_nuclear_norm_regularizers() == 0
!*/
void be_verbose ( void be_verbose (
); );
/*! /*!
......
...@@ -14,6 +14,19 @@ ...@@ -14,6 +14,19 @@
namespace dlib namespace dlib
{ {
// ----------------------------------------------------------------------------------------
namespace impl
{
struct nuclear_norm_regularizer
{
long first_dimension;
long nr;
long nc;
double regularization_strength;
};
}
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
template < template <
...@@ -345,7 +358,7 @@ namespace dlib ...@@ -345,7 +358,7 @@ namespace dlib
<< "\n\t this: " << this << "\n\t this: " << this
); );
nuclear_norm_regularizer temp; impl::nuclear_norm_regularizer temp;
temp.first_dimension = first_dimension; temp.first_dimension = first_dimension;
temp.nr = rows; temp.nr = rows;
temp.nc = cols; temp.nc = cols;
...@@ -464,45 +477,6 @@ namespace dlib ...@@ -464,45 +477,6 @@ namespace dlib
return false; return false;
} }
void compute_nuclear_norm_parts(
const matrix_type& m,
matrix_type& grad,
scalar_type& obj
) const
{
obj = 0;
grad.set_size(m.size());
grad = 0;
matrix<double> u,v,w,f;
nuclear_norm_part = 0;
for (unsigned long i = 0; i < nuclear_norm_regularizers.size(); ++i)
{
const long nr = nuclear_norm_regularizers[i].nr;
const long nc = nuclear_norm_regularizers[i].nc;
const long size = nr*nc;
const long idx = nuclear_norm_regularizers[i].first_dimension;
const double strength = nuclear_norm_regularizers[i].regularization_strength;
f = matrix_cast<double>(reshape(rowm(m, range(idx, idx+size-1)), nr, nc));
svd3(f, u,w,v);
w = round_zeros(w, std::max(1e-9,max(w)*1e-7));
const double norm = sum(w);
obj += strength*norm;
nuclear_norm_part += strength*norm/C;
w = w>0;
f = u*diagm(w)*trans(v);
set_rowm(grad, range(idx, idx+size-1)) = matrix_cast<double>(strength*reshape_to_column_vector(f));
}
obj /= C;
grad /= C;
}
virtual void get_risk ( virtual void get_risk (
matrix_type& w, matrix_type& w,
scalar_type& risk, scalar_type& risk,
...@@ -566,6 +540,46 @@ namespace dlib ...@@ -566,6 +540,46 @@ namespace dlib
} }
protected: protected:
void compute_nuclear_norm_parts(
const matrix_type& m,
matrix_type& grad,
scalar_type& obj
) const
{
obj = 0;
grad.set_size(m.size());
grad = 0;
matrix<double> u,v,w,f;
nuclear_norm_part = 0;
for (unsigned long i = 0; i < nuclear_norm_regularizers.size(); ++i)
{
const long nr = nuclear_norm_regularizers[i].nr;
const long nc = nuclear_norm_regularizers[i].nc;
const long size = nr*nc;
const long idx = nuclear_norm_regularizers[i].first_dimension;
const double strength = nuclear_norm_regularizers[i].regularization_strength;
f = matrix_cast<double>(reshape(rowm(m, range(idx, idx+size-1)), nr, nc));
svd3(f, u,w,v);
w = round_zeros(w, std::max(1e-9,max(w)*1e-7));
const double norm = sum(w);
obj += strength*norm;
nuclear_norm_part += strength*norm/C;
w = w>0;
f = u*diagm(w)*trans(v);
set_rowm(grad, range(idx, idx+size-1)) = matrix_cast<double>(strength*reshape_to_column_vector(f));
}
obj /= C;
grad /= C;
}
void separation_oracle_cached ( void separation_oracle_cached (
const long idx, const long idx,
const matrix_type& current_solution, const matrix_type& current_solution,
...@@ -580,16 +594,8 @@ namespace dlib ...@@ -580,16 +594,8 @@ namespace dlib
loss, loss,
psi); psi);
} }
private:
struct nuclear_norm_regularizer std::vector<impl::nuclear_norm_regularizer> nuclear_norm_regularizers;
{
long first_dimension;
long nr;
long nc;
double regularization_strength;
};
std::vector<nuclear_norm_regularizer> nuclear_norm_regularizers;
mutable scalar_type saved_current_risk_gap; mutable scalar_type saved_current_risk_gap;
mutable matrix_type psi_true; mutable matrix_type psi_true;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment