diff --git a/dlib/python/numpy_image.h b/dlib/python/numpy_image.h index 9d73bdf7e4f27bc69af561d2f1e5981fbf5a758f..af2808ce1e6160146a9d1674c430207a5d54e4f6 100644 --- a/dlib/python/numpy_image.h +++ b/dlib/python/numpy_image.h @@ -13,6 +13,7 @@ #include <pybind11/pybind11.h> #include <dlib/image_transforms/assign_image.h> #include <stdint.h> +#include <type_traits> namespace py = pybind11; @@ -356,18 +357,28 @@ namespace pybind11 { using basic_pixel_type = typename dlib::pixel_traits<pixel_type>::basic_pixel_type; - static PYBIND11_DESCR name() { - constexpr size_t channels = dlib::pixel_traits<pixel_type>::num; - if (channels == 1) - return _("numpy.ndarray[(rows,cols),") + npy_format_descriptor<basic_pixel_type>::name() + _("]"); - else if (channels == 2) + template <size_t channels> + static PYBIND11_DESCR getname(typename std::enable_if<channels==1,int>::type) { + return _("numpy.ndarray[(rows,cols),") + npy_format_descriptor<basic_pixel_type>::name() + _("]"); + }; + template <size_t channels> + static PYBIND11_DESCR getname(typename std::enable_if<channels!=1,int>::type) { + if (channels == 2) return _("numpy.ndarray[(rows,cols,2),") + npy_format_descriptor<basic_pixel_type>::name() + _("]"); else if (channels == 3) return _("numpy.ndarray[(rows,cols,3),") + npy_format_descriptor<basic_pixel_type>::name() + _("]"); else if (channels == 4) return _("numpy.ndarray[(rows,cols,4),") + npy_format_descriptor<basic_pixel_type>::name() + _("]"); - else - DLIB_CASSERT(false,"unsupported pixel type"); + }; + + static PYBIND11_DESCR name() { + constexpr size_t channels = dlib::pixel_traits<pixel_type>::num; + // The reason we have to call getname() in this wonky way is because + // pybind11 uses a type that records the length of the returned string in + // the type. So we have to do this overloading to make the return type + // from name() consistent. In C++17 this would be a lot cleaner with + // constexpr if, but can't use C++17 yet because of lack of wide support :( + return getname<channels>(0); } };