Commit 7352ef0e authored by Davis King's avatar Davis King

merged

parents 2425dcaf ea9aae0f
...@@ -164,7 +164,6 @@ namespace dlib ...@@ -164,7 +164,6 @@ namespace dlib
// set the initial bias values to zero // set the initial bias values to zero
biases(params,filters.size()) = 0; biases(params,filters.size()) = 0;
} }
template <typename SUBNET> template <typename SUBNET>
...@@ -346,17 +345,22 @@ namespace dlib ...@@ -346,17 +345,22 @@ namespace dlib
static_assert(0 <= _padding_x && _padding_x < _nc, "The padding must be smaller than the filter size."); static_assert(0 <= _padding_x && _padding_x < _nc, "The padding must be smaller than the filter size.");
cont_( cont_(
num_con_outputs o
) : ) :
learning_rate_multiplier(1), learning_rate_multiplier(1),
weight_decay_multiplier(1), weight_decay_multiplier(1),
bias_learning_rate_multiplier(1), bias_learning_rate_multiplier(1),
bias_weight_decay_multiplier(0), bias_weight_decay_multiplier(0),
num_filters_(o.num_outputs),
padding_y_(_padding_y), padding_y_(_padding_y),
padding_x_(_padding_x) padding_x_(_padding_x)
{ {
DLIB_CASSERT(num_filters_ > 0);
} }
long num_filters() const { return _num_filters; } cont_() : cont_(num_con_outputs(_num_filters)) {}
long num_filters() const { return num_filters_; }
long nr() const { return _nr; } long nr() const { return _nr; }
long nc() const { return _nc; } long nc() const { return _nc; }
long stride_y() const { return _stride_y; } long stride_y() const { return _stride_y; }
...@@ -364,6 +368,14 @@ namespace dlib ...@@ -364,6 +368,14 @@ namespace dlib
long padding_y() const { return padding_y_; } long padding_y() const { return padding_y_; }
long padding_x() const { return padding_x_; } long padding_x() const { return padding_x_; }
void set_num_filters(long num)
{
DLIB_CASSERT(num > 0);
DLIB_CASSERT(get_layer_params().size() == 0,
"You can't change the number of filters in cont_ if the parameter tensor has already been allocated.");
num_filters_ = num;
}
double get_learning_rate_multiplier () const { return learning_rate_multiplier; } double get_learning_rate_multiplier () const { return learning_rate_multiplier; }
double get_weight_decay_multiplier () const { return weight_decay_multiplier; } double get_weight_decay_multiplier () const { return weight_decay_multiplier; }
void set_learning_rate_multiplier(double val) { learning_rate_multiplier = val; } void set_learning_rate_multiplier(double val) { learning_rate_multiplier = val; }
...@@ -402,6 +414,7 @@ namespace dlib ...@@ -402,6 +414,7 @@ namespace dlib
weight_decay_multiplier(item.weight_decay_multiplier), weight_decay_multiplier(item.weight_decay_multiplier),
bias_learning_rate_multiplier(item.bias_learning_rate_multiplier), bias_learning_rate_multiplier(item.bias_learning_rate_multiplier),
bias_weight_decay_multiplier(item.bias_weight_decay_multiplier), bias_weight_decay_multiplier(item.bias_weight_decay_multiplier),
num_filters_(item.num_filters_),
padding_y_(item.padding_y_), padding_y_(item.padding_y_),
padding_x_(item.padding_x_) padding_x_(item.padding_x_)
{ {
...@@ -427,6 +440,7 @@ namespace dlib ...@@ -427,6 +440,7 @@ namespace dlib
weight_decay_multiplier = item.weight_decay_multiplier; weight_decay_multiplier = item.weight_decay_multiplier;
bias_learning_rate_multiplier = item.bias_learning_rate_multiplier; bias_learning_rate_multiplier = item.bias_learning_rate_multiplier;
bias_weight_decay_multiplier = item.bias_weight_decay_multiplier; bias_weight_decay_multiplier = item.bias_weight_decay_multiplier;
num_filters_ = item.num_filters_;
return *this; return *this;
} }
...@@ -434,18 +448,18 @@ namespace dlib ...@@ -434,18 +448,18 @@ namespace dlib
void setup (const SUBNET& sub) void setup (const SUBNET& sub)
{ {
long num_inputs = _nr*_nc*sub.get_output().k(); long num_inputs = _nr*_nc*sub.get_output().k();
long num_outputs = _num_filters; long num_outputs = num_filters_;
// allocate params for the filters and also for the filter bias values. // allocate params for the filters and also for the filter bias values.
params.set_size(num_inputs*_num_filters + _num_filters); params.set_size(num_inputs*num_filters_ + num_filters_);
dlib::rand rnd(std::rand()); dlib::rand rnd(std::rand());
randomize_parameters(params, num_inputs+num_outputs, rnd); randomize_parameters(params, num_inputs+num_outputs, rnd);
filters = alias_tensor(sub.get_output().k(),_num_filters, _nr, _nc); filters = alias_tensor(sub.get_output().k(), num_filters_, _nr, _nc);
biases = alias_tensor(1,_num_filters); biases = alias_tensor(1,num_filters_);
// set the initial bias values to zero // set the initial bias values to zero
biases(params,_num_filters) = 0; biases(params,filters.size()) = 0;
} }
template <typename SUBNET> template <typename SUBNET>
...@@ -484,7 +498,7 @@ namespace dlib ...@@ -484,7 +498,7 @@ namespace dlib
{ {
serialize("cont_1", out); serialize("cont_1", out);
serialize(item.params, out); serialize(item.params, out);
serialize(_num_filters, out); serialize(item.num_filters_, out);
serialize(_nr, out); serialize(_nr, out);
serialize(_nc, out); serialize(_nc, out);
serialize(_stride_y, out); serialize(_stride_y, out);
...@@ -503,7 +517,6 @@ namespace dlib ...@@ -503,7 +517,6 @@ namespace dlib
{ {
std::string version; std::string version;
deserialize(version, in); deserialize(version, in);
long num_filters;
long nr; long nr;
long nc; long nc;
int stride_y; int stride_y;
...@@ -511,7 +524,7 @@ namespace dlib ...@@ -511,7 +524,7 @@ namespace dlib
if (version == "cont_1") if (version == "cont_1")
{ {
deserialize(item.params, in); deserialize(item.params, in);
deserialize(num_filters, in); deserialize(item.num_filters_, in);
deserialize(nr, in); deserialize(nr, in);
deserialize(nc, in); deserialize(nc, in);
deserialize(stride_y, in); deserialize(stride_y, in);
...@@ -526,14 +539,6 @@ namespace dlib ...@@ -526,14 +539,6 @@ namespace dlib
deserialize(item.bias_weight_decay_multiplier, in); deserialize(item.bias_weight_decay_multiplier, in);
if (item.padding_y_ != _padding_y) throw serialization_error("Wrong padding_y found while deserializing dlib::con_"); if (item.padding_y_ != _padding_y) throw serialization_error("Wrong padding_y found while deserializing dlib::con_");
if (item.padding_x_ != _padding_x) throw serialization_error("Wrong padding_x found while deserializing dlib::con_"); if (item.padding_x_ != _padding_x) throw serialization_error("Wrong padding_x found while deserializing dlib::con_");
if (num_filters != _num_filters)
{
std::ostringstream sout;
sout << "Wrong num_filters found while deserializing dlib::con_" << std::endl;
sout << "expected " << _num_filters << " but found " << num_filters << std::endl;
throw serialization_error(sout.str());
}
if (nr != _nr) throw serialization_error("Wrong nr found while deserializing dlib::con_"); if (nr != _nr) throw serialization_error("Wrong nr found while deserializing dlib::con_");
if (nc != _nc) throw serialization_error("Wrong nc found while deserializing dlib::con_"); if (nc != _nc) throw serialization_error("Wrong nc found while deserializing dlib::con_");
if (stride_y != _stride_y) throw serialization_error("Wrong stride_y found while deserializing dlib::con_"); if (stride_y != _stride_y) throw serialization_error("Wrong stride_y found while deserializing dlib::con_");
...@@ -592,6 +597,7 @@ namespace dlib ...@@ -592,6 +597,7 @@ namespace dlib
double weight_decay_multiplier; double weight_decay_multiplier;
double bias_learning_rate_multiplier; double bias_learning_rate_multiplier;
double bias_weight_decay_multiplier; double bias_weight_decay_multiplier;
long num_filters_;
int padding_y_; int padding_y_;
int padding_x_; int padding_x_;
......
...@@ -941,6 +941,24 @@ namespace dlib ...@@ -941,6 +941,24 @@ namespace dlib
- #get_bias_weight_decay_multiplier() == 0 - #get_bias_weight_decay_multiplier() == 0
!*/ !*/
cont_(
num_con_outputs o
);
/*!
ensures
- #num_filters() == o.num_outputs
- #nr() == _nr
- #nc() == _nc
- #stride_y() == _stride_y
- #stride_x() == _stride_x
- #padding_y() == _padding_y
- #padding_x() == _padding_x
- #get_learning_rate_multiplier() == 1
- #get_weight_decay_multiplier() == 1
- #get_bias_learning_rate_multiplier() == 1
- #get_bias_weight_decay_multiplier() == 0
!*/
long num_filters( long num_filters(
) const; ) const;
/*! /*!
...@@ -950,6 +968,19 @@ namespace dlib ...@@ -950,6 +968,19 @@ namespace dlib
of filters. of filters.
!*/ !*/
void set_num_filters(
long num
);
/*!
requires
- num > 0
- get_layer_params().size() == 0
(i.e. You can't change the number of filters in cont_ if the parameter
tensor has already been allocated.)
ensures
- #num_filters() == num
!*/
long nr( long nr(
) const; ) const;
/*! /*!
......
...@@ -166,7 +166,7 @@ boost::python::list chinese_whispers_clustering(boost::python::list descriptors, ...@@ -166,7 +166,7 @@ boost::python::list chinese_whispers_clustering(boost::python::list descriptors,
edges.push_back(sample_pair(i,j)); edges.push_back(sample_pair(i,j));
} }
} }
const auto num_clusters = chinese_whispers(edges, labels); chinese_whispers(edges, labels);
for (size_t i = 0; i < labels.size(); ++i) for (size_t i = 0; i < labels.size(); ++i)
{ {
clusters.append(labels[i]); clusters.append(labels[i]);
...@@ -242,7 +242,6 @@ boost::python::list get_face_chips ( ...@@ -242,7 +242,6 @@ boost::python::list get_face_chips (
boost::python::list chips_list; boost::python::list chips_list;
int num_faces = faces.size();
std::vector<chip_details> dets; std::vector<chip_details> dets;
for (auto& f : faces) for (auto& f : faces)
dets.push_back(get_face_chip_details(f, size, padding)); dets.push_back(get_face_chip_details(f, size, padding));
...@@ -253,9 +252,9 @@ boost::python::list get_face_chips ( ...@@ -253,9 +252,9 @@ boost::python::list get_face_chips (
{ {
boost::python::list img; boost::python::list img;
for(int row=0; row<size; row++) { for(size_t row=0; row<size; row++) {
boost::python::list row_list; boost::python::list row_list;
for(int col=0; col<size; col++) { for(size_t col=0; col<size; col++) {
rgb_pixel pixel = chip(row, col); rgb_pixel pixel = chip(row, col);
boost::python::list item; boost::python::list item;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment