Commit a9e1c9e4 authored by Davis King's avatar Davis King

Made add_loss_layer's batch operator() more general.

parent 846a5704
......@@ -1923,12 +1923,13 @@ namespace dlib
return temp_label;
}
template <typename iterable_type>
std::vector<label_type> operator() (
const std::vector<input_type>& data,
const iterable_type& data,
size_t batch_size = 128
)
{
std::vector<label_type> results(data.size());
std::vector<label_type> results(std::distance(data.begin(), data.end()));
auto o = results.begin();
for (auto i = data.begin(); i < data.end(); i+=batch_size, o+=batch_size)
{
......
......@@ -688,13 +688,17 @@ namespace dlib
label_type.
!*/
template <typename iterable_type>
std::vector<label_type> operator() (
const std::vector<input_type>& data,
const iterable_type& data,
size_t batch_size = 128
);
/*!
requires
- batch_size > 0
- data must have a .begin() and .end() that supply iterators over a
sequence of input_type elements. E.g. data could have a type of
std::vector<input_type>
ensures
- runs all the objects in data through the network and returns their
predicted labels. This means this function returns a vector V such that:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment