Commit 9b41e84b authored by Ross Girshick's avatar Ross Girshick Committed by Facebook Github Bot

Do not mutate cfg.TRAIN.WEIGHTS

Reviewed By: ashwinb

Differential Revision: D7148419

fbshipit-source-id: c8990b7db3c48ea76b665cbda1b9c0985096920c
parent a026d775
...@@ -125,12 +125,12 @@ def main(): ...@@ -125,12 +125,12 @@ def main():
def train_model(): def train_model():
"""Model training loop.""" """Model training loop."""
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
model, start_iter, checkpoints, output_dir = create_model() model, weights_file, start_iter, checkpoints, output_dir = create_model()
if 'final' in checkpoints: if 'final' in checkpoints:
# The final model was found in the output directory, so nothing to do # The final model was found in the output directory, so nothing to do
return checkpoints return checkpoints
setup_model_for_training(model, output_dir) setup_model_for_training(model, weights_file, output_dir)
training_stats = TrainingStats(model) training_stats = TrainingStats(model)
CHECKPOINT_PERIOD = int(cfg.TRAIN.SNAPSHOT_ITERS / cfg.NUM_GPUS) CHECKPOINT_PERIOD = int(cfg.TRAIN.SNAPSHOT_ITERS / cfg.NUM_GPUS)
...@@ -176,6 +176,7 @@ def create_model(): ...@@ -176,6 +176,7 @@ def create_model():
start_iter = 0 start_iter = 0
checkpoints = {} checkpoints = {}
output_dir = get_output_dir(cfg.TRAIN.DATASETS, training=True) output_dir = get_output_dir(cfg.TRAIN.DATASETS, training=True)
weights_file = cfg.TRAIN.WEIGHTS
if cfg.TRAIN.AUTO_RESUME: if cfg.TRAIN.AUTO_RESUME:
# Check for the final model (indicates training already finished) # Check for the final model (indicates training already finished)
final_path = os.path.join(output_dir, 'model_final.pkl') final_path = os.path.join(output_dir, 'model_final.pkl')
...@@ -196,10 +197,10 @@ def create_model(): ...@@ -196,10 +197,10 @@ def create_model():
if start_iter > 0: if start_iter > 0:
# Override the initialization weights with the found checkpoint # Override the initialization weights with the found checkpoint
cfg.TRAIN.WEIGHTS = os.path.join(output_dir, resume_weights_file) weights_file = os.path.join(output_dir, resume_weights_file)
logger.info( logger.info(
'========> Resuming from checkpoint {} at start iter {}'. '========> Resuming from checkpoint {} at start iter {}'.
format(cfg.TRAIN.WEIGHTS, start_iter) format(weights_file, start_iter)
) )
logger.info('Building model: {}'.format(cfg.MODEL.TYPE)) logger.info('Building model: {}'.format(cfg.MODEL.TYPE))
...@@ -208,7 +209,7 @@ def create_model(): ...@@ -208,7 +209,7 @@ def create_model():
optimize_memory(model) optimize_memory(model)
# Performs random weight initialization as defined by the model # Performs random weight initialization as defined by the model
workspace.RunNetOnce(model.param_init_net) workspace.RunNetOnce(model.param_init_net)
return model, start_iter, checkpoints, output_dir return model, weights_file, start_iter, checkpoints, output_dir
def optimize_memory(model): def optimize_memory(model):
...@@ -225,14 +226,14 @@ def optimize_memory(model): ...@@ -225,14 +226,14 @@ def optimize_memory(model):
) )
def setup_model_for_training(model, output_dir): def setup_model_for_training(model, weights_file, output_dir):
"""Loaded saved weights and create the network in the C2 workspace.""" """Loaded saved weights and create the network in the C2 workspace."""
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
add_model_training_inputs(model) add_model_training_inputs(model)
if cfg.TRAIN.WEIGHTS: if weights_file:
# Override random weight initialization with weights from a saved model # Override random weight initialization with weights from a saved model
nu.initialize_gpu_from_weights_file(model, cfg.TRAIN.WEIGHTS, gpu_id=0) nu.initialize_gpu_from_weights_file(model, weights_file, gpu_id=0)
# Even if we're randomly initializing we still need to synchronize # Even if we're randomly initializing we still need to synchronize
# parameters across GPUs # parameters across GPUs
nu.broadcast_parameters(model) nu.broadcast_parameters(model)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment