Commit 1624f810 authored by Moritz Kampelmuehler's avatar Moritz Kampelmuehler Committed by Facebook Github Bot

Fix Retinanet training bug with num_classes != 81

Summary:
Fixed bug preventing RetinaNet training with num_classes != 81 from converging.

Resolves #93
Closes https://github.com/facebookresearch/Detectron/pull/176

Reviewed By: ir413

Differential Revision: D7056134

Pulled By: rbgirshick

fbshipit-source-id: d3c62ce98a23a09191dd2720bc7189ee5490635d
parent a22302de
......@@ -286,7 +286,8 @@ def add_fpn_retinanet_losses(model):
['fl_{}'.format(suffix)],
gamma=cfg.RETINANET.LOSS_GAMMA,
alpha=cfg.RETINANET.LOSS_ALPHA,
scale=model.GetLossScale()
scale=model.GetLossScale(),
num_classes=model.num_classes - 1
)
gradients.append(cls_focal_loss)
losses.append('fl_{}'.format(suffix))
......@@ -300,6 +301,7 @@ def add_fpn_retinanet_losses(model):
gamma=cfg.RETINANET.LOSS_GAMMA,
alpha=cfg.RETINANET.LOSS_ALPHA,
scale=model.GetLossScale(),
num_classes=model.num_classes
)
gradients.append(cls_focal_loss)
losses.append('fl_{}'.format(suffix))
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment