Commit b318c3ec authored by Rodrigo Berriel's avatar Rodrigo Berriel Committed by Francisco Massa

Add a switch for POST_NMS per batch/image during training (#695)

parent 4466eb5a
...@@ -165,6 +165,9 @@ _C.MODEL.RPN.MIN_SIZE = 0 ...@@ -165,6 +165,9 @@ _C.MODEL.RPN.MIN_SIZE = 0
# all FPN levels # all FPN levels
_C.MODEL.RPN.FPN_POST_NMS_TOP_N_TRAIN = 2000 _C.MODEL.RPN.FPN_POST_NMS_TOP_N_TRAIN = 2000
_C.MODEL.RPN.FPN_POST_NMS_TOP_N_TEST = 2000 _C.MODEL.RPN.FPN_POST_NMS_TOP_N_TEST = 2000
# Apply the post NMS per batch (default) or per image during training
# (default is True to be consistent with Detectron, see Issue #672)
_C.MODEL.RPN.FPN_POST_NMS_PER_BATCH = True
# Custom rpn head, empty to use default conv or separable conv # Custom rpn head, empty to use default conv or separable conv
_C.MODEL.RPN.RPN_HEAD = "SingleConvRPNHead" _C.MODEL.RPN.RPN_HEAD = "SingleConvRPNHead"
......
...@@ -24,6 +24,7 @@ class RPNPostProcessor(torch.nn.Module): ...@@ -24,6 +24,7 @@ class RPNPostProcessor(torch.nn.Module):
min_size, min_size,
box_coder=None, box_coder=None,
fpn_post_nms_top_n=None, fpn_post_nms_top_n=None,
fpn_post_nms_per_batch=True,
): ):
""" """
Arguments: Arguments:
...@@ -47,6 +48,7 @@ class RPNPostProcessor(torch.nn.Module): ...@@ -47,6 +48,7 @@ class RPNPostProcessor(torch.nn.Module):
if fpn_post_nms_top_n is None: if fpn_post_nms_top_n is None:
fpn_post_nms_top_n = post_nms_top_n fpn_post_nms_top_n = post_nms_top_n
self.fpn_post_nms_top_n = fpn_post_nms_top_n self.fpn_post_nms_top_n = fpn_post_nms_top_n
self.fpn_post_nms_per_batch = fpn_post_nms_per_batch
def add_gt_proposals(self, proposals, targets): def add_gt_proposals(self, proposals, targets):
""" """
...@@ -154,9 +156,9 @@ class RPNPostProcessor(torch.nn.Module): ...@@ -154,9 +156,9 @@ class RPNPostProcessor(torch.nn.Module):
# different behavior during training and during testing: # different behavior during training and during testing:
# during training, post_nms_top_n is over *all* the proposals combined, while # during training, post_nms_top_n is over *all* the proposals combined, while
# during testing, it is over the proposals for each image # during testing, it is over the proposals for each image
# TODO resolve this difference and make it consistent. It should be per image, # NOTE: it should be per image, and not per batch. However, to be consistent
# and not per batch # with Detectron, the default is per batch (see Issue #672)
if self.training: if self.training and self.fpn_post_nms_per_batch:
objectness = torch.cat( objectness = torch.cat(
[boxlist.get_field("objectness") for boxlist in boxlists], dim=0 [boxlist.get_field("objectness") for boxlist in boxlists], dim=0
) )
...@@ -189,6 +191,7 @@ def make_rpn_postprocessor(config, rpn_box_coder, is_train): ...@@ -189,6 +191,7 @@ def make_rpn_postprocessor(config, rpn_box_coder, is_train):
if not is_train: if not is_train:
pre_nms_top_n = config.MODEL.RPN.PRE_NMS_TOP_N_TEST pre_nms_top_n = config.MODEL.RPN.PRE_NMS_TOP_N_TEST
post_nms_top_n = config.MODEL.RPN.POST_NMS_TOP_N_TEST post_nms_top_n = config.MODEL.RPN.POST_NMS_TOP_N_TEST
fpn_post_nms_per_batch = config.MODEL.RPN.FPN_POST_NMS_PER_BATCH
nms_thresh = config.MODEL.RPN.NMS_THRESH nms_thresh = config.MODEL.RPN.NMS_THRESH
min_size = config.MODEL.RPN.MIN_SIZE min_size = config.MODEL.RPN.MIN_SIZE
box_selector = RPNPostProcessor( box_selector = RPNPostProcessor(
...@@ -198,5 +201,6 @@ def make_rpn_postprocessor(config, rpn_box_coder, is_train): ...@@ -198,5 +201,6 @@ def make_rpn_postprocessor(config, rpn_box_coder, is_train):
min_size=min_size, min_size=min_size,
box_coder=rpn_box_coder, box_coder=rpn_box_coder,
fpn_post_nms_top_n=fpn_post_nms_top_n, fpn_post_nms_top_n=fpn_post_nms_top_n,
fpn_post_nms_per_batch=fpn_post_nms_per_batch,
) )
return box_selector return box_selector
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment