Commit 846a5704 authored by Davis King's avatar Davis King

Added an overload of operator() that lets you easily run a network on an

entire std::vector of objects.
parent 93ab80c7
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#include <utility> #include <utility>
#include <tuple> #include <tuple>
#include <cmath> #include <cmath>
#include <vector>
#include "tensor_tools.h" #include "tensor_tools.h"
...@@ -1922,6 +1923,21 @@ namespace dlib ...@@ -1922,6 +1923,21 @@ namespace dlib
return temp_label; return temp_label;
} }
std::vector<label_type> operator() (
const std::vector<input_type>& data,
size_t batch_size = 128
)
{
std::vector<label_type> results(data.size());
auto o = results.begin();
for (auto i = data.begin(); i < data.end(); i+=batch_size, o+=batch_size)
{
auto end = std::min(i+batch_size, data.end());
(*this)(i, end, o);
}
return results;
}
template <typename label_iterator> template <typename label_iterator>
double compute_loss ( double compute_loss (
const tensor& x, const tensor& x,
......
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
#include <memory> #include <memory>
#include <type_traits> #include <type_traits>
#include <tuple> #include <tuple>
#include <vector>
#include "../rand.h" #include "../rand.h"
...@@ -687,6 +688,25 @@ namespace dlib ...@@ -687,6 +688,25 @@ namespace dlib
label_type. label_type.
!*/ !*/
std::vector<label_type> operator() (
const std::vector<input_type>& data,
size_t batch_size = 128
);
/*!
requires
- batch_size > 0
ensures
- runs all the objects in data through the network and returns their
predicted labels. This means this function returns a vector V such that:
- V.size() == data.size()
- for all valid i: V[i] == the predicted label of data[i].
- Elements of data are run through the network in batches of batch_size
items. Using a batch_size > 1 can be faster because it better exploits
the available hardware parallelism.
- loss_details().to_label() is used to convert the network output into a
label_type.
!*/
// ------------- // -------------
template <typename label_iterator> template <typename label_iterator>
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment