Commit 228de4ea authored by Ashwin Bharambe's avatar Ashwin Bharambe Committed by Facebook Github Bot

Allow configuration to only build forward pass

Summary:
Generalizations to enable building forward passes and param update
operations separately.

Reviewed By: rbgirshick

Differential Revision: D6827027

fbshipit-source-id: 9d2ace349b4ebdf8993baa2f2c1529e2c5ce751d
parent c08e0703
......@@ -104,9 +104,14 @@ def retinanet(model):
# Helper functions for building various re-usable network bits
# ---------------------------------------------------------------------------- #
def create(model_type_func, train=False):
def create(model_type_func, train=False, gpu_id=0):
"""Generic model creation function that dispatches to specific model
building functions.
By default, this function will generate a data parallel model configured to
run on cfg.NUM_GPUS devices. However, you can restrict it to build a model
targeted to a specific GPU by specifying gpu_id. This is used by
optimizer.build_data_parallel_model() during test time.
"""
model = DetectionModelHelper(
name=model_type_func,
......@@ -114,6 +119,8 @@ def create(model_type_func, train=False):
num_classes=cfg.MODEL.NUM_CLASSES,
init_params=train
)
model.only_build_forward_pass = False
model.target_gpu_id = gpu_id
return get_func(model_type_func)(model)
......
......@@ -34,7 +34,9 @@ def build_data_parallel_model(model, single_gpu_build_func):
"""Build a data parallel model given a function that builds the model on a
single GPU.
"""
if model.train:
if model.only_build_forward_pass:
single_gpu_build_func(model)
elif model.train:
all_loss_gradients = _build_forward_graph(model, single_gpu_build_func)
# Add backward pass on all GPUs
model.AddGradientOperators(all_loss_gradients)
......@@ -43,11 +45,12 @@ def build_data_parallel_model(model, single_gpu_build_func):
for gpu_id in range(cfg.NUM_GPUS):
# After allreduce, all GPUs perform SGD updates on their identical
# params and gradients in parallel
_add_parameter_update_ops(model, gpu_id)
with c2_utils.NamedCudaScope(gpu_id):
add_single_gpu_param_update_ops(model, gpu_id)
else:
# Test-time network operates on single GPU
# Test-time parallelism is implemented through multiprocessing
with c2_utils.NamedCudaScope(0):
with c2_utils.NamedCudaScope(model.target_gpu_id):
single_gpu_build_func(model)
......@@ -84,9 +87,7 @@ def _add_allreduce_graph(model):
muji.Allreduce(model.net, gradients, reduced_affix='')
def _add_parameter_update_ops(model, gpu_id):
"""Construct the optimizer update op graph."""
with c2_utils.NamedCudaScope(gpu_id):
def add_single_gpu_param_update_ops(model, gpu_id):
# Learning rate of 0 is a dummy value to be set properly at the
# start of training
lr = model.param_init_net.ConstantFill(
......@@ -98,9 +99,8 @@ def _add_parameter_update_ops(model, gpu_id):
wd = model.param_init_net.ConstantFill(
[], 'wd', shape=[1], value=cfg.SOLVER.WEIGHT_DECAY
)
for param in model.TrainableParams(gpu_id=gpu_id):
logger.info('param ' + str(param) + ' will be updated')
logger.debug('param ' + str(param) + ' will be updated')
param_grad = model.param_to_grad[param]
# Initialize momentum vector
param_momentum = model.param_init_net.ConstantFill(
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment