Commit 2cca4ae7 authored by Davis King's avatar Davis King

Made the dnn_trainer propagate exceptions that happen during training (in its

training thread) out of the object into the calling code rather than
terminating the application.
parent 86fa427e
...@@ -20,6 +20,8 @@ ...@@ -20,6 +20,8 @@
#include <cstdio> #include <cstdio>
#include <set> #include <set>
#include <future> #include <future>
#include <exception>
#include <mutex>
namespace dlib namespace dlib
{ {
...@@ -132,6 +134,7 @@ namespace dlib ...@@ -132,6 +134,7 @@ namespace dlib
) const ) const
{ {
wait_for_thread_to_pause(); wait_for_thread_to_pause();
propagate_exception();
return net; return net;
} }
...@@ -175,6 +178,7 @@ namespace dlib ...@@ -175,6 +178,7 @@ namespace dlib
) const ) const
{ {
wait_for_thread_to_pause(); wait_for_thread_to_pause();
propagate_exception();
return devices[0]->solvers; return devices[0]->solvers;
} }
...@@ -689,10 +693,12 @@ namespace dlib ...@@ -689,10 +693,12 @@ namespace dlib
} }
} }
} }
catch(std::exception& e) catch(...)
{ {
std::cerr << e.what() << std::endl; // If an exception happens then permanently disable the trainer object.
throw; job_pipe.disable();
std::lock_guard<std::mutex> lock(eptr_mutex);
eptr = std::current_exception();
} }
void wait_for_thread_to_pause() const void wait_for_thread_to_pause() const
...@@ -873,6 +879,7 @@ namespace dlib ...@@ -873,6 +879,7 @@ namespace dlib
label_iterator lbegin label_iterator lbegin
) )
{ {
propagate_exception();
size_t num = std::distance(dbegin, dend); size_t num = std::distance(dbegin, dend);
size_t devs = devices.size(); size_t devs = devices.size();
job.t.resize(devs); job.t.resize(devs);
...@@ -960,6 +967,14 @@ namespace dlib ...@@ -960,6 +967,14 @@ namespace dlib
long lr_schedule_pos; long lr_schedule_pos;
unsigned long gradient_check_budget; unsigned long gradient_check_budget;
std::exception_ptr eptr;
mutable std::mutex eptr_mutex;
void propagate_exception() const
{
std::lock_guard<std::mutex> lock(eptr_mutex);
if (eptr)
std::rethrow_exception(eptr);
}
}; };
......
...@@ -38,6 +38,11 @@ namespace dlib ...@@ -38,6 +38,11 @@ namespace dlib
currently selected (i.e. the one indicated by cudaGetDevice()) when currently selected (i.e. the one indicated by cudaGetDevice()) when
dnn_trainer is constructed. It will continue to use that device even if dnn_trainer is constructed. It will continue to use that device even if
you later change it by a call to cudaSetDevice(). you later change it by a call to cudaSetDevice().
EXCEPTIONS
If an exception is thrown by any part of the neural network during training
then the exception will be propagated out of the trainer to the user.
Moreover, the trainer instance will be unusable and should be destroyed.
!*/ !*/
public: public:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment