Commit 68112e35 authored by Davis King's avatar Davis King

Improved the error messages related to numpy_image usage and also

improved the way overload resolution works in pybind11 for these objects.
parent 06ce2078
......@@ -24,7 +24,7 @@ namespace dlib
typename pixel_type
>
bool is_image (
const py::array& obj
const py::array& img
)
/*!
ensures
......@@ -33,10 +33,14 @@ namespace dlib
!*/
{
using basic_pixel_type = typename pixel_traits<pixel_type>::basic_pixel_type;
const size_t expected_channels = pixel_traits<pixel_type>::num;
return obj.dtype().kind() == py::dtype::of<basic_pixel_type>().kind() &&
obj.itemsize() == sizeof(basic_pixel_type) &&
((pixel_traits<pixel_type>::num==1) ? (obj.ndim()==2) : (obj.ndim()==3));
const bool has_correct_number_of_dims = (img.ndim()==2 && expected_channels==1) ||
(img.ndim()==3 && img.shape(2)==expected_channels);
return img.dtype().kind() == py::dtype::of<basic_pixel_type>().kind() &&
img.itemsize() == sizeof(basic_pixel_type) &&
has_correct_number_of_dims;
}
// ----------------------------------------------------------------------------------------
......@@ -51,7 +55,7 @@ namespace dlib
const size_t expected_channels = pixel_traits<pixel_type>::num;
if (expected_channels == 1)
{
if (img.ndim() != 2)
if (!(img.ndim() == 2 || (img.ndim()==3&&img.shape(2)==1)))
throw dlib::error("Expected a 2D numpy array, but instead got one with " + std::to_string(img.ndim()) + " dimensions.");
}
else
......@@ -113,7 +117,7 @@ namespace dlib
template <
typename pixel_type
>
class numpy_image : public py::array_t<typename pixel_traits<pixel_type>::basic_pixel_type>
class numpy_image : public py::array_t<typename pixel_traits<pixel_type>::basic_pixel_type, py::array::c_style>
{
/*!
REQUIREMENTS ON pixel_type
......@@ -135,28 +139,50 @@ namespace dlib
numpy_image(
py::array& img
) : py::array_t<typename pixel_traits<pixel_type>::basic_pixel_type>(img)
) : py::array_t<typename pixel_traits<pixel_type>::basic_pixel_type, py::array::c_style>(img)
{
assert_is_image<pixel_type>(img);
}
numpy_image(
const numpy_image& img
) : py::array_t<typename pixel_traits<pixel_type>::basic_pixel_type, py::array::c_style>(img)
{
}
numpy_image& operator= (
const py::object& rhs
)
{
assert_is_image<pixel_type>(rhs);
*this = rhs.cast<py::array>();
return *this;
}
numpy_image& operator= (
const py::array_t<typename pixel_traits<pixel_type>::basic_pixel_type>& rhs
const py::array_t<typename pixel_traits<pixel_type>::basic_pixel_type, py::array::c_style>& rhs
)
{
assert_is_image<pixel_type>(rhs);
py::array_t<typename pixel_traits<pixel_type>::basic_pixel_type>::operator=(rhs);
py::array_t<typename pixel_traits<pixel_type>::basic_pixel_type, py::array::c_style>::operator=(rhs);
return *this;
}
numpy_image& operator= (
const numpy_image& rhs
)
{
py::array_t<typename pixel_traits<pixel_type>::basic_pixel_type, py::array::c_style>::operator=(rhs);
return *this;
}
numpy_image (
matrix<pixel_type>&& rhs
)
{
*this = convert_to_numpy(std::move(rhs));
}
numpy_image& operator= (
matrix<pixel_type>&& rhs
)
......@@ -170,13 +196,13 @@ namespace dlib
using basic_pixel_type = typename pixel_traits<pixel_type>::basic_pixel_type;
constexpr size_t channels = pixel_traits<pixel_type>::num;
if (channels != 1)
*this = py::array_t<basic_pixel_type>({rows, cols, channels});
*this = py::array_t<basic_pixel_type, py::array::c_style>({rows, cols, channels});
else
*this = py::array_t<basic_pixel_type>({rows, cols});
*this = py::array_t<basic_pixel_type, py::array::c_style>({rows, cols});
}
private:
static py::array_t<typename pixel_traits<pixel_type>::basic_pixel_type> convert_to_numpy(matrix<pixel_type>&& img)
static py::array_t<typename pixel_traits<pixel_type>::basic_pixel_type, py::array::c_style> convert_to_numpy(matrix<pixel_type>&& img)
{
using basic_pixel_type = typename pixel_traits<pixel_type>::basic_pixel_type;
const size_t dtype_size = sizeof(basic_pixel_type);
......@@ -188,12 +214,24 @@ namespace dlib
std::unique_ptr<pixel_type[]> arr_ptr = img.steal_memory();
basic_pixel_type* arr = (basic_pixel_type *) arr_ptr.release();
return pybind11::template array_t<basic_pixel_type>(
{rows, cols, channels}, // shape
{dtype_size * cols * channels, dtype_size * channels, dtype_size}, // strides
arr, // pointer
pybind11::capsule{ arr, [](void *arr_p) { delete[] reinterpret_cast<basic_pixel_type*>(arr_p); } }
);
if (channels == 1)
{
return pybind11::template array_t<basic_pixel_type, py::array::c_style>(
{rows, cols}, // shape
{dtype_size*cols, dtype_size}, // strides
arr, // pointer
pybind11::capsule{ arr, [](void *arr_p) { delete[] reinterpret_cast<basic_pixel_type*>(arr_p); } }
);
}
else
{
return pybind11::template array_t<basic_pixel_type, py::array::c_style>(
{rows, cols, channels}, // shape
{dtype_size * cols * channels, dtype_size * channels, dtype_size}, // strides
arr, // pointer
pybind11::capsule{ arr, [](void *arr_p) { delete[] reinterpret_cast<basic_pixel_type*>(arr_p); } }
);
}
}
};
......@@ -310,9 +348,15 @@ namespace pybind11
using type = dlib::numpy_image<pixel_type>;
bool load(handle src, bool convert) {
if (!convert && !type::check_(src))
if (!type::check_(src))
return false;
// stash the output of ensure into a temp variable since assigning it to
// value (the member variable created by the PYBIND11_TYPE_CASTER)
// apparently causes the return bool value to be ignored?
auto temp = type::ensure(src);
if (!dlib::is_image<pixel_type>(temp))
return false;
value = type::ensure(src);
value = temp;
return static_cast<bool>(value);
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment