Commit 06ce2078 authored by Davis King's avatar Davis King

Fixed numpy_image not working correctly for grayscale images.

parent 2b55b996
...@@ -33,11 +33,10 @@ namespace dlib ...@@ -33,11 +33,10 @@ namespace dlib
!*/ !*/
{ {
using basic_pixel_type = typename pixel_traits<pixel_type>::basic_pixel_type; using basic_pixel_type = typename pixel_traits<pixel_type>::basic_pixel_type;
constexpr size_t channels = pixel_traits<pixel_type>::num;
return obj.dtype().kind() == py::dtype::of<basic_pixel_type>().kind() && return obj.dtype().kind() == py::dtype::of<basic_pixel_type>().kind() &&
obj.itemsize() == sizeof(basic_pixel_type) && obj.itemsize() == sizeof(basic_pixel_type) &&
obj.ndim() == pixel_traits<pixel_type>::num; ((pixel_traits<pixel_type>::num==1) ? (obj.ndim()==2) : (obj.ndim()==3));
} }
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
...@@ -170,7 +169,10 @@ namespace dlib ...@@ -170,7 +169,10 @@ namespace dlib
{ {
using basic_pixel_type = typename pixel_traits<pixel_type>::basic_pixel_type; using basic_pixel_type = typename pixel_traits<pixel_type>::basic_pixel_type;
constexpr size_t channels = pixel_traits<pixel_type>::num; constexpr size_t channels = pixel_traits<pixel_type>::num;
*this = py::array_t<basic_pixel_type>({rows, cols, channels}); if (channels != 1)
*this = py::array_t<basic_pixel_type>({rows, cols, channels});
else
*this = py::array_t<basic_pixel_type>({rows, cols});
} }
private: private:
...@@ -256,7 +258,7 @@ namespace dlib ...@@ -256,7 +258,7 @@ namespace dlib
assert_correct_num_channels_in_image<pixel_type>(img); assert_correct_num_channels_in_image<pixel_type>(img);
using basic_pixel_type = typename pixel_traits<pixel_type>::basic_pixel_type; using basic_pixel_type = typename pixel_traits<pixel_type>::basic_pixel_type;
if (img.strides(2) != sizeof(basic_pixel_type)) if (img.ndim()==3 && img.strides(2) != sizeof(basic_pixel_type))
throw dlib::error("The stride of the 3rd dimension (the channel dimension) of the numpy array must be " + std::to_string(sizeof(basic_pixel_type))); throw dlib::error("The stride of the 3rd dimension (the channel dimension) of the numpy array must be " + std::to_string(sizeof(basic_pixel_type)));
if (img.strides(1) != sizeof(pixel_type)) if (img.strides(1) != sizeof(pixel_type))
throw dlib::error("The stride of the 2nd dimension (the columns dimension) of the numpy array must be " + std::to_string(sizeof(pixel_type))); throw dlib::error("The stride of the 2nd dimension (the columns dimension) of the numpy array must be " + std::to_string(sizeof(pixel_type)));
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment