Commit 1ecd603b authored by Peizhao Zhang's avatar Peizhao Zhang Committed by Facebook Github Bot

Potential fix for training stuck caused by data loader failure. (#638)

Summary:
Pull Request resolved: https://github.com/facebookresearch/Detectron/pull/638

Potential fix for training stuck caused by data loader failure.

Reviewed By: rbgirshick

Differential Revision: D9513621

fbshipit-source-id: 123974eac83f40ef2f582a90fedea790fdc442d1
parent c9ed587c
...@@ -241,6 +241,9 @@ class RoIDataLoader(object): ...@@ -241,6 +241,9 @@ class RoIDataLoader(object):
self.shutdown() self.shutdown()
break break
def should_stop(self):
return self.coordinator.should_stop()
def shutdown(self): def shutdown(self):
self.coordinator.request_stop() self.coordinator.request_stop()
self.coordinator.wait_for_stop() self.coordinator.wait_for_stop()
......
...@@ -50,7 +50,6 @@ import detectron.utils.net as nu ...@@ -50,7 +50,6 @@ import detectron.utils.net as nu
def train_model(): def train_model():
"""Model training loop.""" """Model training loop."""
logger = logging.getLogger(__name__)
model, weights_file, start_iter, checkpoints, output_dir = create_model() model, weights_file, start_iter, checkpoints, output_dir = create_model()
if 'final' in checkpoints: if 'final' in checkpoints:
# The final model was found in the output directory, so nothing to do # The final model was found in the output directory, so nothing to do
...@@ -61,6 +60,8 @@ def train_model(): ...@@ -61,6 +60,8 @@ def train_model():
CHECKPOINT_PERIOD = int(cfg.TRAIN.SNAPSHOT_ITERS / cfg.NUM_GPUS) CHECKPOINT_PERIOD = int(cfg.TRAIN.SNAPSHOT_ITERS / cfg.NUM_GPUS)
for cur_iter in range(start_iter, cfg.SOLVER.MAX_ITER): for cur_iter in range(start_iter, cfg.SOLVER.MAX_ITER):
if model.roi_data_loader.should_stop():
handle_critical_error(model, 'roi_data_loader failed')
training_stats.IterTic() training_stats.IterTic()
lr = model.UpdateWorkspaceLr(cur_iter, lr_policy.get_lr_at_iter(cur_iter)) lr = model.UpdateWorkspaceLr(cur_iter, lr_policy.get_lr_at_iter(cur_iter))
workspace.RunNet(model.net.Proto().name) workspace.RunNet(model.net.Proto().name)
...@@ -82,9 +83,7 @@ def train_model(): ...@@ -82,9 +83,7 @@ def train_model():
training_stats.ResetIterTimer() training_stats.ResetIterTimer()
if np.isnan(training_stats.iter_total_loss): if np.isnan(training_stats.iter_total_loss):
logger.critical('Loss is NaN, exiting...') handle_critical_error(model, 'Loss is NaN')
model.roi_data_loader.shutdown()
envu.exit_on_error()
# Save the final model # Save the final model
checkpoints['final'] = os.path.join(output_dir, 'model_final.pkl') checkpoints['final'] = os.path.join(output_dir, 'model_final.pkl')
...@@ -94,6 +93,13 @@ def train_model(): ...@@ -94,6 +93,13 @@ def train_model():
return checkpoints return checkpoints
def handle_critical_error(model, msg):
logger = logging.getLogger(__name__)
logger.critical(msg)
model.roi_data_loader.shutdown()
raise Exception(msg)
def create_model(): def create_model():
"""Build the model and look for saved model checkpoints in case we can """Build the model and look for saved model checkpoints in case we can
resume from one. resume from one.
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment