Commit 939e9670 authored by Davis King's avatar Davis King

Improved the type validation on numpy arrays passed to dlib. The old code

didn't have complete coverage and would let incorrectly typed numpy arrays
through.
parent 176e329d
......@@ -14,21 +14,59 @@ namespace py = pybind11;
// ----------------------------------------------------------------------------------------
template <typename T>
template <typename TT>
void validate_numpy_array_type (
const py::object& obj
)
{
const char ch = obj.attr("dtype").attr("char").cast<char>();
if (dlib::is_same_type<T,double>::value && ch != 'd')
throw dlib::error("Expected numpy.ndarray of float64");
if (dlib::is_same_type<T,float>::value && ch != 'f')
throw dlib::error("Expected numpy.ndarray of float32");
if (dlib::is_same_type<T,dlib::int32>::value && ch != 'i')
throw dlib::error("Expected numpy.ndarray of int32");
if (dlib::is_same_type<T,unsigned char>::value && ch != 'B')
throw dlib::error("Expected numpy.ndarray of uint8");
using T = typename dlib::pixel_traits<TT>::basic_pixel_type;
if (dlib::is_same_type<T,double>::value)
{
if (ch != 'd')
throw dlib::error("Expected numpy.ndarray of float64");
}
else if (dlib::is_same_type<T,float>::value)
{
if (ch != 'f')
throw dlib::error("Expected numpy.ndarray of float32");
}
else if (dlib::is_same_type<T,dlib::int16>::value)
{
if (ch != 'h')
throw dlib::error("Expected numpy.ndarray of int16");
}
else if (dlib::is_same_type<T,dlib::uint16>::value)
{
if (ch != 'H')
throw dlib::error("Expected numpy.ndarray of uint16");
}
else if (dlib::is_same_type<T,dlib::int32>::value)
{
if (ch != 'i')
throw dlib::error("Expected numpy.ndarray of int32");
}
else if (dlib::is_same_type<T,dlib::uint32>::value)
{
if (ch != 'I')
throw dlib::error("Expected numpy.ndarray of uint32");
}
else if (dlib::is_same_type<T,unsigned char>::value)
{
if (ch != 'B')
throw dlib::error("Expected numpy.ndarray of uint8");
}
else if (dlib::is_same_type<T,signed char>::value)
{
if (ch != 'b')
throw dlib::error("Expected numpy.ndarray of int8");
}
else
{
throw dlib::error("validate_numpy_array_type() called with unsupported type.");
}
}
// ----------------------------------------------------------------------------------------
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment