Commit b377f752 authored by Davis King's avatar Davis King

Made it so you can set the number of output filters for con_ layers at runtime.

parent 8eb9e295
......@@ -21,6 +21,12 @@ namespace dlib
// ----------------------------------------------------------------------------------------
struct num_con_outputs
{
num_con_outputs(unsigned long n) : num_outputs(n) {}
unsigned long num_outputs;
};
template <
long _num_filters,
long _nr,
......@@ -43,16 +49,22 @@ namespace dlib
static_assert(0 <= _padding_x && _padding_x < _nc, "The padding must be smaller than the filter size.");
con_(
num_con_outputs o
) :
learning_rate_multiplier(1),
weight_decay_multiplier(1),
bias_learning_rate_multiplier(1),
bias_weight_decay_multiplier(0),
padding_y_(_padding_y),
padding_x_(_padding_x)
{}
padding_x_(_padding_x),
num_filters_(o.num_outputs)
{
DLIB_CASSERT(num_filters_ > 0);
}
long num_filters() const { return _num_filters; }
con_() : con_(num_con_outputs(_num_filters)) {}
long num_filters() const { return num_filters_; }
long nr() const { return _nr; }
long nc() const { return _nc; }
long stride_y() const { return _stride_y; }
......@@ -60,6 +72,14 @@ namespace dlib
long padding_y() const { return padding_y_; }
long padding_x() const { return padding_x_; }
void set_num_filters(long num)
{
DLIB_CASSERT(num > 0);
DLIB_CASSERT(get_layer_params().size() == 0,
"You can't change the number of filters in con_ if the parameter tensor has already been allocated.");
num_filters_ = num;
}
double get_learning_rate_multiplier () const { return learning_rate_multiplier; }
double get_weight_decay_multiplier () const { return weight_decay_multiplier; }
void set_learning_rate_multiplier(double val) { learning_rate_multiplier = val; }
......@@ -130,15 +150,15 @@ namespace dlib
void setup (const SUBNET& sub)
{
long num_inputs = _nr*_nc*sub.get_output().k();
long num_outputs = _num_filters;
long num_outputs = num_filters_;
// allocate params for the filters and also for the filter bias values.
params.set_size(num_inputs*_num_filters + _num_filters);
params.set_size(num_inputs*num_filters_ + num_filters_);
dlib::rand rnd(std::rand());
randomize_parameters(params, num_inputs+num_outputs, rnd);
filters = alias_tensor(_num_filters, sub.get_output().k(), _nr, _nc);
biases = alias_tensor(1,_num_filters);
filters = alias_tensor(num_filters_, sub.get_output().k(), _nr, _nc);
biases = alias_tensor(1,num_filters_);
// set the initial bias values to zero
biases(params,filters.size()) = 0;
......@@ -182,7 +202,7 @@ namespace dlib
{
serialize("con_4", out);
serialize(item.params, out);
serialize(_num_filters, out);
serialize(item.num_filters_, out);
serialize(_nr, out);
serialize(_nc, out);
serialize(_stride_y, out);
......@@ -201,7 +221,6 @@ namespace dlib
{
std::string version;
deserialize(version, in);
long num_filters;
long nr;
long nc;
int stride_y;
......@@ -209,7 +228,7 @@ namespace dlib
if (version == "con_4")
{
deserialize(item.params, in);
deserialize(num_filters, in);
deserialize(item.num_filters_, in);
deserialize(nr, in);
deserialize(nc, in);
deserialize(stride_y, in);
......@@ -224,14 +243,6 @@ namespace dlib
deserialize(item.bias_weight_decay_multiplier, in);
if (item.padding_y_ != _padding_y) throw serialization_error("Wrong padding_y found while deserializing dlib::con_");
if (item.padding_x_ != _padding_x) throw serialization_error("Wrong padding_x found while deserializing dlib::con_");
if (num_filters != _num_filters)
{
std::ostringstream sout;
sout << "Wrong num_filters found while deserializing dlib::con_" << std::endl;
sout << "expected " << _num_filters << " but found " << num_filters << std::endl;
throw serialization_error(sout.str());
}
if (nr != _nr) throw serialization_error("Wrong nr found while deserializing dlib::con_");
if (nc != _nc) throw serialization_error("Wrong nc found while deserializing dlib::con_");
if (stride_y != _stride_y) throw serialization_error("Wrong stride_y found while deserializing dlib::con_");
......@@ -247,7 +258,7 @@ namespace dlib
friend std::ostream& operator<<(std::ostream& out, const con_& item)
{
out << "con\t ("
<< "num_filters="<<_num_filters
<< "num_filters="<<item.num_filters_
<< ", nr="<<_nr
<< ", nc="<<_nc
<< ", stride_y="<<_stride_y
......@@ -265,7 +276,7 @@ namespace dlib
friend void to_xml(const con_& item, std::ostream& out)
{
out << "<con"
<< " num_filters='"<<_num_filters<<"'"
<< " num_filters='"<<item.num_filters_<<"'"
<< " nr='"<<_nr<<"'"
<< " nc='"<<_nc<<"'"
<< " stride_y='"<<_stride_y<<"'"
......@@ -290,6 +301,7 @@ namespace dlib
double weight_decay_multiplier;
double bias_learning_rate_multiplier;
double bias_weight_decay_multiplier;
long num_filters_;
// These are here only because older versions of con (which you might encounter
// serialized to disk) used different padding settings.
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment