Commit 3b27142e authored by zimenglan's avatar zimenglan Committed by Francisco Massa

add GroupNorm (#346)

* make pixel indexes 0-based for bounding box in pascal voc dataset

* replacing all instances of torch.distributed.deprecated with torch.distributed

* replacing all instances of torch.distributed.deprecated with torch.distributed

* add GroupNorm

* add GroupNorm -- sort out yaml files

* use torch.nn.GroupNorm instead, replace 'use_gn' with 'conv_block' and use 'BaseStem'&'Bottleneck' to simply codes

* modification on 'group_norm' and 'conv_with_kaiming_uniform' function

* modification on yaml files in configs/gn_baselines/ and reduce the amount of indentation and code duplication
parent abf36b94
INPUT:
MIN_SIZE_TRAIN: 800
MAX_SIZE_TRAIN: 1333
MIN_SIZE_TEST: 800
MAX_SIZE_TEST: 1333
MODEL:
META_ARCHITECTURE: "GeneralizedRCNN"
WEIGHT: "catalog://ImageNetPretrained/MSRA/R-50-GN"
BACKBONE:
CONV_BODY: "R-50-FPN"
OUT_CHANNELS: 256
RESNETS: # use GN for backbone
TRANS_FUNC: "BottleneckWithGN"
STEM_FUNC: "StemWithGN"
FPN:
USE_GN: True # use GN for FPN
RPN:
USE_FPN: True
ANCHOR_STRIDE: (4, 8, 16, 32, 64)
PRE_NMS_TOP_N_TRAIN: 2000
PRE_NMS_TOP_N_TEST: 1000
POST_NMS_TOP_N_TEST: 1000
FPN_POST_NMS_TOP_N_TEST: 1000
ROI_HEADS:
USE_FPN: True
BATCH_SIZE_PER_IMAGE: 512
POSITIVE_FRACTION: 0.25
ROI_BOX_HEAD:
USE_GN: True # use GN for bbox head
POOLER_RESOLUTION: 7
POOLER_SCALES: (0.25, 0.125, 0.0625, 0.03125)
POOLER_SAMPLING_RATIO: 2
FEATURE_EXTRACTOR: "FPN2MLPFeatureExtractor"
PREDICTOR: "FPNPredictor"
DATASETS:
TRAIN: ("coco_2014_train", "coco_2014_valminusminival")
TEST: ("coco_2014_minival",)
DATALOADER:
SIZE_DIVISIBILITY: 32
SOLVER:
# Assume 8 gpus
BASE_LR: 0.02
WEIGHT_DECAY: 0.0001
STEPS: (60000, 80000)
MAX_ITER: 90000
IMS_PER_BATCH: 16
TEST:
IMS_PER_BATCH: 8
INPUT:
MIN_SIZE_TRAIN: 800
MAX_SIZE_TRAIN: 1333
MIN_SIZE_TEST: 800
MAX_SIZE_TEST: 1333
MODEL:
META_ARCHITECTURE: "GeneralizedRCNN"
WEIGHT: "catalog://ImageNetPretrained/MSRA/R-50-GN"
BACKBONE:
CONV_BODY: "R-50-FPN"
OUT_CHANNELS: 256
RESNETS: # use GN for backbone
TRANS_FUNC: "BottleneckWithGN"
STEM_FUNC: "StemWithGN"
FPN:
USE_GN: True # use GN for FPN
RPN:
USE_FPN: True
ANCHOR_STRIDE: (4, 8, 16, 32, 64)
PRE_NMS_TOP_N_TRAIN: 2000
PRE_NMS_TOP_N_TEST: 1000
POST_NMS_TOP_N_TEST: 1000
FPN_POST_NMS_TOP_N_TEST: 1000
ROI_HEADS:
USE_FPN: True
BATCH_SIZE_PER_IMAGE: 512
POSITIVE_FRACTION: 0.25
ROI_BOX_HEAD:
USE_GN: True # use GN for bbox head
POOLER_RESOLUTION: 7
POOLER_SCALES: (0.25, 0.125, 0.0625, 0.03125)
POOLER_SAMPLING_RATIO: 2
CONV_HEAD_DIM: 256
NUM_STACKED_CONVS: 4
FEATURE_EXTRACTOR: "FPNXconv1fcFeatureExtractor"
PREDICTOR: "FPNPredictor"
DATASETS:
TRAIN: ("coco_2014_train", "coco_2014_valminusminival")
TEST: ("coco_2014_minival",)
DATALOADER:
SIZE_DIVISIBILITY: 32
SOLVER:
# Assume 8 gpus
BASE_LR: 0.02
WEIGHT_DECAY: 0.0001
STEPS: (60000, 80000)
MAX_ITER: 90000
IMS_PER_BATCH: 16
TEST:
IMS_PER_BATCH: 8
\ No newline at end of file
INPUT:
MIN_SIZE_TRAIN: 800
MAX_SIZE_TRAIN: 1333
MIN_SIZE_TEST: 800
MAX_SIZE_TEST: 1333
MODEL:
META_ARCHITECTURE: "GeneralizedRCNN"
WEIGHT: "catalog://ImageNetPretrained/MSRA/R-50-GN"
BACKBONE:
CONV_BODY: "R-50-FPN"
OUT_CHANNELS: 256
RESNETS: # use GN for backbone
TRANS_FUNC: "BottleneckWithGN"
STEM_FUNC: "StemWithGN"
FPN:
USE_GN: True # use GN for FPN
RPN:
USE_FPN: True
ANCHOR_STRIDE: (4, 8, 16, 32, 64)
PRE_NMS_TOP_N_TRAIN: 2000
PRE_NMS_TOP_N_TEST: 1000
POST_NMS_TOP_N_TEST: 1000
FPN_POST_NMS_TOP_N_TEST: 1000
ROI_HEADS:
USE_FPN: True
BATCH_SIZE_PER_IMAGE: 512
POSITIVE_FRACTION: 0.25
ROI_BOX_HEAD:
USE_GN: True # use GN for bbox head
POOLER_RESOLUTION: 7
POOLER_SCALES: (0.25, 0.125, 0.0625, 0.03125)
POOLER_SAMPLING_RATIO: 2
FEATURE_EXTRACTOR: "FPN2MLPFeatureExtractor"
PREDICTOR: "FPNPredictor"
ROI_MASK_HEAD:
USE_GN: True # use GN for mask head
POOLER_SCALES: (0.25, 0.125, 0.0625, 0.03125)
CONV_LAYERS: (256, 256, 256, 256)
FEATURE_EXTRACTOR: "MaskRCNNFPNFeatureExtractor"
PREDICTOR: "MaskRCNNC4Predictor"
POOLER_RESOLUTION: 14
POOLER_SAMPLING_RATIO: 2
RESOLUTION: 28
SHARE_BOX_FEATURE_EXTRACTOR: False
MASK_ON: True
DATASETS:
TRAIN: ("coco_2014_train", "coco_2014_valminusminival")
TEST: ("coco_2014_minival",)
DATALOADER:
SIZE_DIVISIBILITY: 32
SOLVER:
# Assume 8 gpus
BASE_LR: 0.02
WEIGHT_DECAY: 0.0001
STEPS: (60000, 80000)
MAX_ITER: 90000
IMS_PER_BATCH: 16
TEST:
IMS_PER_BATCH: 8
\ No newline at end of file
INPUT:
MIN_SIZE_TRAIN: 800
MAX_SIZE_TRAIN: 1333
MIN_SIZE_TEST: 800
MAX_SIZE_TEST: 1333
MODEL:
META_ARCHITECTURE: "GeneralizedRCNN"
WEIGHT: "catalog://ImageNetPretrained/MSRA/R-50-GN"
BACKBONE:
CONV_BODY: "R-50-FPN"
OUT_CHANNELS: 256
RESNETS: # use GN for backbone
TRANS_FUNC: "BottleneckWithGN"
STEM_FUNC: "StemWithGN"
FPN:
USE_GN: True # use GN for FPN
RPN:
USE_FPN: True
ANCHOR_STRIDE: (4, 8, 16, 32, 64)
PRE_NMS_TOP_N_TRAIN: 2000
PRE_NMS_TOP_N_TEST: 1000
POST_NMS_TOP_N_TEST: 1000
FPN_POST_NMS_TOP_N_TEST: 1000
ROI_HEADS:
USE_FPN: True
BATCH_SIZE_PER_IMAGE: 512
POSITIVE_FRACTION: 0.25
ROI_BOX_HEAD:
USE_GN: True # use GN for bbox head
POOLER_RESOLUTION: 7
POOLER_SCALES: (0.25, 0.125, 0.0625, 0.03125)
POOLER_SAMPLING_RATIO: 2
CONV_HEAD_DIM: 256
NUM_STACKED_CONVS: 4
FEATURE_EXTRACTOR: "FPNXconv1fcFeatureExtractor"
PREDICTOR: "FPNPredictor"
ROI_MASK_HEAD:
USE_GN: True # use GN for mask head
POOLER_SCALES: (0.25, 0.125, 0.0625, 0.03125)
CONV_LAYERS: (256, 256, 256, 256)
FEATURE_EXTRACTOR: "MaskRCNNFPNFeatureExtractor"
PREDICTOR: "MaskRCNNC4Predictor"
POOLER_RESOLUTION: 14
POOLER_SAMPLING_RATIO: 2
RESOLUTION: 28
SHARE_BOX_FEATURE_EXTRACTOR: False
MASK_ON: True
DATASETS:
TRAIN: ("coco_2014_train", "coco_2014_valminusminival")
TEST: ("coco_2014_minival",)
DATALOADER:
SIZE_DIVISIBILITY: 32
SOLVER:
# Assume 8 gpus
BASE_LR: 0.02
WEIGHT_DECAY: 0.0001
STEPS: (60000, 80000)
MAX_ITER: 90000
IMS_PER_BATCH: 16
TEST:
IMS_PER_BATCH: 8
\ No newline at end of file
INPUT:
MIN_SIZE_TRAIN: 800
MAX_SIZE_TRAIN: 1333
MIN_SIZE_TEST: 800
MAX_SIZE_TEST: 1333
MODEL:
META_ARCHITECTURE: "GeneralizedRCNN"
WEIGHT: "" # no pretrained model
BACKBONE:
CONV_BODY: "R-50-FPN"
OUT_CHANNELS: 256
FREEZE_CONV_BODY_AT: 0 # finetune all layers
RESNETS: # use GN for backbone
TRANS_FUNC: "BottleneckWithGN"
STEM_FUNC: "StemWithGN"
FPN:
USE_GN: True # use GN for FPN
RPN:
USE_FPN: True
ANCHOR_STRIDE: (4, 8, 16, 32, 64)
PRE_NMS_TOP_N_TRAIN: 2000
PRE_NMS_TOP_N_TEST: 1000
POST_NMS_TOP_N_TEST: 1000
FPN_POST_NMS_TOP_N_TEST: 1000
ROI_HEADS:
USE_FPN: True
BATCH_SIZE_PER_IMAGE: 512
POSITIVE_FRACTION: 0.25
ROI_BOX_HEAD:
USE_GN: True # use GN for bbox head
POOLER_RESOLUTION: 7
POOLER_SCALES: (0.25, 0.125, 0.0625, 0.03125)
POOLER_SAMPLING_RATIO: 2
FEATURE_EXTRACTOR: "FPN2MLPFeatureExtractor"
PREDICTOR: "FPNPredictor"
DATASETS:
TRAIN: ("coco_2014_train", "coco_2014_valminusminival")
TEST: ("coco_2014_minival",)
DATALOADER:
SIZE_DIVISIBILITY: 32
SOLVER:
# Assume 8 gpus
BASE_LR: 0.02
WEIGHT_DECAY: 0.0001
STEPS: (210000, 250000)
MAX_ITER: 270000
IMS_PER_BATCH: 16
TEST:
IMS_PER_BATCH: 8
\ No newline at end of file
INPUT:
MIN_SIZE_TRAIN: 800
MAX_SIZE_TRAIN: 1333
MIN_SIZE_TEST: 800
MAX_SIZE_TEST: 1333
MODEL:
META_ARCHITECTURE: "GeneralizedRCNN"
WEIGHT: "" # no pretrained model
BACKBONE:
CONV_BODY: "R-50-FPN"
OUT_CHANNELS: 256
FREEZE_CONV_BODY_AT: 0 # finetune all layers
RESNETS: # use GN for backbone
TRANS_FUNC: "BottleneckWithGN"
STEM_FUNC: "StemWithGN"
FPN:
USE_GN: True # use GN for FPN
RPN:
USE_FPN: True
ANCHOR_STRIDE: (4, 8, 16, 32, 64)
PRE_NMS_TOP_N_TRAIN: 2000
PRE_NMS_TOP_N_TEST: 1000
POST_NMS_TOP_N_TEST: 1000
FPN_POST_NMS_TOP_N_TEST: 1000
ROI_HEADS:
USE_FPN: True
BATCH_SIZE_PER_IMAGE: 512
POSITIVE_FRACTION: 0.25
ROI_BOX_HEAD:
USE_GN: True # use GN for bbox head
POOLER_RESOLUTION: 7
POOLER_SCALES: (0.25, 0.125, 0.0625, 0.03125)
POOLER_SAMPLING_RATIO: 2
CONV_HEAD_DIM: 256
NUM_STACKED_CONVS: 4
FEATURE_EXTRACTOR: "FPNXconv1fcFeatureExtractor"
PREDICTOR: "FPNPredictor"
DATASETS:
TRAIN: ("coco_2014_train", "coco_2014_valminusminival")
TEST: ("coco_2014_minival",)
DATALOADER:
SIZE_DIVISIBILITY: 32
SOLVER:
# Assume 8 gpus
BASE_LR: 0.02
WEIGHT_DECAY: 0.0001
STEPS: (210000, 250000)
MAX_ITER: 270000
IMS_PER_BATCH: 16
TEST:
IMS_PER_BATCH: 8
\ No newline at end of file
INPUT:
MIN_SIZE_TRAIN: 800
MAX_SIZE_TRAIN: 1333
MIN_SIZE_TEST: 800
MAX_SIZE_TEST: 1333
MODEL:
META_ARCHITECTURE: "GeneralizedRCNN"
WEIGHT: "" # no pretrained model
BACKBONE:
CONV_BODY: "R-50-FPN"
OUT_CHANNELS: 256
FREEZE_CONV_BODY_AT: 0 # finetune all layers
RESNETS: # use GN for backbone
TRANS_FUNC: "BottleneckWithGN"
STEM_FUNC: "StemWithGN"
FPN:
USE_GN: True # use GN for FPN
RPN:
USE_FPN: True
ANCHOR_STRIDE: (4, 8, 16, 32, 64)
PRE_NMS_TOP_N_TRAIN: 2000
PRE_NMS_TOP_N_TEST: 1000
POST_NMS_TOP_N_TEST: 1000
FPN_POST_NMS_TOP_N_TEST: 1000
ROI_HEADS:
USE_FPN: True
BATCH_SIZE_PER_IMAGE: 512
POSITIVE_FRACTION: 0.25
ROI_BOX_HEAD:
USE_GN: True # use GN for bbox head
POOLER_RESOLUTION: 7
POOLER_SCALES: (0.25, 0.125, 0.0625, 0.03125)
POOLER_SAMPLING_RATIO: 2
FEATURE_EXTRACTOR: "FPN2MLPFeatureExtractor"
PREDICTOR: "FPNPredictor"
ROI_MASK_HEAD:
USE_GN: True # use GN for mask head
POOLER_SCALES: (0.25, 0.125, 0.0625, 0.03125)
CONV_LAYERS: (256, 256, 256, 256)
FEATURE_EXTRACTOR: "MaskRCNNFPNFeatureExtractor"
PREDICTOR: "MaskRCNNC4Predictor"
POOLER_RESOLUTION: 14
POOLER_SAMPLING_RATIO: 2
RESOLUTION: 28
SHARE_BOX_FEATURE_EXTRACTOR: False
MASK_ON: True
DATASETS:
TRAIN: ("coco_2014_train", "coco_2014_valminusminival")
TEST: ("coco_2014_minival",)
DATALOADER:
SIZE_DIVISIBILITY: 32
SOLVER:
# Assume 8 gpus
BASE_LR: 0.02
WEIGHT_DECAY: 0.0001
STEPS: (210000, 250000)
MAX_ITER: 270000
IMS_PER_BATCH: 16
TEST:
IMS_PER_BATCH: 8
INPUT:
MIN_SIZE_TRAIN: 800
MAX_SIZE_TRAIN: 1333
MIN_SIZE_TEST: 800
MAX_SIZE_TEST: 1333
MODEL:
META_ARCHITECTURE: "GeneralizedRCNN"
WEIGHT: "" # no pretrained model
BACKBONE:
CONV_BODY: "R-50-FPN"
OUT_CHANNELS: 256
FREEZE_CONV_BODY_AT: 0 # finetune all layers
RESNETS: # use GN for backbone
TRANS_FUNC: "BottleneckWithGN"
STEM_FUNC: "StemWithGN"
FPN:
USE_GN: True # use GN for FPN
RPN:
USE_FPN: True
ANCHOR_STRIDE: (4, 8, 16, 32, 64)
PRE_NMS_TOP_N_TRAIN: 2000
PRE_NMS_TOP_N_TEST: 1000
POST_NMS_TOP_N_TEST: 1000
FPN_POST_NMS_TOP_N_TEST: 1000
ROI_HEADS:
USE_FPN: True
BATCH_SIZE_PER_IMAGE: 512
POSITIVE_FRACTION: 0.25
ROI_BOX_HEAD:
USE_GN: True # use GN for bbox head
POOLER_RESOLUTION: 7
POOLER_SCALES: (0.25, 0.125, 0.0625, 0.03125)
POOLER_SAMPLING_RATIO: 2
CONV_HEAD_DIM: 256
NUM_STACKED_CONVS: 4
FEATURE_EXTRACTOR: "FPNXconv1fcFeatureExtractor"
PREDICTOR: "FPNPredictor"
ROI_MASK_HEAD:
USE_GN: True # use GN for mask head
POOLER_SCALES: (0.25, 0.125, 0.0625, 0.03125)
CONV_LAYERS: (256, 256, 256, 256)
FEATURE_EXTRACTOR: "MaskRCNNFPNFeatureExtractor"
PREDICTOR: "MaskRCNNC4Predictor"
POOLER_RESOLUTION: 14
POOLER_SAMPLING_RATIO: 2
RESOLUTION: 28
SHARE_BOX_FEATURE_EXTRACTOR: False
MASK_ON: True
DATASETS:
TRAIN: ("coco_2014_train", "coco_2014_valminusminival")
TEST: ("coco_2014_minival",)
DATALOADER:
SIZE_DIVISIBILITY: 32
SOLVER:
# Assume 8 gpus
BASE_LR: 0.02
WEIGHT_DECAY: 0.0001
STEPS: (210000, 250000)
MAX_ITER: 270000
IMS_PER_BATCH: 16
TEST:
IMS_PER_BATCH: 8
\ No newline at end of file
...@@ -75,6 +75,7 @@ _C.DATALOADER.SIZE_DIVISIBILITY = 0 ...@@ -75,6 +75,7 @@ _C.DATALOADER.SIZE_DIVISIBILITY = 0
# are not batched with portrait images. # are not batched with portrait images.
_C.DATALOADER.ASPECT_RATIO_GROUPING = True _C.DATALOADER.ASPECT_RATIO_GROUPING = True
# ---------------------------------------------------------------------------- # # ---------------------------------------------------------------------------- #
# Backbone options # Backbone options
# ---------------------------------------------------------------------------- # # ---------------------------------------------------------------------------- #
...@@ -89,6 +90,28 @@ _C.MODEL.BACKBONE.CONV_BODY = "R-50-C4" ...@@ -89,6 +90,28 @@ _C.MODEL.BACKBONE.CONV_BODY = "R-50-C4"
# Add StopGrad at a specified stage so the bottom layers are frozen # Add StopGrad at a specified stage so the bottom layers are frozen
_C.MODEL.BACKBONE.FREEZE_CONV_BODY_AT = 2 _C.MODEL.BACKBONE.FREEZE_CONV_BODY_AT = 2
_C.MODEL.BACKBONE.OUT_CHANNELS = 256 * 4 _C.MODEL.BACKBONE.OUT_CHANNELS = 256 * 4
# GN for backbone
_C.MODEL.BACKBONE.USE_GN = False
# ---------------------------------------------------------------------------- #
# FPN options
# ---------------------------------------------------------------------------- #
_C.MODEL.FPN = CN()
_C.MODEL.FPN.USE_GN = False
_C.MODEL.FPN.USE_RELU = False
# ---------------------------------------------------------------------------- #
# Group Norm options
# ---------------------------------------------------------------------------- #
_C.MODEL.GROUP_NORM = CN()
# Number of dimensions per group in GroupNorm (-1 if using NUM_GROUPS)
_C.MODEL.GROUP_NORM.DIM_PER_GP = -1
# Number of groups in GroupNorm (-1 if using DIM_PER_GP)
_C.MODEL.GROUP_NORM.NUM_GROUPS = 32
# GroupNorm's small constant in the denominator
_C.MODEL.GROUP_NORM.EPSILON = 1e-5
# ---------------------------------------------------------------------------- # # ---------------------------------------------------------------------------- #
...@@ -182,6 +205,12 @@ _C.MODEL.ROI_BOX_HEAD.POOLER_SCALES = (1.0 / 16,) ...@@ -182,6 +205,12 @@ _C.MODEL.ROI_BOX_HEAD.POOLER_SCALES = (1.0 / 16,)
_C.MODEL.ROI_BOX_HEAD.NUM_CLASSES = 81 _C.MODEL.ROI_BOX_HEAD.NUM_CLASSES = 81
# Hidden layer dimension when using an MLP for the RoI box head # Hidden layer dimension when using an MLP for the RoI box head
_C.MODEL.ROI_BOX_HEAD.MLP_HEAD_DIM = 1024 _C.MODEL.ROI_BOX_HEAD.MLP_HEAD_DIM = 1024
# GN
_C.MODEL.ROI_BOX_HEAD.USE_GN = False
# Dilation
_C.MODEL.ROI_BOX_HEAD.DILATION = 1
_C.MODEL.ROI_BOX_HEAD.CONV_HEAD_DIM = 256
_C.MODEL.ROI_BOX_HEAD.NUM_STACKED_CONVS = 4
_C.MODEL.ROI_MASK_HEAD = CN() _C.MODEL.ROI_MASK_HEAD = CN()
...@@ -197,6 +226,10 @@ _C.MODEL.ROI_MASK_HEAD.SHARE_BOX_FEATURE_EXTRACTOR = True ...@@ -197,6 +226,10 @@ _C.MODEL.ROI_MASK_HEAD.SHARE_BOX_FEATURE_EXTRACTOR = True
# Whether or not resize and translate masks to the input image. # Whether or not resize and translate masks to the input image.
_C.MODEL.ROI_MASK_HEAD.POSTPROCESS_MASKS = False _C.MODEL.ROI_MASK_HEAD.POSTPROCESS_MASKS = False
_C.MODEL.ROI_MASK_HEAD.POSTPROCESS_MASKS_THRESHOLD = 0.5 _C.MODEL.ROI_MASK_HEAD.POSTPROCESS_MASKS_THRESHOLD = 0.5
# Dilation
_C.MODEL.ROI_MASK_HEAD.DILATION = 1
# GN
_C.MODEL.ROI_MASK_HEAD.USE_GN = False
# ---------------------------------------------------------------------------- # # ---------------------------------------------------------------------------- #
# ResNe[X]t options (ResNets = {ResNet, ResNeXt} # ResNe[X]t options (ResNets = {ResNet, ResNeXt}
......
...@@ -113,7 +113,9 @@ class ModelCatalog(object): ...@@ -113,7 +113,9 @@ class ModelCatalog(object):
S3_C2_DETECTRON_URL = "https://dl.fbaipublicfiles.com/detectron" S3_C2_DETECTRON_URL = "https://dl.fbaipublicfiles.com/detectron"
C2_IMAGENET_MODELS = { C2_IMAGENET_MODELS = {
"MSRA/R-50": "ImageNetPretrained/MSRA/R-50.pkl", "MSRA/R-50": "ImageNetPretrained/MSRA/R-50.pkl",
"MSRA/R-50-GN": "ImageNetPretrained/47261647/R-50-GN.pkl",
"MSRA/R-101": "ImageNetPretrained/MSRA/R-101.pkl", "MSRA/R-101": "ImageNetPretrained/MSRA/R-101.pkl",
"MSRA/R-101-GN": "ImageNetPretrained/47592356/R-101-GN.pkl",
"FAIR/20171220/X-101-32x8d": "ImageNetPretrained/20171220/X-101-32x8d.pkl", "FAIR/20171220/X-101-32x8d": "ImageNetPretrained/20171220/X-101-32x8d.pkl",
} }
......
...@@ -12,4 +12,7 @@ from .roi_pool import ROIPool ...@@ -12,4 +12,7 @@ from .roi_pool import ROIPool
from .roi_pool import roi_pool from .roi_pool import roi_pool
from .smooth_l1_loss import smooth_l1_loss from .smooth_l1_loss import smooth_l1_loss
__all__ = ["nms", "roi_align", "ROIAlign", "roi_pool", "ROIPool", "smooth_l1_loss", "Conv2d", "ConvTranspose2d", "interpolate", "FrozenBatchNorm2d"] __all__ = ["nms", "roi_align", "ROIAlign", "roi_pool", "ROIPool",
"smooth_l1_loss", "Conv2d", "ConvTranspose2d", "interpolate",
"FrozenBatchNorm2d",
]
...@@ -4,12 +4,15 @@ from collections import OrderedDict ...@@ -4,12 +4,15 @@ from collections import OrderedDict
from torch import nn from torch import nn
from maskrcnn_benchmark.modeling import registry from maskrcnn_benchmark.modeling import registry
from maskrcnn_benchmark.modeling.make_layers import conv_with_kaiming_uniform
from . import fpn as fpn_module from . import fpn as fpn_module
from . import resnet from . import resnet
@registry.BACKBONES.register("R-50-C4") @registry.BACKBONES.register("R-50-C4")
@registry.BACKBONES.register("R-50-C5")
@registry.BACKBONES.register("R-101-C4")
@registry.BACKBONES.register("R-101-C5")
def build_resnet_backbone(cfg): def build_resnet_backbone(cfg):
body = resnet.ResNet(cfg) body = resnet.ResNet(cfg)
model = nn.Sequential(OrderedDict([("body", body)])) model = nn.Sequential(OrderedDict([("body", body)]))
...@@ -30,6 +33,9 @@ def build_resnet_fpn_backbone(cfg): ...@@ -30,6 +33,9 @@ def build_resnet_fpn_backbone(cfg):
in_channels_stage2 * 8, in_channels_stage2 * 8,
], ],
out_channels=out_channels, out_channels=out_channels,
conv_block=conv_with_kaiming_uniform(
cfg.MODEL.FPN.USE_GN, cfg.MODEL.FPN.USE_RELU
),
top_blocks=fpn_module.LastLevelMaxPool(), top_blocks=fpn_module.LastLevelMaxPool(),
) )
model = nn.Sequential(OrderedDict([("body", body), ("fpn", fpn)])) model = nn.Sequential(OrderedDict([("body", body), ("fpn", fpn)]))
......
...@@ -11,7 +11,9 @@ class FPN(nn.Module): ...@@ -11,7 +11,9 @@ class FPN(nn.Module):
order, and must be consecutive order, and must be consecutive
""" """
def __init__(self, in_channels_list, out_channels, top_blocks=None): def __init__(
self, in_channels_list, out_channels, conv_block, top_blocks=None
):
""" """
Arguments: Arguments:
in_channels_list (list[int]): number of channels for each feature map that in_channels_list (list[int]): number of channels for each feature map that
...@@ -27,13 +29,8 @@ class FPN(nn.Module): ...@@ -27,13 +29,8 @@ class FPN(nn.Module):
for idx, in_channels in enumerate(in_channels_list, 1): for idx, in_channels in enumerate(in_channels_list, 1):
inner_block = "fpn_inner{}".format(idx) inner_block = "fpn_inner{}".format(idx)
layer_block = "fpn_layer{}".format(idx) layer_block = "fpn_layer{}".format(idx)
inner_block_module = nn.Conv2d(in_channels, out_channels, 1) inner_block_module = conv_block(in_channels, out_channels, 1)
layer_block_module = nn.Conv2d(out_channels, out_channels, 3, 1, 1) layer_block_module = conv_block(out_channels, out_channels, 3, 1)
for module in [inner_block_module, layer_block_module]:
# Caffe2 implementation uses XavierFill, which in fact
# corresponds to kaiming_uniform_ in PyTorch
nn.init.kaiming_uniform_(module.weight, a=1)
nn.init.constant_(module.bias, 0)
self.add_module(inner_block, inner_block_module) self.add_module(inner_block, inner_block_module)
self.add_module(layer_block, layer_block_module) self.add_module(layer_block, layer_block_module)
self.inner_blocks.append(inner_block) self.inner_blocks.append(inner_block)
......
...@@ -7,6 +7,12 @@ Example usage. Strings may be specified in the config file. ...@@ -7,6 +7,12 @@ Example usage. Strings may be specified in the config file.
"BottleneckWithFixedBatchNorm", "BottleneckWithFixedBatchNorm",
"ResNet50StagesTo4", "ResNet50StagesTo4",
) )
OR:
model = ResNet(
"StemWithGN",
"BottleneckWithGN",
"ResNet50StagesTo4",
)
Custom implementations may be written in user code and hooked in via the Custom implementations may be written in user code and hooked in via the
`register_*` functions. `register_*` functions.
""" """
...@@ -18,6 +24,7 @@ from torch import nn ...@@ -18,6 +24,7 @@ from torch import nn
from maskrcnn_benchmark.layers import FrozenBatchNorm2d from maskrcnn_benchmark.layers import FrozenBatchNorm2d
from maskrcnn_benchmark.layers import Conv2d from maskrcnn_benchmark.layers import Conv2d
from maskrcnn_benchmark.modeling.make_layers import group_norm
from maskrcnn_benchmark.utils.registry import Registry from maskrcnn_benchmark.utils.registry import Registry
...@@ -44,6 +51,16 @@ ResNet50StagesTo4 = tuple( ...@@ -44,6 +51,16 @@ ResNet50StagesTo4 = tuple(
StageSpec(index=i, block_count=c, return_features=r) StageSpec(index=i, block_count=c, return_features=r)
for (i, c, r) in ((1, 3, False), (2, 4, False), (3, 6, True)) for (i, c, r) in ((1, 3, False), (2, 4, False), (3, 6, True))
) )
# ResNet-101 (including all stages)
ResNet101StagesTo5 = tuple(
StageSpec(index=i, block_count=c, return_features=r)
for (i, c, r) in ((1, 3, False), (2, 4, False), (3, 23, False), (4, 3, True))
)
# ResNet-101 up to stage 4 (excludes stage 5)
ResNet101StagesTo4 = tuple(
StageSpec(index=i, block_count=c, return_features=r)
for (i, c, r) in ((1, 3, False), (2, 4, False), (3, 23, True))
)
# ResNet-50-FPN (including all stages) # ResNet-50-FPN (including all stages)
ResNet50FPNStagesTo5 = tuple( ResNet50FPNStagesTo5 = tuple(
StageSpec(index=i, block_count=c, return_features=r) StageSpec(index=i, block_count=c, return_features=r)
...@@ -104,6 +121,8 @@ class ResNet(nn.Module): ...@@ -104,6 +121,8 @@ class ResNet(nn.Module):
self._freeze_backbone(cfg.MODEL.BACKBONE.FREEZE_CONV_BODY_AT) self._freeze_backbone(cfg.MODEL.BACKBONE.FREEZE_CONV_BODY_AT)
def _freeze_backbone(self, freeze_at): def _freeze_backbone(self, freeze_at):
if freeze_at < 0:
return
for stage_index in range(freeze_at): for stage_index in range(freeze_at):
if stage_index == 0: if stage_index == 0:
m = self.stem # stage 0 is the stem m = self.stem # stage 0 is the stem
...@@ -132,6 +151,7 @@ class ResNetHead(nn.Module): ...@@ -132,6 +151,7 @@ class ResNetHead(nn.Module):
stride_in_1x1=True, stride_in_1x1=True,
stride_init=None, stride_init=None,
res2_out_channels=256, res2_out_channels=256,
dilation=1
): ):
super(ResNetHead, self).__init__() super(ResNetHead, self).__init__()
...@@ -158,6 +178,7 @@ class ResNetHead(nn.Module): ...@@ -158,6 +178,7 @@ class ResNetHead(nn.Module):
num_groups, num_groups,
stride_in_1x1, stride_in_1x1,
first_stride=stride, first_stride=stride,
dilation=dilation
) )
stride = None stride = None
self.add_module(name, module) self.add_module(name, module)
...@@ -178,6 +199,7 @@ def _make_stage( ...@@ -178,6 +199,7 @@ def _make_stage(
num_groups, num_groups,
stride_in_1x1, stride_in_1x1,
first_stride, first_stride,
dilation=1
): ):
blocks = [] blocks = []
stride = first_stride stride = first_stride
...@@ -190,6 +212,7 @@ def _make_stage( ...@@ -190,6 +212,7 @@ def _make_stage(
num_groups, num_groups,
stride_in_1x1, stride_in_1x1,
stride, stride,
dilation=dilation
) )
) )
stride = 1 stride = 1
...@@ -197,27 +220,34 @@ def _make_stage( ...@@ -197,27 +220,34 @@ def _make_stage(
return nn.Sequential(*blocks) return nn.Sequential(*blocks)
class BottleneckWithFixedBatchNorm(nn.Module): class Bottleneck(nn.Module):
def __init__( def __init__(
self, self,
in_channels, in_channels,
bottleneck_channels, bottleneck_channels,
out_channels, out_channels,
num_groups=1, num_groups,
stride_in_1x1=True, stride_in_1x1,
stride=1, stride,
dilation,
norm_func
): ):
super(BottleneckWithFixedBatchNorm, self).__init__() super(Bottleneck, self).__init__()
self.downsample = None self.downsample = None
if in_channels != out_channels: if in_channels != out_channels:
down_stride = stride if dilation == 1 else 1
self.downsample = nn.Sequential( self.downsample = nn.Sequential(
Conv2d( Conv2d(
in_channels, out_channels, kernel_size=1, stride=stride, bias=False in_channels, out_channels,
kernel_size=1, stride=down_stride, bias=False
), ),
FrozenBatchNorm2d(out_channels), norm_func(out_channels),
) )
if dilation > 1:
stride = 1 # reset to be 1
# The original MSRA ResNet models have stride in the first 1x1 conv # The original MSRA ResNet models have stride in the first 1x1 conv
# The subsequent fb.torch.resnet and Caffe2 ResNe[X]t implementations have # The subsequent fb.torch.resnet and Caffe2 ResNe[X]t implementations have
# stride in the 3x3 conv # stride in the 3x3 conv
...@@ -230,7 +260,7 @@ class BottleneckWithFixedBatchNorm(nn.Module): ...@@ -230,7 +260,7 @@ class BottleneckWithFixedBatchNorm(nn.Module):
stride=stride_1x1, stride=stride_1x1,
bias=False, bias=False,
) )
self.bn1 = FrozenBatchNorm2d(bottleneck_channels) self.bn1 = norm_func(bottleneck_channels)
# TODO: specify init for the above # TODO: specify init for the above
self.conv2 = Conv2d( self.conv2 = Conv2d(
...@@ -238,16 +268,17 @@ class BottleneckWithFixedBatchNorm(nn.Module): ...@@ -238,16 +268,17 @@ class BottleneckWithFixedBatchNorm(nn.Module):
bottleneck_channels, bottleneck_channels,
kernel_size=3, kernel_size=3,
stride=stride_3x3, stride=stride_3x3,
padding=1, padding=dilation,
bias=False, bias=False,
groups=num_groups, groups=num_groups,
dilation=dilation
) )
self.bn2 = FrozenBatchNorm2d(bottleneck_channels) self.bn2 = norm_func(bottleneck_channels)
self.conv3 = Conv2d( self.conv3 = Conv2d(
bottleneck_channels, out_channels, kernel_size=1, bias=False bottleneck_channels, out_channels, kernel_size=1, bias=False
) )
self.bn3 = FrozenBatchNorm2d(out_channels) self.bn3 = norm_func(out_channels)
def forward(self, x): def forward(self, x):
identity = x identity = x
...@@ -272,16 +303,16 @@ class BottleneckWithFixedBatchNorm(nn.Module): ...@@ -272,16 +303,16 @@ class BottleneckWithFixedBatchNorm(nn.Module):
return out return out
class StemWithFixedBatchNorm(nn.Module): class BaseStem(nn.Module):
def __init__(self, cfg): def __init__(self, cfg, norm_func):
super(StemWithFixedBatchNorm, self).__init__() super(BaseStem, self).__init__()
out_channels = cfg.MODEL.RESNETS.STEM_OUT_CHANNELS out_channels = cfg.MODEL.RESNETS.STEM_OUT_CHANNELS
self.conv1 = Conv2d( self.conv1 = Conv2d(
3, out_channels, kernel_size=7, stride=2, padding=3, bias=False 3, out_channels, kernel_size=7, stride=2, padding=3, bias=False
) )
self.bn1 = FrozenBatchNorm2d(out_channels) self.bn1 = norm_func(out_channels)
def forward(self, x): def forward(self, x):
x = self.conv1(x) x = self.conv1(x)
...@@ -291,15 +322,79 @@ class StemWithFixedBatchNorm(nn.Module): ...@@ -291,15 +322,79 @@ class StemWithFixedBatchNorm(nn.Module):
return x return x
class BottleneckWithFixedBatchNorm(Bottleneck):
def __init__(
self,
in_channels,
bottleneck_channels,
out_channels,
num_groups=1,
stride_in_1x1=True,
stride=1,
dilation=1
):
super(BottleneckWithFixedBatchNorm, self).__init__(
in_channels=in_channels,
bottleneck_channels=bottleneck_channels,
out_channels=out_channels,
num_groups=num_groups,
stride_in_1x1=stride_in_1x1,
stride=stride,
dilation=dilation,
norm_func=FrozenBatchNorm2d
)
class StemWithFixedBatchNorm(BaseStem):
def __init__(self, cfg):
super(StemWithFixedBatchNorm, self).__init__(
cfg, norm_func=FrozenBatchNorm2d
)
class BottleneckWithGN(Bottleneck):
def __init__(
self,
in_channels,
bottleneck_channels,
out_channels,
num_groups=1,
stride_in_1x1=True,
stride=1,
dilation=1
):
super(BottleneckWithGN, self).__init__(
in_channels=in_channels,
bottleneck_channels=bottleneck_channels,
out_channels=out_channels,
num_groups=num_groups,
stride_in_1x1=stride_in_1x1,
stride=stride,
dilation=dilation,
norm_func=group_norm
)
class StemWithGN(BaseStem):
def __init__(self, cfg):
super(StemWithGN, self).__init__(cfg, norm_func=group_norm)
_TRANSFORMATION_MODULES = Registry({ _TRANSFORMATION_MODULES = Registry({
"BottleneckWithFixedBatchNorm": BottleneckWithFixedBatchNorm "BottleneckWithFixedBatchNorm": BottleneckWithFixedBatchNorm,
"BottleneckWithGN": BottleneckWithGN,
}) })
_STEM_MODULES = Registry({"StemWithFixedBatchNorm": StemWithFixedBatchNorm}) _STEM_MODULES = Registry({
"StemWithFixedBatchNorm": StemWithFixedBatchNorm,
"StemWithGN": StemWithGN,
})
_STAGE_SPECS = Registry({ _STAGE_SPECS = Registry({
"R-50-C4": ResNet50StagesTo4, "R-50-C4": ResNet50StagesTo4,
"R-50-C5": ResNet50StagesTo5, "R-50-C5": ResNet50StagesTo5,
"R-101-C4": ResNet101StagesTo4,
"R-101-C5": ResNet101StagesTo5,
"R-50-FPN": ResNet50FPNStagesTo5, "R-50-FPN": ResNet50FPNStagesTo5,
"R-101-FPN": ResNet101FPNStagesTo5, "R-101-FPN": ResNet101FPNStagesTo5,
}) })
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
"""
Miscellaneous utility functions
"""
import torch
from torch import nn
from torch.nn import functional as F
from maskrcnn_benchmark.config import cfg
from maskrcnn_benchmark.layers import Conv2d
from maskrcnn_benchmark.modeling.poolers import Pooler
def get_group_gn(dim, dim_per_gp, num_groups):
"""get number of groups used by GroupNorm, based on number of channels."""
assert dim_per_gp == -1 or num_groups == -1, \
"GroupNorm: can only specify G or C/G."
if dim_per_gp > 0:
assert dim % dim_per_gp == 0, \
"dim: {}, dim_per_gp: {}".format(dim, dim_per_gp)
group_gn = dim // dim_per_gp
else:
assert dim % num_groups == 0, \
"dim: {}, num_groups: {}".format(dim, num_groups)
group_gn = num_groups
return group_gn
def group_norm(out_channels, affine=True, divisor=1):
out_channels = out_channels // divisor
dim_per_gp = cfg.MODEL.GROUP_NORM.DIM_PER_GP // divisor
num_groups = cfg.MODEL.GROUP_NORM.NUM_GROUPS // divisor
eps = cfg.MODEL.GROUP_NORM.EPSILON # default: 1e-5
return torch.nn.GroupNorm(
get_group_gn(out_channels, dim_per_gp, num_groups),
out_channels,
eps,
affine
)
def make_conv3x3(
in_channels,
out_channels,
dilation=1,
stride=1,
use_gn=False,
use_relu=False,
kaiming_init=True
):
conv = Conv2d(
in_channels,
out_channels,
kernel_size=3,
stride=stride,
padding=dilation,
dilation=dilation,
bias=False if use_gn else True
)
if kaiming_init:
nn.init.kaiming_normal_(
conv.weight, mode="fan_out", nonlinearity="relu"
)
else:
torch.nn.init.normal_(conv.weight, std=0.01)
if not use_gn:
nn.init.constant_(conv.bias, 0)
module = [conv,]
if use_gn:
module.append(group_norm(out_channels))
if use_relu:
module.append(nn.ReLU(inplace=True))
if len(module) > 1:
return nn.Sequential(*module)
return conv
def make_fc(dim_in, hidden_dim, use_gn):
'''
Caffe2 implementation uses XavierFill, which in fact
corresponds to kaiming_uniform_ in PyTorch
'''
if use_gn:
fc = nn.Linear(dim_in, hidden_dim, bias=False)
nn.init.kaiming_uniform_(fc.weight, a=1)
return nn.Sequential(fc, group_norm(hidden_dim))
fc = nn.Linear(dim_in, hidden_dim)
nn.init.kaiming_uniform_(fc.weight, a=1)
nn.init.constant_(fc.bias, 0)
return fc
def conv_with_kaiming_uniform(use_gn=False, use_relu=False):
def make_conv(
in_channels, out_channels, kernel_size, stride=1, dilation=1
):
conv = Conv2d(
in_channels,
out_channels,
kernel_size=kernel_size,
stride=stride,
padding=dilation * (kernel_size - 1) // 2,
dilation=dilation,
bias=False if use_gn else True
)
# Caffe2 implementation uses XavierFill, which in fact
# corresponds to kaiming_uniform_ in PyTorch
nn.init.kaiming_uniform_(conv.weight, a=1)
if not use_gn:
nn.init.constant_(conv.bias, 0)
module = [conv,]
if use_gn:
module.append(group_norm(out_channels))
if use_relu:
module.append(nn.ReLU(inplace=True))
if len(module) > 1:
return nn.Sequential(*module)
return conv
return make_conv
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
import torch
from torch import nn from torch import nn
from torch.nn import functional as F from torch.nn import functional as F
from maskrcnn_benchmark.modeling import registry from maskrcnn_benchmark.modeling import registry
from maskrcnn_benchmark.modeling.backbone import resnet from maskrcnn_benchmark.modeling.backbone import resnet
from maskrcnn_benchmark.modeling.poolers import Pooler from maskrcnn_benchmark.modeling.poolers import Pooler
from maskrcnn_benchmark.modeling.make_layers import group_norm
from maskrcnn_benchmark.modeling.make_layers import make_fc
@registry.ROI_BOX_FEATURE_EXTRACTORS.register("ResNet50Conv5ROIFeatureExtractor") @registry.ROI_BOX_FEATURE_EXTRACTORS.register("ResNet50Conv5ROIFeatureExtractor")
...@@ -60,15 +63,10 @@ class FPN2MLPFeatureExtractor(nn.Module): ...@@ -60,15 +63,10 @@ class FPN2MLPFeatureExtractor(nn.Module):
) )
input_size = cfg.MODEL.BACKBONE.OUT_CHANNELS * resolution ** 2 input_size = cfg.MODEL.BACKBONE.OUT_CHANNELS * resolution ** 2
representation_size = cfg.MODEL.ROI_BOX_HEAD.MLP_HEAD_DIM representation_size = cfg.MODEL.ROI_BOX_HEAD.MLP_HEAD_DIM
use_gn = cfg.MODEL.ROI_BOX_HEAD.USE_GN
self.pooler = pooler self.pooler = pooler
self.fc6 = nn.Linear(input_size, representation_size) self.fc6 = make_fc(input_size, representation_size, use_gn)
self.fc7 = nn.Linear(representation_size, representation_size) self.fc7 = make_fc(representation_size, representation_size, use_gn)
for l in [self.fc6, self.fc7]:
# Caffe2 implementation uses XavierFill, which in fact
# corresponds to kaiming_uniform_ in PyTorch
nn.init.kaiming_uniform_(l.weight, a=1)
nn.init.constant_(l.bias, 0)
def forward(self, x, proposals): def forward(self, x, proposals):
x = self.pooler(x, proposals) x = self.pooler(x, proposals)
...@@ -80,6 +78,69 @@ class FPN2MLPFeatureExtractor(nn.Module): ...@@ -80,6 +78,69 @@ class FPN2MLPFeatureExtractor(nn.Module):
return x return x
@registry.ROI_BOX_FEATURE_EXTRACTORS.register("FPNXconv1fcFeatureExtractor")
class FPNXconv1fcFeatureExtractor(nn.Module):
"""
Heads for FPN for classification
"""
def __init__(self, cfg):
super(FPNXconv1fcFeatureExtractor, self).__init__()
resolution = cfg.MODEL.ROI_BOX_HEAD.POOLER_RESOLUTION
scales = cfg.MODEL.ROI_BOX_HEAD.POOLER_SCALES
sampling_ratio = cfg.MODEL.ROI_BOX_HEAD.POOLER_SAMPLING_RATIO
pooler = Pooler(
output_size=(resolution, resolution),
scales=scales,
sampling_ratio=sampling_ratio,
)
self.pooler = pooler
use_gn = cfg.MODEL.ROI_BOX_HEAD.USE_GN
in_channels = cfg.MODEL.BACKBONE.OUT_CHANNELS
conv_head_dim = cfg.MODEL.ROI_BOX_HEAD.CONV_HEAD_DIM
num_stacked_convs = cfg.MODEL.ROI_BOX_HEAD.NUM_STACKED_CONVS
dilation = cfg.MODEL.ROI_BOX_HEAD.DILATION
xconvs = []
for ix in range(num_stacked_convs):
xconvs.append(
nn.Conv2d(
in_channels,
conv_head_dim,
kernel_size=3,
stride=1,
padding=dilation,
dilation=dilation,
bias=False if use_gn else True
)
)
in_channels = conv_head_dim
if use_gn:
xconvs.append(group_norm(in_channels))
xconvs.append(nn.ReLU(inplace=True))
self.add_module("xconvs", nn.Sequential(*xconvs))
for modules in [self.xconvs,]:
for l in modules.modules():
if isinstance(l, nn.Conv2d):
torch.nn.init.normal_(l.weight, std=0.01)
if not use_gn:
torch.nn.init.constant_(l.bias, 0)
input_size = conv_head_dim * resolution ** 2
representation_size = cfg.MODEL.ROI_BOX_HEAD.MLP_HEAD_DIM
self.fc6 = make_fc(input_size, representation_size, use_gn)
def forward(self, x, proposals):
x = self.pooler(x, proposals)
x = self.xconvs(x)
x = x.view(x.size(0), -1)
x = F.relu(self.fc6(x))
return x
def make_roi_box_feature_extractor(cfg): def make_roi_box_feature_extractor(cfg):
func = registry.ROI_BOX_FEATURE_EXTRACTORS[ func = registry.ROI_BOX_FEATURE_EXTRACTORS[
cfg.MODEL.ROI_BOX_HEAD.FEATURE_EXTRACTOR cfg.MODEL.ROI_BOX_HEAD.FEATURE_EXTRACTOR
......
...@@ -5,6 +5,8 @@ from torch.nn import functional as F ...@@ -5,6 +5,8 @@ from torch.nn import functional as F
from ..box_head.roi_box_feature_extractors import ResNet50Conv5ROIFeatureExtractor from ..box_head.roi_box_feature_extractors import ResNet50Conv5ROIFeatureExtractor
from maskrcnn_benchmark.modeling.poolers import Pooler from maskrcnn_benchmark.modeling.poolers import Pooler
from maskrcnn_benchmark.layers import Conv2d from maskrcnn_benchmark.layers import Conv2d
from maskrcnn_benchmark.modeling.make_layers import make_conv3x3
class MaskRCNNFPNFeatureExtractor(nn.Module): class MaskRCNNFPNFeatureExtractor(nn.Module):
...@@ -32,17 +34,17 @@ class MaskRCNNFPNFeatureExtractor(nn.Module): ...@@ -32,17 +34,17 @@ class MaskRCNNFPNFeatureExtractor(nn.Module):
input_size = cfg.MODEL.BACKBONE.OUT_CHANNELS input_size = cfg.MODEL.BACKBONE.OUT_CHANNELS
self.pooler = pooler self.pooler = pooler
use_gn = cfg.MODEL.ROI_MASK_HEAD.USE_GN
layers = cfg.MODEL.ROI_MASK_HEAD.CONV_LAYERS layers = cfg.MODEL.ROI_MASK_HEAD.CONV_LAYERS
dilation = cfg.MODEL.ROI_MASK_HEAD.DILATION
next_feature = input_size next_feature = input_size
self.blocks = [] self.blocks = []
for layer_idx, layer_features in enumerate(layers, 1): for layer_idx, layer_features in enumerate(layers, 1):
layer_name = "mask_fcn{}".format(layer_idx) layer_name = "mask_fcn{}".format(layer_idx)
module = Conv2d(next_feature, layer_features, 3, stride=1, padding=1) module = make_conv3x3(next_feature, layer_features,
# Caffe2 implementation uses MSRAFill, which in fact dilation=dilation, stride=1, use_gn=use_gn
# corresponds to kaiming_normal_ in PyTorch )
nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu")
nn.init.constant_(module.bias, 0)
self.add_module(layer_name, module) self.add_module(layer_name, module)
next_feature = layer_features next_feature = layer_features
self.blocks.append(layer_name) self.blocks.append(layer_name)
......
...@@ -47,6 +47,18 @@ def _rename_basic_resnet_weights(layer_keys): ...@@ -47,6 +47,18 @@ def _rename_basic_resnet_weights(layer_keys):
layer_keys = [k.replace(".branch1.", ".downsample.0.") for k in layer_keys] layer_keys = [k.replace(".branch1.", ".downsample.0.") for k in layer_keys]
layer_keys = [k.replace(".branch1_bn.", ".downsample.1.") for k in layer_keys] layer_keys = [k.replace(".branch1_bn.", ".downsample.1.") for k in layer_keys]
# GroupNorm
layer_keys = [k.replace("conv1.gn.s", "bn1.weight") for k in layer_keys]
layer_keys = [k.replace("conv1.gn.bias", "bn1.bias") for k in layer_keys]
layer_keys = [k.replace("conv2.gn.s", "bn2.weight") for k in layer_keys]
layer_keys = [k.replace("conv2.gn.bias", "bn2.bias") for k in layer_keys]
layer_keys = [k.replace("conv3.gn.s", "bn3.weight") for k in layer_keys]
layer_keys = [k.replace("conv3.gn.bias", "bn3.bias") for k in layer_keys]
layer_keys = [k.replace("downsample.0.gn.s", "downsample.1.weight") \
for k in layer_keys]
layer_keys = [k.replace("downsample.0.gn.bias", "downsample.1.bias") \
for k in layer_keys]
return layer_keys return layer_keys
def _rename_fpn_weights(layer_keys, stage_names): def _rename_fpn_weights(layer_keys, stage_names):
...@@ -140,12 +152,15 @@ C2_FORMAT_LOADER = Registry() ...@@ -140,12 +152,15 @@ C2_FORMAT_LOADER = Registry()
@C2_FORMAT_LOADER.register("R-50-C4") @C2_FORMAT_LOADER.register("R-50-C4")
@C2_FORMAT_LOADER.register("R-50-C5")
@C2_FORMAT_LOADER.register("R-101-C4")
@C2_FORMAT_LOADER.register("R-101-C5")
@C2_FORMAT_LOADER.register("R-50-FPN") @C2_FORMAT_LOADER.register("R-50-FPN")
@C2_FORMAT_LOADER.register("R-101-FPN") @C2_FORMAT_LOADER.register("R-101-FPN")
def load_resnet_c2_format(cfg, f): def load_resnet_c2_format(cfg, f):
state_dict = _load_c2_pickled_weights(f) state_dict = _load_c2_pickled_weights(f)
conv_body = cfg.MODEL.BACKBONE.CONV_BODY conv_body = cfg.MODEL.BACKBONE.CONV_BODY
arch = conv_body.replace("-C4", "").replace("-FPN", "") arch = conv_body.replace("-C4", "").replace("-C5", "").replace("-FPN", "")
stages = _C2_STAGE_NAMES[arch] stages = _C2_STAGE_NAMES[arch]
state_dict = _rename_weights_for_resnet(state_dict, stages) state_dict = _rename_weights_for_resnet(state_dict, stages)
return dict(model=state_dict) return dict(model=state_dict)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment