Commit 83217d76 authored by Davis King's avatar Davis King

Added the option to learn non-negative weights to the svm_multiclass_linear_trainer.

parent 8e6b5a40
...@@ -177,7 +177,8 @@ namespace dlib ...@@ -177,7 +177,8 @@ namespace dlib
num_threads(4), num_threads(4),
C(1), C(1),
eps(0.001), eps(0.001),
verbose(false) verbose(false),
learn_nonnegative_weights(false)
{ {
} }
...@@ -243,6 +244,16 @@ namespace dlib ...@@ -243,6 +244,16 @@ namespace dlib
return kernel_type(); return kernel_type();
} }
bool learns_nonnegative_weights (
) const { return learn_nonnegative_weights; }
void set_learns_nonnegative_weights (
bool value
)
{
learn_nonnegative_weights = value;
}
void set_c ( void set_c (
scalar_type C_ scalar_type C_
) )
...@@ -297,7 +308,13 @@ namespace dlib ...@@ -297,7 +308,13 @@ namespace dlib
problem.set_c(C); problem.set_c(C);
problem.set_epsilon(eps); problem.set_epsilon(eps);
svm_objective = solver(problem, weights); unsigned long num_nonnegative = 0;
if (learn_nonnegative_weights)
{
num_nonnegative = problem.get_num_dimensions();
}
svm_objective = solver(problem, weights, num_nonnegative);
trained_function_type df; trained_function_type df;
...@@ -315,6 +332,7 @@ namespace dlib ...@@ -315,6 +332,7 @@ namespace dlib
scalar_type eps; scalar_type eps;
bool verbose; bool verbose;
oca solver; oca solver;
bool learn_nonnegative_weights;
}; };
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
......
...@@ -32,6 +32,7 @@ namespace dlib ...@@ -32,6 +32,7 @@ namespace dlib
INITIAL VALUE INITIAL VALUE
- get_num_threads() == 4 - get_num_threads() == 4
- learns_nonnegative_weights() == false
- get_epsilon() == 0.001 - get_epsilon() == 0.001
- get_c() == 1 - get_c() == 1
- this object will not be verbose unless be_verbose() is called - this object will not be verbose unless be_verbose() is called
...@@ -155,6 +156,26 @@ namespace dlib ...@@ -155,6 +156,26 @@ namespace dlib
generalization. generalization.
!*/ !*/
bool learns_nonnegative_weights (
) const;
/*!
ensures
- The output of training is a set of weights and bias values that together
define the behavior of a multiclass_linear_decision_function object. If
learns_nonnegative_weights() == true then the resulting weights and bias
values will always have non-negative values. That is, if this function
returns true then all the numbers in the multiclass_linear_decision_function
objects output by train() will be non-negative.
!*/
void set_learns_nonnegative_weights (
bool value
);
/*!
ensures
- #learns_nonnegative_weights() == value
!*/
trained_function_type train ( trained_function_type train (
const std::vector<sample_type>& all_samples, const std::vector<sample_type>& all_samples,
const std::vector<label_type>& all_labels const std::vector<label_type>& all_labels
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment