Commit 4ee1f664 authored by Davis King's avatar Davis King

Made thread_pool and parallel_for propagate exceptions from task threads to

calling code.
parent 5b361945
...@@ -369,6 +369,39 @@ namespace ...@@ -369,6 +369,39 @@ namespace
DLIB_TEST(d == 4); DLIB_TEST(d == 4);
} }
tp.wait_for_all_tasks();
// make sure exception propagation from tasks works correctly.
auto f_throws = []() { throw dlib::error("test exception");};
bool got_exception = false;
try
{
tp.add_task_by_value(f_throws);
tp.wait_for_all_tasks();
}
catch(dlib::error& e)
{
DLIB_TEST(e.info == "test exception");
got_exception = true;
}
DLIB_TEST(got_exception);
dlib::future<int> aa;
auto f_throws2 = [](int& a) { a = 1; throw dlib::error("test exception");};
got_exception = false;
try
{
tp.add_task(f_throws2, aa);
aa.get();
}
catch(dlib::error& e)
{
DLIB_TEST(e.info == "test exception");
got_exception = true;
}
DLIB_TEST(got_exception);
} }
} }
......
...@@ -24,7 +24,6 @@ namespace dlib ...@@ -24,7 +24,6 @@ namespace dlib
requires requires
- begin <= end - begin <= end
- chunks_per_thread > 0 - chunks_per_thread > 0
- funct does not throw any exceptions
ensures ensures
- This is a convenience function for submitting a block of jobs to a thread_pool. - This is a convenience function for submitting a block of jobs to a thread_pool.
In particular, given the half open range [begin, end), this function will In particular, given the half open range [begin, end), this function will
...@@ -61,7 +60,6 @@ namespace dlib ...@@ -61,7 +60,6 @@ namespace dlib
requires requires
- begin <= end - begin <= end
- chunks_per_thread > 0 - chunks_per_thread > 0
- funct does not throw any exceptions
ensures ensures
- This function is equivalent to the following block of code: - This function is equivalent to the following block of code:
thread_pool tp(num_threads); thread_pool tp(num_threads);
...@@ -82,7 +80,6 @@ namespace dlib ...@@ -82,7 +80,6 @@ namespace dlib
requires requires
- chunks_per_thread > 0 - chunks_per_thread > 0
- begin <= end - begin <= end
- funct does not throw any exceptions
ensures ensures
- This is a convenience function for submitting a block of jobs to a - This is a convenience function for submitting a block of jobs to a
thread_pool. In particular, given the range [begin, end), this function will thread_pool. In particular, given the range [begin, end), this function will
...@@ -117,7 +114,6 @@ namespace dlib ...@@ -117,7 +114,6 @@ namespace dlib
requires requires
- begin <= end - begin <= end
- chunks_per_thread > 0 - chunks_per_thread > 0
- funct does not throw any exceptions
ensures ensures
- This function is equivalent to the following block of code: - This function is equivalent to the following block of code:
thread_pool tp(num_threads); thread_pool tp(num_threads);
...@@ -137,7 +133,6 @@ namespace dlib ...@@ -137,7 +133,6 @@ namespace dlib
requires requires
- begin <= end - begin <= end
- chunks_per_thread > 0 - chunks_per_thread > 0
- funct does not throw any exceptions
ensures ensures
- This function is equivalent to the following block of code: - This function is equivalent to the following block of code:
parallel_for_blocked(default_thread_pool(), begin, end, funct, chunks_per_thread); parallel_for_blocked(default_thread_pool(), begin, end, funct, chunks_per_thread);
...@@ -159,7 +154,6 @@ namespace dlib ...@@ -159,7 +154,6 @@ namespace dlib
requires requires
- begin <= end - begin <= end
- chunks_per_thread > 0 - chunks_per_thread > 0
- funct does not throw any exceptions
ensures ensures
- This function is equivalent to the following function call: - This function is equivalent to the following function call:
parallel_for_blocked(tp, begin, end, [&](long begin_sub, long end_sub) parallel_for_blocked(tp, begin, end, [&](long begin_sub, long end_sub)
...@@ -189,7 +183,6 @@ namespace dlib ...@@ -189,7 +183,6 @@ namespace dlib
requires requires
- begin <= end - begin <= end
- chunks_per_thread > 0 - chunks_per_thread > 0
- funct does not throw any exceptions
ensures ensures
- This function is equivalent to the following block of code: - This function is equivalent to the following block of code:
thread_pool tp(num_threads); thread_pool tp(num_threads);
...@@ -210,7 +203,6 @@ namespace dlib ...@@ -210,7 +203,6 @@ namespace dlib
requires requires
- begin <= end - begin <= end
- chunks_per_thread > 0 - chunks_per_thread > 0
- funct does not throw any exceptions
ensures ensures
- This function is equivalent to the following function call: - This function is equivalent to the following function call:
parallel_for_blocked(tp, begin, end, [&](long begin_sub, long end_sub) parallel_for_blocked(tp, begin, end, [&](long begin_sub, long end_sub)
...@@ -238,7 +230,6 @@ namespace dlib ...@@ -238,7 +230,6 @@ namespace dlib
requires requires
- begin <= end - begin <= end
- chunks_per_thread > 0 - chunks_per_thread > 0
- funct does not throw any exceptions
ensures ensures
- This function is equivalent to the following block of code: - This function is equivalent to the following block of code:
thread_pool tp(num_threads); thread_pool tp(num_threads);
...@@ -258,7 +249,6 @@ namespace dlib ...@@ -258,7 +249,6 @@ namespace dlib
requires requires
- begin <= end - begin <= end
- chunks_per_thread > 0 - chunks_per_thread > 0
- funct does not throw any exceptions
ensures ensures
- This function is equivalent to the following block of code: - This function is equivalent to the following block of code:
parallel_for(default_thread_pool(), begin, end, funct, chunks_per_thread); parallel_for(default_thread_pool(), begin, end, funct, chunks_per_thread);
...@@ -280,7 +270,6 @@ namespace dlib ...@@ -280,7 +270,6 @@ namespace dlib
requires requires
- begin <= end - begin <= end
- chunks_per_thread > 0 - chunks_per_thread > 0
- funct does not throw any exceptions
ensures ensures
- This function is identical to the parallel_for() routine defined above except - This function is identical to the parallel_for() routine defined above except
that it will print messages to cout showing the progress in executing the that it will print messages to cout showing the progress in executing the
...@@ -302,7 +291,6 @@ namespace dlib ...@@ -302,7 +291,6 @@ namespace dlib
requires requires
- begin <= end - begin <= end
- chunks_per_thread > 0 - chunks_per_thread > 0
- funct does not throw any exceptions
ensures ensures
- This function is identical to the parallel_for() routine defined above except - This function is identical to the parallel_for() routine defined above except
that it will print messages to cout showing the progress in executing the that it will print messages to cout showing the progress in executing the
...@@ -323,7 +311,6 @@ namespace dlib ...@@ -323,7 +311,6 @@ namespace dlib
requires requires
- begin <= end - begin <= end
- chunks_per_thread > 0 - chunks_per_thread > 0
- funct does not throw any exceptions
ensures ensures
- This function is identical to the parallel_for() routine defined above except - This function is identical to the parallel_for() routine defined above except
that it will print messages to cout showing the progress in executing the that it will print messages to cout showing the progress in executing the
...@@ -344,7 +331,6 @@ namespace dlib ...@@ -344,7 +331,6 @@ namespace dlib
requires requires
- begin <= end - begin <= end
- chunks_per_thread > 0 - chunks_per_thread > 0
- funct does not throw any exceptions
ensures ensures
- This function is identical to the parallel_for() routine defined above except - This function is identical to the parallel_for() routine defined above except
that it will print messages to cout showing the progress in executing the that it will print messages to cout showing the progress in executing the
...@@ -364,7 +350,6 @@ namespace dlib ...@@ -364,7 +350,6 @@ namespace dlib
requires requires
- begin <= end - begin <= end
- chunks_per_thread > 0 - chunks_per_thread > 0
- funct does not throw any exceptions
ensures ensures
- This function is identical to the parallel_for() routine defined above except - This function is identical to the parallel_for() routine defined above except
that it will print messages to cout showing the progress in executing the that it will print messages to cout showing the progress in executing the
...@@ -388,7 +373,6 @@ namespace dlib ...@@ -388,7 +373,6 @@ namespace dlib
requires requires
- begin <= end - begin <= end
- chunks_per_thread > 0 - chunks_per_thread > 0
- funct does not throw any exceptions
ensures ensures
- This function is identical to the parallel_for_blocked() routine defined - This function is identical to the parallel_for_blocked() routine defined
above except that it will print messages to cout showing the progress in above except that it will print messages to cout showing the progress in
...@@ -410,7 +394,6 @@ namespace dlib ...@@ -410,7 +394,6 @@ namespace dlib
requires requires
- begin <= end - begin <= end
- chunks_per_thread > 0 - chunks_per_thread > 0
- funct does not throw any exceptions
ensures ensures
- This function is identical to the parallel_for_blocked() routine defined - This function is identical to the parallel_for_blocked() routine defined
above except that it will print messages to cout showing the progress in above except that it will print messages to cout showing the progress in
...@@ -431,7 +414,6 @@ namespace dlib ...@@ -431,7 +414,6 @@ namespace dlib
requires requires
- begin <= end - begin <= end
- chunks_per_thread > 0 - chunks_per_thread > 0
- funct does not throw any exceptions
ensures ensures
- This function is identical to the parallel_for_blocked() routine defined - This function is identical to the parallel_for_blocked() routine defined
above except that it will print messages to cout showing the progress in above except that it will print messages to cout showing the progress in
...@@ -452,7 +434,6 @@ namespace dlib ...@@ -452,7 +434,6 @@ namespace dlib
requires requires
- begin <= end - begin <= end
- chunks_per_thread > 0 - chunks_per_thread > 0
- funct does not throw any exceptions
ensures ensures
- This function is identical to the parallel_for_blocked() routine defined - This function is identical to the parallel_for_blocked() routine defined
above except that it will print messages to cout showing the progress in above except that it will print messages to cout showing the progress in
...@@ -472,7 +453,6 @@ namespace dlib ...@@ -472,7 +453,6 @@ namespace dlib
requires requires
- begin <= end - begin <= end
- chunks_per_thread > 0 - chunks_per_thread > 0
- funct does not throw any exceptions
ensures ensures
- This function is identical to the parallel_for_blocked() routine defined - This function is identical to the parallel_for_blocked() routine defined
above except that it will print messages to cout showing the progress in above except that it will print messages to cout showing the progress in
......
...@@ -61,6 +61,11 @@ namespace dlib ...@@ -61,6 +61,11 @@ namespace dlib
} }
wait(); wait();
// Throw any unhandled exceptions. Since shutdown_pool() is only called in the
// destructor this will kill the program.
for (auto&& task : tasks)
task.propagate_exception();
} }
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
...@@ -94,6 +99,9 @@ namespace dlib ...@@ -94,6 +99,9 @@ namespace dlib
const unsigned long idx = task_id_to_index(task_id); const unsigned long idx = task_id_to_index(task_id);
while (tasks[idx].task_id == task_id) while (tasks[idx].task_id == task_id)
task_done_signaler.wait(); task_done_signaler.wait();
for (auto&& task : tasks)
task.propagate_exception();
} }
} }
...@@ -124,6 +132,10 @@ namespace dlib ...@@ -124,6 +132,10 @@ namespace dlib
if (found_task) if (found_task)
task_done_signaler.wait(); task_done_signaler.wait();
} }
// throw any exceptions generated by the tasks
for (auto&& task : tasks)
task.propagate_exception();
} }
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
...@@ -177,6 +189,9 @@ namespace dlib ...@@ -177,6 +189,9 @@ namespace dlib
task = tasks[idx]; task = tasks[idx];
} }
std::exception_ptr eptr;
try
{
// now do the task // now do the task
if (task.bfp) if (task.bfp)
task.bfp(); task.bfp();
...@@ -186,6 +201,11 @@ namespace dlib ...@@ -186,6 +201,11 @@ namespace dlib
task.mfp1(task.arg1); task.mfp1(task.arg1);
else if (task.mfp2) else if (task.mfp2)
task.mfp2(task.arg1, task.arg2); task.mfp2(task.arg1, task.arg2);
}
catch(...)
{
eptr = std::current_exception();
}
// Now let others know that we finished the task. We do this // Now let others know that we finished the task. We do this
// by clearing out the state of this task // by clearing out the state of this task
...@@ -198,6 +218,7 @@ namespace dlib ...@@ -198,6 +218,7 @@ namespace dlib
tasks[idx].mfp2.clear(); tasks[idx].mfp2.clear();
tasks[idx].arg1 = 0; tasks[idx].arg1 = 0;
tasks[idx].arg2 = 0; tasks[idx].arg2 = 0;
tasks[idx].eptr = eptr;
task_done_signaler.broadcast(); task_done_signaler.broadcast();
} }
...@@ -210,6 +231,9 @@ namespace dlib ...@@ -210,6 +231,9 @@ namespace dlib
find_empty_task_slot ( find_empty_task_slot (
) const ) const
{ {
for (auto&& task : tasks)
task.propagate_exception();
for (unsigned long i = 0; i < tasks.size(); ++i) for (unsigned long i = 0; i < tasks.size(); ++i)
{ {
if (tasks[i].is_empty()) if (tasks[i].is_empty())
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
#include "../array.h" #include "../array.h"
#include "../smart_pointers_thread_safe.h" #include "../smart_pointers_thread_safe.h"
#include "../smart_pointers.h" #include "../smart_pointers.h"
#include <exception>
namespace dlib namespace dlib
{ {
...@@ -451,6 +452,17 @@ namespace dlib ...@@ -451,6 +452,17 @@ namespace dlib
bfp_type bfp; bfp_type bfp;
shared_ptr<function_object_copy> function_copy; shared_ptr<function_object_copy> function_copy;
mutable std::exception_ptr eptr; // non-null if the task threw an exception
void propagate_exception() const
{
if (eptr)
{
auto tmp = eptr;
eptr = nullptr;
std::rethrow_exception(tmp);
}
}
}; };
......
...@@ -225,9 +225,11 @@ namespace dlib ...@@ -225,9 +225,11 @@ namespace dlib
such as mutex objects. such as mutex objects.
EXCEPTIONS EXCEPTIONS
Note that if an exception is thrown inside a task thread and Note that if an exception is thrown inside a task thread and is not caught
is not caught then the normal rule for uncaught exceptions in then the exception will be trapped inside the thread pool and rethrown at a
threads applies. That is, the application will be terminated. later time when someone calls one of the add task or wait member functions
of the thread pool. This allows exceptions to propagate out of task threads
and into the calling code where they can be handled.
!*/ !*/
public: public:
......
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment