Commit 9b41e84b authored by Ross Girshick's avatar Ross Girshick Committed by Facebook Github Bot

Do not mutate cfg.TRAIN.WEIGHTS

Reviewed By: ashwinb

Differential Revision: D7148419

fbshipit-source-id: c8990b7db3c48ea76b665cbda1b9c0985096920c
parent a026d775
......@@ -125,12 +125,12 @@ def main():
def train_model():
"""Model training loop."""
logger = logging.getLogger(__name__)
model, start_iter, checkpoints, output_dir = create_model()
model, weights_file, start_iter, checkpoints, output_dir = create_model()
if 'final' in checkpoints:
# The final model was found in the output directory, so nothing to do
return checkpoints
setup_model_for_training(model, output_dir)
setup_model_for_training(model, weights_file, output_dir)
training_stats = TrainingStats(model)
CHECKPOINT_PERIOD = int(cfg.TRAIN.SNAPSHOT_ITERS / cfg.NUM_GPUS)
......@@ -176,6 +176,7 @@ def create_model():
start_iter = 0
checkpoints = {}
output_dir = get_output_dir(cfg.TRAIN.DATASETS, training=True)
weights_file = cfg.TRAIN.WEIGHTS
if cfg.TRAIN.AUTO_RESUME:
# Check for the final model (indicates training already finished)
final_path = os.path.join(output_dir, 'model_final.pkl')
......@@ -196,10 +197,10 @@ def create_model():
if start_iter > 0:
# Override the initialization weights with the found checkpoint
cfg.TRAIN.WEIGHTS = os.path.join(output_dir, resume_weights_file)
weights_file = os.path.join(output_dir, resume_weights_file)
logger.info(
'========> Resuming from checkpoint {} at start iter {}'.
format(cfg.TRAIN.WEIGHTS, start_iter)
format(weights_file, start_iter)
)
logger.info('Building model: {}'.format(cfg.MODEL.TYPE))
......@@ -208,7 +209,7 @@ def create_model():
optimize_memory(model)
# Performs random weight initialization as defined by the model
workspace.RunNetOnce(model.param_init_net)
return model, start_iter, checkpoints, output_dir
return model, weights_file, start_iter, checkpoints, output_dir
def optimize_memory(model):
......@@ -225,14 +226,14 @@ def optimize_memory(model):
)
def setup_model_for_training(model, output_dir):
def setup_model_for_training(model, weights_file, output_dir):
"""Loaded saved weights and create the network in the C2 workspace."""
logger = logging.getLogger(__name__)
add_model_training_inputs(model)
if cfg.TRAIN.WEIGHTS:
if weights_file:
# Override random weight initialization with weights from a saved model
nu.initialize_gpu_from_weights_file(model, cfg.TRAIN.WEIGHTS, gpu_id=0)
nu.initialize_gpu_from_weights_file(model, weights_file, gpu_id=0)
# Even if we're randomly initializing we still need to synchronize
# parameters across GPUs
nu.broadcast_parameters(model)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment