Commit 68112e35 authored by Davis King's avatar Davis King

Improved the error messages related to numpy_image usage and also

improved the way overload resolution works in pybind11 for these objects.
parent 06ce2078
...@@ -24,7 +24,7 @@ namespace dlib ...@@ -24,7 +24,7 @@ namespace dlib
typename pixel_type typename pixel_type
> >
bool is_image ( bool is_image (
const py::array& obj const py::array& img
) )
/*! /*!
ensures ensures
...@@ -33,10 +33,14 @@ namespace dlib ...@@ -33,10 +33,14 @@ namespace dlib
!*/ !*/
{ {
using basic_pixel_type = typename pixel_traits<pixel_type>::basic_pixel_type; using basic_pixel_type = typename pixel_traits<pixel_type>::basic_pixel_type;
const size_t expected_channels = pixel_traits<pixel_type>::num;
return obj.dtype().kind() == py::dtype::of<basic_pixel_type>().kind() && const bool has_correct_number_of_dims = (img.ndim()==2 && expected_channels==1) ||
obj.itemsize() == sizeof(basic_pixel_type) && (img.ndim()==3 && img.shape(2)==expected_channels);
((pixel_traits<pixel_type>::num==1) ? (obj.ndim()==2) : (obj.ndim()==3));
return img.dtype().kind() == py::dtype::of<basic_pixel_type>().kind() &&
img.itemsize() == sizeof(basic_pixel_type) &&
has_correct_number_of_dims;
} }
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
...@@ -51,7 +55,7 @@ namespace dlib ...@@ -51,7 +55,7 @@ namespace dlib
const size_t expected_channels = pixel_traits<pixel_type>::num; const size_t expected_channels = pixel_traits<pixel_type>::num;
if (expected_channels == 1) if (expected_channels == 1)
{ {
if (img.ndim() != 2) if (!(img.ndim() == 2 || (img.ndim()==3&&img.shape(2)==1)))
throw dlib::error("Expected a 2D numpy array, but instead got one with " + std::to_string(img.ndim()) + " dimensions."); throw dlib::error("Expected a 2D numpy array, but instead got one with " + std::to_string(img.ndim()) + " dimensions.");
} }
else else
...@@ -113,7 +117,7 @@ namespace dlib ...@@ -113,7 +117,7 @@ namespace dlib
template < template <
typename pixel_type typename pixel_type
> >
class numpy_image : public py::array_t<typename pixel_traits<pixel_type>::basic_pixel_type> class numpy_image : public py::array_t<typename pixel_traits<pixel_type>::basic_pixel_type, py::array::c_style>
{ {
/*! /*!
REQUIREMENTS ON pixel_type REQUIREMENTS ON pixel_type
...@@ -135,28 +139,50 @@ namespace dlib ...@@ -135,28 +139,50 @@ namespace dlib
numpy_image( numpy_image(
py::array& img py::array& img
) : py::array_t<typename pixel_traits<pixel_type>::basic_pixel_type>(img) ) : py::array_t<typename pixel_traits<pixel_type>::basic_pixel_type, py::array::c_style>(img)
{ {
assert_is_image<pixel_type>(img); assert_is_image<pixel_type>(img);
} }
numpy_image(
const numpy_image& img
) : py::array_t<typename pixel_traits<pixel_type>::basic_pixel_type, py::array::c_style>(img)
{
}
numpy_image& operator= ( numpy_image& operator= (
const py::object& rhs const py::object& rhs
) )
{ {
assert_is_image<pixel_type>(rhs);
*this = rhs.cast<py::array>(); *this = rhs.cast<py::array>();
return *this; return *this;
} }
numpy_image& operator= ( numpy_image& operator= (
const py::array_t<typename pixel_traits<pixel_type>::basic_pixel_type>& rhs const py::array_t<typename pixel_traits<pixel_type>::basic_pixel_type, py::array::c_style>& rhs
) )
{ {
assert_is_image<pixel_type>(rhs); assert_is_image<pixel_type>(rhs);
py::array_t<typename pixel_traits<pixel_type>::basic_pixel_type>::operator=(rhs); py::array_t<typename pixel_traits<pixel_type>::basic_pixel_type, py::array::c_style>::operator=(rhs);
return *this;
}
numpy_image& operator= (
const numpy_image& rhs
)
{
py::array_t<typename pixel_traits<pixel_type>::basic_pixel_type, py::array::c_style>::operator=(rhs);
return *this; return *this;
} }
numpy_image (
matrix<pixel_type>&& rhs
)
{
*this = convert_to_numpy(std::move(rhs));
}
numpy_image& operator= ( numpy_image& operator= (
matrix<pixel_type>&& rhs matrix<pixel_type>&& rhs
) )
...@@ -170,13 +196,13 @@ namespace dlib ...@@ -170,13 +196,13 @@ namespace dlib
using basic_pixel_type = typename pixel_traits<pixel_type>::basic_pixel_type; using basic_pixel_type = typename pixel_traits<pixel_type>::basic_pixel_type;
constexpr size_t channels = pixel_traits<pixel_type>::num; constexpr size_t channels = pixel_traits<pixel_type>::num;
if (channels != 1) if (channels != 1)
*this = py::array_t<basic_pixel_type>({rows, cols, channels}); *this = py::array_t<basic_pixel_type, py::array::c_style>({rows, cols, channels});
else else
*this = py::array_t<basic_pixel_type>({rows, cols}); *this = py::array_t<basic_pixel_type, py::array::c_style>({rows, cols});
} }
private: private:
static py::array_t<typename pixel_traits<pixel_type>::basic_pixel_type> convert_to_numpy(matrix<pixel_type>&& img) static py::array_t<typename pixel_traits<pixel_type>::basic_pixel_type, py::array::c_style> convert_to_numpy(matrix<pixel_type>&& img)
{ {
using basic_pixel_type = typename pixel_traits<pixel_type>::basic_pixel_type; using basic_pixel_type = typename pixel_traits<pixel_type>::basic_pixel_type;
const size_t dtype_size = sizeof(basic_pixel_type); const size_t dtype_size = sizeof(basic_pixel_type);
...@@ -188,12 +214,24 @@ namespace dlib ...@@ -188,12 +214,24 @@ namespace dlib
std::unique_ptr<pixel_type[]> arr_ptr = img.steal_memory(); std::unique_ptr<pixel_type[]> arr_ptr = img.steal_memory();
basic_pixel_type* arr = (basic_pixel_type *) arr_ptr.release(); basic_pixel_type* arr = (basic_pixel_type *) arr_ptr.release();
return pybind11::template array_t<basic_pixel_type>( if (channels == 1)
{rows, cols, channels}, // shape {
{dtype_size * cols * channels, dtype_size * channels, dtype_size}, // strides return pybind11::template array_t<basic_pixel_type, py::array::c_style>(
arr, // pointer {rows, cols}, // shape
pybind11::capsule{ arr, [](void *arr_p) { delete[] reinterpret_cast<basic_pixel_type*>(arr_p); } } {dtype_size*cols, dtype_size}, // strides
); arr, // pointer
pybind11::capsule{ arr, [](void *arr_p) { delete[] reinterpret_cast<basic_pixel_type*>(arr_p); } }
);
}
else
{
return pybind11::template array_t<basic_pixel_type, py::array::c_style>(
{rows, cols, channels}, // shape
{dtype_size * cols * channels, dtype_size * channels, dtype_size}, // strides
arr, // pointer
pybind11::capsule{ arr, [](void *arr_p) { delete[] reinterpret_cast<basic_pixel_type*>(arr_p); } }
);
}
} }
}; };
...@@ -310,9 +348,15 @@ namespace pybind11 ...@@ -310,9 +348,15 @@ namespace pybind11
using type = dlib::numpy_image<pixel_type>; using type = dlib::numpy_image<pixel_type>;
bool load(handle src, bool convert) { bool load(handle src, bool convert) {
if (!convert && !type::check_(src)) if (!type::check_(src))
return false;
// stash the output of ensure into a temp variable since assigning it to
// value (the member variable created by the PYBIND11_TYPE_CASTER)
// apparently causes the return bool value to be ignored?
auto temp = type::ensure(src);
if (!dlib::is_image<pixel_type>(temp))
return false; return false;
value = type::ensure(src); value = temp;
return static_cast<bool>(value); return static_cast<bool>(value);
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment