Commit e2e26f3b authored by Davis King's avatar Davis King

Relaxed the preconditions so that layers are allowed to output tensors

that contain a different number of samples than their input tensors.
parent fd867dd8
...@@ -1059,7 +1059,7 @@ namespace dlib ...@@ -1059,7 +1059,7 @@ namespace dlib
{ {
subnetwork.forward(x); subnetwork.forward(x);
const dimpl::subnet_wrapper<subnet_type> wsub(subnetwork); const dimpl::subnet_wrapper<subnet_type> wsub(subnetwork);
loss.to_label(wsub, obegin); loss.to_label(x, wsub, obegin);
} }
template <typename input_iterator, typename output_iterator> template <typename input_iterator, typename output_iterator>
......
...@@ -309,7 +309,6 @@ namespace dlib ...@@ -309,7 +309,6 @@ namespace dlib
return forward(temp_tensor); return forward(temp_tensor);
- The return value from this function is also available in #get_output(). - The return value from this function is also available in #get_output().
i.e. this function returns #get_output(). i.e. this function returns #get_output().
- #get_output().num_samples() == std::distance(ibegin,iend)*sample_expansion_factor.
- have_same_dimensions(#get_gradient_input(), #get_output()) == true. - have_same_dimensions(#get_gradient_input(), #get_output()) == true.
- All elements of #get_gradient_input() are set to 0. - All elements of #get_gradient_input() are set to 0.
i.e. calling this function clears out #get_gradient_input() and ensures i.e. calling this function clears out #get_gradient_input() and ensures
...@@ -341,7 +340,6 @@ namespace dlib ...@@ -341,7 +340,6 @@ namespace dlib
layer_details().forward(subnet(), get_output()); layer_details().forward(subnet(), get_output());
- The return value from this function is also available in #get_output(). - The return value from this function is also available in #get_output().
i.e. this function returns #get_output(). i.e. this function returns #get_output().
- #get_output().num_samples() == x.num_samples().
- have_same_dimensions(#get_gradient_input(), #get_output()) == true - have_same_dimensions(#get_gradient_input(), #get_output()) == true
- All elements of #get_gradient_input() are set to 0. - All elements of #get_gradient_input() are set to 0.
i.e. calling this function clears out #get_gradient_input() and ensures i.e. calling this function clears out #get_gradient_input() and ensures
...@@ -382,7 +380,6 @@ namespace dlib ...@@ -382,7 +380,6 @@ namespace dlib
/*! /*!
requires requires
- forward(x) was called to forward propagate x though the network. - forward(x) was called to forward propagate x though the network.
- x.num_samples() == get_output().num_samples()
- get_gradient_input() has been set equal to the gradient of this network's - get_gradient_input() has been set equal to the gradient of this network's
output with respect to some loss function. output with respect to some loss function.
- This instance of solvers has only ever been used with this network. That - This instance of solvers has only ever been used with this network. That
......
...@@ -163,7 +163,6 @@ namespace dlib ...@@ -163,7 +163,6 @@ namespace dlib
output into #output. In particular, forward() can use any of the outputs output into #output. In particular, forward() can use any of the outputs
in sub (e.g. sub.get_output(), sub.subnet().get_output(), etc.) to in sub (e.g. sub.get_output(), sub.subnet().get_output(), etc.) to
compute whatever it wants. compute whatever it wants.
- #output.num_samples() == sub.get_output().num_samples()
!*/ !*/
template <typename SUBNET> template <typename SUBNET>
......
...@@ -24,6 +24,7 @@ namespace dlib ...@@ -24,6 +24,7 @@ namespace dlib
typename label_iterator typename label_iterator
> >
void to_label ( void to_label (
const tensor& input_tensor,
const SUB_TYPE& sub, const SUB_TYPE& sub,
label_iterator iter label_iterator iter
) const ) const
...@@ -32,7 +33,7 @@ namespace dlib ...@@ -32,7 +33,7 @@ namespace dlib
DLIB_CASSERT(output_tensor.nr() == 1 && DLIB_CASSERT(output_tensor.nr() == 1 &&
output_tensor.nc() == 1 && output_tensor.nc() == 1 &&
output_tensor.k() == 1,""); output_tensor.k() == 1,"");
DLIB_CASSERT(output_tensor.num_samples()%sample_expansion_factor == 0,""); DLIB_CASSERT(input_tensor.num_samples() == output_tensor.num_samples(),"");
const float* out_data = output_tensor.host(); const float* out_data = output_tensor.host();
for (long i = 0; i < output_tensor.num_samples(); ++i) for (long i = 0; i < output_tensor.num_samples(); ++i)
......
...@@ -56,6 +56,7 @@ namespace dlib ...@@ -56,6 +56,7 @@ namespace dlib
typename label_iterator typename label_iterator
> >
void to_label ( void to_label (
const tensor& input_tensor,
const SUB_TYPE& sub, const SUB_TYPE& sub,
label_iterator iter label_iterator iter
) const; ) const;
...@@ -63,19 +64,19 @@ namespace dlib ...@@ -63,19 +64,19 @@ namespace dlib
requires requires
- SUBNET implements the SUBNET interface defined at the top of - SUBNET implements the SUBNET interface defined at the top of
layers_abstract.h. layers_abstract.h.
- sub.get_output().num_samples()%sample_expansion_factor == 0. - input_tensor was given as input to the network sub and the outputs are
- All outputs in each layer of sub have the same number of samples. That now visible in layer<i>(sub).get_output(), for all valid i.
is, for all valid i: - input_tensor.num_samples() > 0
- sub.get_output().num_samples() == layer<i>(sub).get_output().num_samples() - input_tensor.num_samples()%sample_expansion_factor == 0.
- iter == an iterator pointing to the beginning of a range of - iter == an iterator pointing to the beginning of a range of
sub.get_output().num_samples()/sample_expansion_factor elements. input_tensor.num_samples()/sample_expansion_factor elements. Moreover,
Moreover, they must be label_type elements. they must be label_type elements.
ensures ensures
- Converts the output of the provided network to label_type objects and - Converts the output of the provided network to label_type objects and
stores the results into the range indicated by iter. In particular, for stores the results into the range indicated by iter. In particular, for
all valid i and j, it will be the case that: all valid i, it will be the case that:
*(iter+i/sample_expansion_factor) is the output corresponding to the *(iter+i/sample_expansion_factor) is the element corresponding to the
ith sample in layer<j>(sub).get_output(). output of sub for the ith sample in input_tensor.
!*/ !*/
template < template <
...@@ -96,15 +97,14 @@ namespace dlib ...@@ -96,15 +97,14 @@ namespace dlib
- input_tensor.num_samples() > 0 - input_tensor.num_samples() > 0
- input_tensor.num_samples()%sample_expansion_factor == 0. - input_tensor.num_samples()%sample_expansion_factor == 0.
- for all valid i: - for all valid i:
- layer<i>(sub).get_output().num_samples() == input_tensor.num_samples().
- layer<i>(sub).get_gradient_input() has the same dimensions as - layer<i>(sub).get_gradient_input() has the same dimensions as
layer<i>(sub).get_output(). layer<i>(sub).get_output().
- truth == an iterator pointing to the beginning of a range of - truth == an iterator pointing to the beginning of a range of
input_tensor.num_samples()/sample_expansion_factor elements. In input_tensor.num_samples()/sample_expansion_factor elements. Moreover,
particular, they must be label_type elements. they must be label_type elements.
- for all valid i and j: - for all valid i:
- *(truth+i/sample_expansion_factor) is the label of the ith sample in - *(truth+i/sample_expansion_factor) is the label of the ith sample in
layer<j>(sub).get_output(). input_tensor.
ensures ensures
- This function computes a loss function that describes how well the output - This function computes a loss function that describes how well the output
of sub matches the expected labels given by truth. Let's write the loss of sub matches the expected labels given by truth. Let's write the loss
...@@ -154,6 +154,7 @@ namespace dlib ...@@ -154,6 +154,7 @@ namespace dlib
typename label_iterator typename label_iterator
> >
void to_label ( void to_label (
const tensor& input_tensor,
const SUB_TYPE& sub, const SUB_TYPE& sub,
label_iterator iter label_iterator iter
) const; ) const;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment