Commit a9e1c9e4 authored by Davis King's avatar Davis King

Made add_loss_layer's batch operator() more general.

parent 846a5704
...@@ -1923,12 +1923,13 @@ namespace dlib ...@@ -1923,12 +1923,13 @@ namespace dlib
return temp_label; return temp_label;
} }
template <typename iterable_type>
std::vector<label_type> operator() ( std::vector<label_type> operator() (
const std::vector<input_type>& data, const iterable_type& data,
size_t batch_size = 128 size_t batch_size = 128
) )
{ {
std::vector<label_type> results(data.size()); std::vector<label_type> results(std::distance(data.begin(), data.end()));
auto o = results.begin(); auto o = results.begin();
for (auto i = data.begin(); i < data.end(); i+=batch_size, o+=batch_size) for (auto i = data.begin(); i < data.end(); i+=batch_size, o+=batch_size)
{ {
......
...@@ -688,13 +688,17 @@ namespace dlib ...@@ -688,13 +688,17 @@ namespace dlib
label_type. label_type.
!*/ !*/
template <typename iterable_type>
std::vector<label_type> operator() ( std::vector<label_type> operator() (
const std::vector<input_type>& data, const iterable_type& data,
size_t batch_size = 128 size_t batch_size = 128
); );
/*! /*!
requires requires
- batch_size > 0 - batch_size > 0
- data must have a .begin() and .end() that supply iterators over a
sequence of input_type elements. E.g. data could have a type of
std::vector<input_type>
ensures ensures
- runs all the objects in data through the network and returns their - runs all the objects in data through the network and returns their
predicted labels. This means this function returns a vector V such that: predicted labels. This means this function returns a vector V such that:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment