Commit 9cb251d3 authored by zimenglan's avatar zimenglan Committed by Francisco Massa

agnostic-regression for bbox (#390)

* make pixel indexes 0-based for bounding box in pascal voc dataset

* replacing all instances of torch.distributed.deprecated with torch.distributed

* replacing all instances of torch.distributed.deprecated with torch.distributed

* add GroupNorm

* add GroupNorm -- sort out yaml files

* use torch.nn.GroupNorm instead, replace 'use_gn' with 'conv_block' and use 'BaseStem'&'Bottleneck' to simply codes

* modification on 'group_norm' and 'conv_with_kaiming_uniform' function

* modification on yaml files in configs/gn_baselines/ and reduce the amount of indentation and code duplication

* use 'kaiming_uniform' to initialize resnet, disable gn after fc layer, and add dilation into ResNetHead

* agnostic-regression for bbox
parent 519e8dd4
...@@ -25,6 +25,7 @@ _C.MODEL.RPN_ONLY = False ...@@ -25,6 +25,7 @@ _C.MODEL.RPN_ONLY = False
_C.MODEL.MASK_ON = False _C.MODEL.MASK_ON = False
_C.MODEL.DEVICE = "cuda" _C.MODEL.DEVICE = "cuda"
_C.MODEL.META_ARCHITECTURE = "GeneralizedRCNN" _C.MODEL.META_ARCHITECTURE = "GeneralizedRCNN"
_C.MODEL.CLS_AGNOSTIC_BBOX_REG = False
# If the WEIGHT starts with a catalog://, like :R-50, the code will look for # If the WEIGHT starts with a catalog://, like :R-50, the code will look for
# the path in paths_catalog. Else, it will use it as the specified absolute # the path in paths_catalog. Else, it will use it as the specified absolute
......
...@@ -17,7 +17,12 @@ class PostProcessor(nn.Module): ...@@ -17,7 +17,12 @@ class PostProcessor(nn.Module):
""" """
def __init__( def __init__(
self, score_thresh=0.05, nms=0.5, detections_per_img=100, box_coder=None self,
score_thresh=0.05,
nms=0.5,
detections_per_img=100,
box_coder=None,
cls_agnostic_bbox_reg=False
): ):
""" """
Arguments: Arguments:
...@@ -33,6 +38,7 @@ class PostProcessor(nn.Module): ...@@ -33,6 +38,7 @@ class PostProcessor(nn.Module):
if box_coder is None: if box_coder is None:
box_coder = BoxCoder(weights=(10., 10., 5., 5.)) box_coder = BoxCoder(weights=(10., 10., 5., 5.))
self.box_coder = box_coder self.box_coder = box_coder
self.cls_agnostic_bbox_reg = cls_agnostic_bbox_reg
def forward(self, x, boxes): def forward(self, x, boxes):
""" """
...@@ -54,9 +60,13 @@ class PostProcessor(nn.Module): ...@@ -54,9 +60,13 @@ class PostProcessor(nn.Module):
boxes_per_image = [len(box) for box in boxes] boxes_per_image = [len(box) for box in boxes]
concat_boxes = torch.cat([a.bbox for a in boxes], dim=0) concat_boxes = torch.cat([a.bbox for a in boxes], dim=0)
if self.cls_agnostic_bbox_reg:
box_regression = box_regression[:, -4:]
proposals = self.box_coder.decode( proposals = self.box_coder.decode(
box_regression.view(sum(boxes_per_image), -1), concat_boxes box_regression.view(sum(boxes_per_image), -1), concat_boxes
) )
if self.cls_agnostic_bbox_reg:
proposals = proposals.repeat(1, class_prob.shape[1])
num_classes = class_prob.shape[1] num_classes = class_prob.shape[1]
...@@ -145,8 +155,13 @@ def make_roi_box_post_processor(cfg): ...@@ -145,8 +155,13 @@ def make_roi_box_post_processor(cfg):
score_thresh = cfg.MODEL.ROI_HEADS.SCORE_THRESH score_thresh = cfg.MODEL.ROI_HEADS.SCORE_THRESH
nms_thresh = cfg.MODEL.ROI_HEADS.NMS nms_thresh = cfg.MODEL.ROI_HEADS.NMS
detections_per_img = cfg.MODEL.ROI_HEADS.DETECTIONS_PER_IMG detections_per_img = cfg.MODEL.ROI_HEADS.DETECTIONS_PER_IMG
cls_agnostic_bbox_reg = cfg.MODEL.CLS_AGNOSTIC_BBOX_REG
postprocessor = PostProcessor( postprocessor = PostProcessor(
score_thresh, nms_thresh, detections_per_img, box_coder score_thresh,
nms_thresh,
detections_per_img,
box_coder,
cls_agnostic_bbox_reg
) )
return postprocessor return postprocessor
...@@ -18,7 +18,13 @@ class FastRCNNLossComputation(object): ...@@ -18,7 +18,13 @@ class FastRCNNLossComputation(object):
Also supports FPN Also supports FPN
""" """
def __init__(self, proposal_matcher, fg_bg_sampler, box_coder): def __init__(
self,
proposal_matcher,
fg_bg_sampler,
box_coder,
cls_agnostic_bbox_reg=False
):
""" """
Arguments: Arguments:
proposal_matcher (Matcher) proposal_matcher (Matcher)
...@@ -28,6 +34,7 @@ class FastRCNNLossComputation(object): ...@@ -28,6 +34,7 @@ class FastRCNNLossComputation(object):
self.proposal_matcher = proposal_matcher self.proposal_matcher = proposal_matcher
self.fg_bg_sampler = fg_bg_sampler self.fg_bg_sampler = fg_bg_sampler
self.box_coder = box_coder self.box_coder = box_coder
self.cls_agnostic_bbox_reg = cls_agnostic_bbox_reg
def match_targets_to_proposals(self, proposal, target): def match_targets_to_proposals(self, proposal, target):
match_quality_matrix = boxlist_iou(target, proposal) match_quality_matrix = boxlist_iou(target, proposal)
...@@ -143,7 +150,11 @@ class FastRCNNLossComputation(object): ...@@ -143,7 +150,11 @@ class FastRCNNLossComputation(object):
# advanced indexing # advanced indexing
sampled_pos_inds_subset = torch.nonzero(labels > 0).squeeze(1) sampled_pos_inds_subset = torch.nonzero(labels > 0).squeeze(1)
labels_pos = labels[sampled_pos_inds_subset] labels_pos = labels[sampled_pos_inds_subset]
map_inds = 4 * labels_pos[:, None] + torch.tensor([0, 1, 2, 3], device=device) if self.cls_agnostic_bbox_reg:
map_inds = torch.tensor([4, 5, 6, 7], device=device)
else:
map_inds = 4 * labels_pos[:, None] + torch.tensor(
[0, 1, 2, 3], device=device)
box_loss = smooth_l1_loss( box_loss = smooth_l1_loss(
box_regression[sampled_pos_inds_subset[:, None], map_inds], box_regression[sampled_pos_inds_subset[:, None], map_inds],
...@@ -170,6 +181,13 @@ def make_roi_box_loss_evaluator(cfg): ...@@ -170,6 +181,13 @@ def make_roi_box_loss_evaluator(cfg):
cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE, cfg.MODEL.ROI_HEADS.POSITIVE_FRACTION cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE, cfg.MODEL.ROI_HEADS.POSITIVE_FRACTION
) )
loss_evaluator = FastRCNNLossComputation(matcher, fg_bg_sampler, box_coder) cls_agnostic_bbox_reg = cfg.MODEL.CLS_AGNOSTIC_BBOX_REG
loss_evaluator = FastRCNNLossComputation(
matcher,
fg_bg_sampler,
box_coder,
cls_agnostic_bbox_reg
)
return loss_evaluator return loss_evaluator
...@@ -14,7 +14,8 @@ class FastRCNNPredictor(nn.Module): ...@@ -14,7 +14,8 @@ class FastRCNNPredictor(nn.Module):
num_classes = config.MODEL.ROI_BOX_HEAD.NUM_CLASSES num_classes = config.MODEL.ROI_BOX_HEAD.NUM_CLASSES
self.avgpool = nn.AvgPool2d(kernel_size=7, stride=7) self.avgpool = nn.AvgPool2d(kernel_size=7, stride=7)
self.cls_score = nn.Linear(num_inputs, num_classes) self.cls_score = nn.Linear(num_inputs, num_classes)
self.bbox_pred = nn.Linear(num_inputs, num_classes * 4) num_bbox_reg_classes = 2 if cfg.MODEL.CLS_AGNOSTIC_BBOX_REG else num_classes
self.bbox_pred = nn.Linear(num_inputs, num_bbox_reg_classes * 4)
nn.init.normal_(self.cls_score.weight, mean=0, std=0.01) nn.init.normal_(self.cls_score.weight, mean=0, std=0.01)
nn.init.constant_(self.cls_score.bias, 0) nn.init.constant_(self.cls_score.bias, 0)
...@@ -37,7 +38,8 @@ class FPNPredictor(nn.Module): ...@@ -37,7 +38,8 @@ class FPNPredictor(nn.Module):
representation_size = cfg.MODEL.ROI_BOX_HEAD.MLP_HEAD_DIM representation_size = cfg.MODEL.ROI_BOX_HEAD.MLP_HEAD_DIM
self.cls_score = nn.Linear(representation_size, num_classes) self.cls_score = nn.Linear(representation_size, num_classes)
self.bbox_pred = nn.Linear(representation_size, num_classes * 4) num_bbox_reg_classes = 2 if cfg.MODEL.CLS_AGNOSTIC_BBOX_REG else num_classes
self.bbox_pred = nn.Linear(representation_size, num_bbox_reg_classes * 4)
nn.init.normal_(self.cls_score.weight, std=0.01) nn.init.normal_(self.cls_score.weight, std=0.01)
nn.init.normal_(self.bbox_pred.weight, std=0.001) nn.init.normal_(self.bbox_pred.weight, std=0.001)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment