Commit 3e99b2f5 authored by Ashwin Bharambe's avatar Ashwin Bharambe Committed by Ashwin Bharambe

Introduce DetectionModelHelper.GetLossScale()

Summary:
Under a distributed setting (multiple nodes), the loss scale may need
to account for the number of nodes participating in the algorithm. This method
provides a reasonable way for providing this config.

Reviewed By: rbgirshick

Differential Revision: D6826976

fbshipit-source-id: d224200710b0f3c3245851874c9f4f3ab871710f
parent e59c30bb
......@@ -412,7 +412,7 @@ def add_fpn_rpn_losses(model):
'loss_rpn_cls_fpn' + slvl,
normalize=0,
scale=(
1. / cfg.NUM_GPUS / cfg.TRAIN.RPN_BATCH_SIZE_PER_IM /
model.GetLossScale() / cfg.TRAIN.RPN_BATCH_SIZE_PER_IM /
cfg.TRAIN.IMS_PER_BATCH
)
)
......@@ -427,7 +427,7 @@ def add_fpn_rpn_losses(model):
],
'loss_rpn_bbox_fpn' + slvl,
beta=1. / 9.,
scale=1. / cfg.NUM_GPUS
scale=model.GetLossScale(),
)
loss_gradients.update(
blob_utils.
......
......@@ -471,6 +471,13 @@ class DetectionModelHelper(cnn.CNNModelHelper):
scale=correction)
workspace.RunOperatorOnce(op)
def GetLossScale(self):
"""Allow a way to configure the loss scale dynamically.
This may be used in a distributed data parallel setting.
"""
return 1.0 / cfg.NUM_GPUS
def AddLosses(self, losses):
if not isinstance(losses, list):
losses = [losses]
......
......@@ -70,7 +70,7 @@ def add_fast_rcnn_losses(model):
"""Add losses for RoI classification and bounding box regression."""
cls_prob, loss_cls = model.net.SoftmaxWithLoss(
['cls_score', 'labels_int32'], ['cls_prob', 'loss_cls'],
scale=1. / cfg.NUM_GPUS
scale=model.GetLossScale()
)
loss_bbox = model.net.SmoothL1Loss(
[
......@@ -78,7 +78,7 @@ def add_fast_rcnn_losses(model):
'bbox_outside_weights'
],
'loss_bbox',
scale=1. / cfg.NUM_GPUS
scale=model.GetLossScale()
)
loss_gradients = blob_utils.get_loss_gradients(model, [loss_cls, loss_bbox])
model.Accuracy(['cls_prob', 'labels_int32'], 'accuracy_cls')
......
......@@ -96,7 +96,7 @@ def add_mask_rcnn_losses(model, blob_mask):
loss_mask = model.net.SigmoidCrossEntropyLoss(
[blob_mask, 'masks_int32'],
'loss_mask',
scale=1. / cfg.NUM_GPUS * cfg.MRCNN.WEIGHT_LOSS_MASK
scale=model.GetLossScale() * cfg.MRCNN.WEIGHT_LOSS_MASK
)
loss_gradients = blob_utils.get_loss_gradients(model, [loss_mask])
model.AddLosses('loss_mask')
......
......@@ -266,7 +266,7 @@ def add_fpn_retinanet_losses(model):
],
'retnet_loss_bbox_' + suffix,
beta=cfg.RETINANET.BBOX_REG_BETA,
scale=1. / cfg.NUM_GPUS * cfg.RETINANET.BBOX_REG_WEIGHT
scale=model.GetLossScale() * cfg.RETINANET.BBOX_REG_WEIGHT
)
gradients.append(bbox_loss)
losses.append('retnet_loss_bbox_' + suffix)
......@@ -286,7 +286,7 @@ def add_fpn_retinanet_losses(model):
['fl_{}'.format(suffix)],
gamma=cfg.RETINANET.LOSS_GAMMA,
alpha=cfg.RETINANET.LOSS_ALPHA,
scale=(1. / cfg.NUM_GPUS)
scale=model.GetLossScale()
)
gradients.append(cls_focal_loss)
losses.append('fl_{}'.format(suffix))
......@@ -299,7 +299,7 @@ def add_fpn_retinanet_losses(model):
['fl_{}'.format(suffix), 'retnet_prob_{}'.format(suffix)],
gamma=cfg.RETINANET.LOSS_GAMMA,
alpha=cfg.RETINANET.LOSS_ALPHA,
scale=(1. / cfg.NUM_GPUS),
scale=model.GetLossScale(),
)
gradients.append(cls_focal_loss)
losses.append('fl_{}'.format(suffix))
......
......@@ -136,7 +136,7 @@ def add_single_scale_rpn_losses(model):
loss_rpn_cls = model.net.SigmoidCrossEntropyLoss(
['rpn_cls_logits', 'rpn_labels_int32'],
'loss_rpn_cls',
scale=1. / cfg.NUM_GPUS
scale=model.GetLossScale()
)
loss_rpn_bbox = model.net.SmoothL1Loss(
[
......@@ -145,7 +145,7 @@ def add_single_scale_rpn_losses(model):
],
'loss_rpn_bbox',
beta=1. / 9.,
scale=1. / cfg.NUM_GPUS
scale=model.GetLossScale()
)
loss_gradients = blob_utils.get_loss_gradients(
model, [loss_rpn_cls, loss_rpn_bbox]
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment