Commit 70619d2f authored by Davis King's avatar Davis King

Made input_layer() work in a more reasonable and general way.

parent a105c616
...@@ -2617,12 +2617,43 @@ namespace dlib ...@@ -2617,12 +2617,43 @@ namespace dlib
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
namespace dimpl
{
template <typename T>
T& get_input_details (
T& net
)
{
return net;
}
template <typename T, bool is_first, typename enabled>
auto get_input_details (
dimpl::subnet_wrapper<T,is_first,enabled>& net
) -> decltype(net.layer_details())&
{
return net.layer_details();
}
template <typename T, bool is_first, typename enabled>
auto get_input_details (
const dimpl::subnet_wrapper<T,is_first,enabled>& net
) -> decltype(net.layer_details())&
{
return net.layer_details();
}
}
template <typename net_type> template <typename net_type>
auto input_layer ( auto input_layer (
net_type& net net_type& net
) -> decltype(layer<net_type::num_layers-1>(net))& ) -> decltype(dimpl::get_input_details(layer<net_type::num_layers-1>(net)))&
{ {
return layer<net_type::num_layers-1>(net); // Calling input_layer() on a subnet_wrapper is a little funny since the behavior of
// .subnet() returns another subnet_wrapper rather than an input details object as it
// does in add_layer.
return dimpl::get_input_details(layer<net_type::num_layers-1>(net));
} }
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
......
...@@ -1396,6 +1396,7 @@ namespace dlib ...@@ -1396,6 +1396,7 @@ namespace dlib
- returns the input later of the given network object. Specifically, this - returns the input later of the given network object. Specifically, this
function is equivalent to calling: function is equivalent to calling:
layer<net_type::num_layers-1>(net); layer<net_type::num_layers-1>(net);
That is, you get the input layer details object for the network.
!*/ !*/
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
......
...@@ -725,7 +725,7 @@ namespace dlib ...@@ -725,7 +725,7 @@ namespace dlib
{ {
dpoint p = output_tensor_to_input_tensor(net, point(c,r)); dpoint p = output_tensor_to_input_tensor(net, point(c,r));
drectangle rect = centered_drect(p, options.detector_width, options.detector_height); drectangle rect = centered_drect(p, options.detector_width, options.detector_height);
rect = input_layer(net).layer_details().tensor_space_to_image_space(input_tensor,rect); rect = input_layer(net).tensor_space_to_image_space(input_tensor,rect);
dets_accum.push_back(intermediate_detection(rect, score, r*output_tensor.nc() + c)); dets_accum.push_back(intermediate_detection(rect, score, r*output_tensor.nc() + c));
} }
...@@ -743,7 +743,7 @@ namespace dlib ...@@ -743,7 +743,7 @@ namespace dlib
) const ) const
{ {
using namespace std; using namespace std;
if (!input_layer(net).layer_details().image_contained_point(input_tensor,center(rect))) if (!input_layer(net).image_contained_point(input_tensor,center(rect)))
{ {
std::ostringstream sout; std::ostringstream sout;
sout << "Encountered a truth rectangle located at " << rect << " that is outside the image." << endl; sout << "Encountered a truth rectangle located at " << rect << " that is outside the image." << endl;
...@@ -757,12 +757,12 @@ namespace dlib ...@@ -757,12 +757,12 @@ namespace dlib
// it means the box can't be matched by the sliding window. But picking the // it means the box can't be matched by the sliding window. But picking the
// max causes the right error message to be selected in the logic below. // max causes the right error message to be selected in the logic below.
const double scale = std::max(options.detector_width/(double)rect.width(), options.detector_height/(double)rect.height()); const double scale = std::max(options.detector_width/(double)rect.width(), options.detector_height/(double)rect.height());
const rectangle mapped_rect = input_layer(net).layer_details().image_space_to_tensor_space(input_tensor, std::min(1.0,scale), rect); const rectangle mapped_rect = input_layer(net).image_space_to_tensor_space(input_tensor, std::min(1.0,scale), rect);
// compute the detection window that we would use at this position. // compute the detection window that we would use at this position.
point tensor_p = center(mapped_rect); point tensor_p = center(mapped_rect);
rectangle det_window = centered_rect(tensor_p, options.detector_width,options.detector_height); rectangle det_window = centered_rect(tensor_p, options.detector_width,options.detector_height);
det_window = input_layer(net).layer_details().tensor_space_to_image_space(input_tensor, det_window); det_window = input_layer(net).tensor_space_to_image_space(input_tensor, det_window);
// make sure the rect can actually be represented by the image pyramid we are // make sure the rect can actually be represented by the image pyramid we are
// using. // using.
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment