Commit 9691c194 authored by Morten Hustveit's avatar Morten Hustveit Committed by Davis E. King

Added support for variadic Python functions in find_max_global. (#1141)

* Added support for variadic Python functions in find_max_global.

* Add test for find_{min,max}_global on variadic functions.
parent 09a9ad6d
...@@ -45,17 +45,18 @@ py::list mat_to_list ( ...@@ -45,17 +45,18 @@ py::list mat_to_list (
return l; return l;
} }
size_t num_function_arguments(py::object f) size_t num_function_arguments(py::object f, size_t expected_num)
{ {
if (hasattr(f,"func_code")) const auto code_object = f.attr(hasattr(f,"func_code") ? "func_code" : "__code__");
return f.attr("func_code").attr("co_argcount").cast<std::size_t>(); const auto num = code_object.attr("co_argcount").cast<std::size_t>();
else if (num < expected_num && (code_object.attr("co_flags").cast<int>() & CO_VARARGS))
return f.attr("__code__").attr("co_argcount").cast<std::size_t>(); return expected_num;
return num;
} }
double call_func(py::object f, const matrix<double,0,1>& args) double call_func(py::object f, const matrix<double,0,1>& args)
{ {
const auto num = num_function_arguments(f); const auto num = num_function_arguments(f, args.size());
DLIB_CASSERT(num == args.size(), DLIB_CASSERT(num == args.size(),
"The function being optimized takes a number of arguments that doesn't agree with the size of the bounds lists you provided to find_max_global()"); "The function being optimized takes a number of arguments that doesn't agree with the size of the bounds lists you provided to find_max_global()");
DLIB_CASSERT(0 < num && num < 15, "Functions being optimized must take between 1 and 15 scalar arguments."); DLIB_CASSERT(0 < num && num < 15, "Functions being optimized must take between 1 and 15 scalar arguments.");
......
from dlib import find_max_global, find_min_global
from pytest import raises
def test_global_optimization_nargs():
w0 = find_max_global(lambda *args: sum(args), [0, 0, 0], [1, 1, 1], 10)
w1 = find_min_global(lambda *args: sum(args), [0, 0, 0], [1, 1, 1], 10)
assert w0 == ([1, 1, 1], 3)
assert w1 == ([0, 0, 0], 0)
w2 = find_max_global(lambda a, b, c, *args: a + b + c - sum(args), [0, 0, 0], [1, 1, 1], 10)
w3 = find_min_global(lambda a, b, c, *args: a + b + c - sum(args), [0, 0, 0], [1, 1, 1], 10)
assert w2 == ([1, 1, 1], 3)
assert w3 == ([0, 0, 0], 0)
with raises(Exception):
find_max_global(lambda a, b: 0, [0, 0, 0], [1, 1, 1], 10)
with raises(Exception):
find_min_global(lambda a, b: 0, [0, 0, 0], [1, 1, 1], 10)
with raises(Exception):
find_max_global(lambda a, b, c, d, *args: 0, [0, 0, 0], [1, 1, 1], 10)
with raises(Exception):
find_min_global(lambda a, b, c, d, *args: 0, [0, 0, 0], [1, 1, 1], 10)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment