Commit c7692eb7 authored by Kaiming He's avatar Kaiming He Committed by Facebook Github Bot

add group norm

Summary: Add GroupNorm support to master Detectron.

Reviewed By: rbgirshick

Differential Revision: D7611892

fbshipit-source-id: dc4fb84a0e2167b05fd8a94ee0ff1ab1c21369b7
parent 0fda5f9a
MODEL:
TYPE: generalized_rcnn
CONV_BODY: FPN.add_fpn_ResNet101_conv5_body
NUM_CLASSES: 81
FASTER_RCNN: True
MASK_ON: True
NUM_GPUS: 8
SOLVER:
WEIGHT_DECAY: 0.0001
LR_POLICY: steps_with_decay
BASE_LR: 0.02
GAMMA: 0.1
MAX_ITER: 180000
STEPS: [0, 120000, 160000]
FPN:
FPN_ON: True
MULTILEVEL_ROIS: True
MULTILEVEL_RPN: True
USE_GN: True # Note: use GN on the FPN-specific layers
RESNETS:
STRIDE_1X1: False # default True for MSRA; False for C2 or Torch models
TRANS_FUNC: bottleneck_gn_transformation # Note: this is a GN bottleneck transform
STEM_FUNC: basic_gn_stem # Note: this is a GN stem
SHORTCUT_FUNC: basic_gn_shortcut # Note: this is a GN shortcut
FAST_RCNN:
ROI_BOX_HEAD: fast_rcnn_heads.add_roi_Xconv1fc_gn_head # Note: this is a Conv GN head
ROI_XFORM_METHOD: RoIAlign
ROI_XFORM_RESOLUTION: 7
ROI_XFORM_SAMPLING_RATIO: 2
MRCNN:
ROI_MASK_HEAD: mask_rcnn_heads.mask_rcnn_fcn_head_v1up4convs_gn # Note: this is a GN mask head
RESOLUTION: 28 # (output mask resolution) default 14
ROI_XFORM_METHOD: RoIAlign
ROI_XFORM_RESOLUTION: 14 # default 7
ROI_XFORM_SAMPLING_RATIO: 2 # default 0
DILATION: 1 # default 2
CONV_INIT: MSRAFill # default GaussianFill
TRAIN:
WEIGHTS: https://s3-us-west-2.amazonaws.com/detectron/ImageNetPretrained/47592356/R-101-GN.pkl # Note: a GN pre-trained model
DATASETS: ('coco_2014_train', 'coco_2014_valminusminival')
SCALES: (800,)
MAX_SIZE: 1333
BATCH_SIZE_PER_IM: 512
RPN_PRE_NMS_TOP_N: 2000 # Per FPN level
TEST:
DATASETS: ('coco_2014_minival',)
SCALE: 800
MAX_SIZE: 1333
NMS: 0.5
RPN_PRE_NMS_TOP_N: 1000 # Per FPN level
RPN_POST_NMS_TOP_N: 1000
OUTPUT_DIR: .
MODEL:
TYPE: generalized_rcnn
CONV_BODY: FPN.add_fpn_ResNet101_conv5_body
NUM_CLASSES: 81
FASTER_RCNN: True
MASK_ON: True
NUM_GPUS: 8
SOLVER:
WEIGHT_DECAY: 0.0001
LR_POLICY: steps_with_decay
BASE_LR: 0.02
GAMMA: 0.1
MAX_ITER: 270000
STEPS: [0, 210000, 250000]
FPN:
FPN_ON: True
MULTILEVEL_ROIS: True
MULTILEVEL_RPN: True
USE_GN: True # Note: use GN on the FPN-specific layers
RESNETS:
STRIDE_1X1: False # default True for MSRA; False for C2 or Torch models
TRANS_FUNC: bottleneck_gn_transformation # Note: this is a GN bottleneck transform
STEM_FUNC: basic_gn_stem # Note: this is a GN stem
SHORTCUT_FUNC: basic_gn_shortcut # Note: this is a GN shortcut
FAST_RCNN:
ROI_BOX_HEAD: fast_rcnn_heads.add_roi_Xconv1fc_gn_head # Note: this is a Conv GN head
ROI_XFORM_METHOD: RoIAlign
ROI_XFORM_RESOLUTION: 7
ROI_XFORM_SAMPLING_RATIO: 2
MRCNN:
ROI_MASK_HEAD: mask_rcnn_heads.mask_rcnn_fcn_head_v1up4convs_gn # Note: this is a GN mask head
RESOLUTION: 28 # (output mask resolution) default 14
ROI_XFORM_METHOD: RoIAlign
ROI_XFORM_RESOLUTION: 14 # default 7
ROI_XFORM_SAMPLING_RATIO: 2 # default 0
DILATION: 1 # default 2
CONV_INIT: MSRAFill # default GaussianFill
TRAIN:
WEIGHTS: https://s3-us-west-2.amazonaws.com/detectron/ImageNetPretrained/47592356/R-101-GN.pkl # Note: a GN pre-trained model
DATASETS: ('coco_2014_train', 'coco_2014_valminusminival')
SCALES: (800,)
MAX_SIZE: 1333
BATCH_SIZE_PER_IM: 512
RPN_PRE_NMS_TOP_N: 2000 # Per FPN level
TEST:
DATASETS: ('coco_2014_minival',)
SCALE: 800
MAX_SIZE: 1333
NMS: 0.5
RPN_PRE_NMS_TOP_N: 1000 # Per FPN level
RPN_POST_NMS_TOP_N: 1000
OUTPUT_DIR: .
MODEL:
TYPE: generalized_rcnn
CONV_BODY: FPN.add_fpn_ResNet50_conv5_body
NUM_CLASSES: 81
FASTER_RCNN: True
MASK_ON: True
NUM_GPUS: 8
SOLVER:
WEIGHT_DECAY: 0.0001
LR_POLICY: steps_with_decay
BASE_LR: 0.02
GAMMA: 0.1
MAX_ITER: 180000
STEPS: [0, 120000, 160000]
FPN:
FPN_ON: True
MULTILEVEL_ROIS: True
MULTILEVEL_RPN: True
USE_GN: True # Note: use GN on the FPN-specific layers
RESNETS:
STRIDE_1X1: False # default True for MSRA; False for C2 or Torch models
TRANS_FUNC: bottleneck_gn_transformation # Note: this is a GN bottleneck transform
STEM_FUNC: basic_gn_stem # Note: this is a GN stem
SHORTCUT_FUNC: basic_gn_shortcut # Note: this is a GN shortcut
FAST_RCNN:
ROI_BOX_HEAD: fast_rcnn_heads.add_roi_Xconv1fc_gn_head # Note: this is a Conv GN head
ROI_XFORM_METHOD: RoIAlign
ROI_XFORM_RESOLUTION: 7
ROI_XFORM_SAMPLING_RATIO: 2
MRCNN:
ROI_MASK_HEAD: mask_rcnn_heads.mask_rcnn_fcn_head_v1up4convs_gn # Note: this is a GN mask head
RESOLUTION: 28 # (output mask resolution) default 14
ROI_XFORM_METHOD: RoIAlign
ROI_XFORM_RESOLUTION: 14 # default 7
ROI_XFORM_SAMPLING_RATIO: 2 # default 0
DILATION: 1 # default 2
CONV_INIT: MSRAFill # default GaussianFill
TRAIN:
WEIGHTS: https://s3-us-west-2.amazonaws.com/detectron/ImageNetPretrained/47261647/R-50-GN.pkl # Note: a GN pre-trained model
DATASETS: ('coco_2014_train', 'coco_2014_valminusminival')
SCALES: (800,)
MAX_SIZE: 1333
BATCH_SIZE_PER_IM: 512
RPN_PRE_NMS_TOP_N: 2000 # Per FPN level
TEST:
DATASETS: ('coco_2014_minival',)
SCALE: 800
MAX_SIZE: 1333
NMS: 0.5
RPN_PRE_NMS_TOP_N: 1000 # Per FPN level
RPN_POST_NMS_TOP_N: 1000
OUTPUT_DIR: .
MODEL:
TYPE: generalized_rcnn
CONV_BODY: FPN.add_fpn_ResNet50_conv5_body
NUM_CLASSES: 81
FASTER_RCNN: True
MASK_ON: True
NUM_GPUS: 8
SOLVER:
WEIGHT_DECAY: 0.0001
LR_POLICY: steps_with_decay
BASE_LR: 0.02
GAMMA: 0.1
MAX_ITER: 270000
STEPS: [0, 210000, 250000]
FPN:
FPN_ON: True
MULTILEVEL_ROIS: True
MULTILEVEL_RPN: True
USE_GN: True # Note: use GN on the FPN-specific layers
RESNETS:
STRIDE_1X1: False # default True for MSRA; False for C2 or Torch models
TRANS_FUNC: bottleneck_gn_transformation # Note: this is a GN bottleneck transform
STEM_FUNC: basic_gn_stem # Note: this is a GN stem
SHORTCUT_FUNC: basic_gn_shortcut # Note: this is a GN shortcut
FAST_RCNN:
ROI_BOX_HEAD: fast_rcnn_heads.add_roi_Xconv1fc_gn_head # Note: this is a Conv GN head
ROI_XFORM_METHOD: RoIAlign
ROI_XFORM_RESOLUTION: 7
ROI_XFORM_SAMPLING_RATIO: 2
MRCNN:
ROI_MASK_HEAD: mask_rcnn_heads.mask_rcnn_fcn_head_v1up4convs_gn # Note: this is a GN mask head
RESOLUTION: 28 # (output mask resolution) default 14
ROI_XFORM_METHOD: RoIAlign
ROI_XFORM_RESOLUTION: 14 # default 7
ROI_XFORM_SAMPLING_RATIO: 2 # default 0
DILATION: 1 # default 2
CONV_INIT: MSRAFill # default GaussianFill
TRAIN:
WEIGHTS: https://s3-us-west-2.amazonaws.com/detectron/ImageNetPretrained/47261647/R-50-GN.pkl # Note: a GN pre-trained model
DATASETS: ('coco_2014_train', 'coco_2014_valminusminival')
SCALES: (800,)
MAX_SIZE: 1333
BATCH_SIZE_PER_IM: 512
RPN_PRE_NMS_TOP_N: 2000 # Per FPN level
TEST:
DATASETS: ('coco_2014_minival',)
SCALE: 800
MAX_SIZE: 1333
NMS: 0.5
RPN_PRE_NMS_TOP_N: 1000 # Per FPN level
RPN_POST_NMS_TOP_N: 1000
OUTPUT_DIR: .
# WARNING: this script uses **pre-computed** BN-based proposals, and is for quick debugging only.
MODEL:
TYPE: generalized_rcnn
CONV_BODY: FPN.add_fpn_ResNet50_conv5_body
NUM_CLASSES: 81
MASK_ON: True
NUM_GPUS: 8
SOLVER:
WEIGHT_DECAY: 0.0001
LR_POLICY: steps_with_decay
BASE_LR: 0.02
GAMMA: 0.1
MAX_ITER: 90000
STEPS: [0, 60000, 80000]
FPN:
FPN_ON: True
MULTILEVEL_ROIS: True
MULTILEVEL_RPN: True
USE_GN: True # Note: use GN on the FPN-specific layers
RESNETS:
STRIDE_1X1: False # default True for MSRA; False for C2 or Torch models
TRANS_FUNC: bottleneck_gn_transformation # Note: this is a GN bottleneck transform
STEM_FUNC: basic_gn_stem # Note: this is a GN stem
SHORTCUT_FUNC: basic_gn_shortcut # Note: this is a GN shortcut
FAST_RCNN:
ROI_BOX_HEAD: fast_rcnn_heads.add_roi_Xconv1fc_gn_head # Note: this is a Conv GN head
ROI_XFORM_METHOD: RoIAlign
ROI_XFORM_RESOLUTION: 7
ROI_XFORM_SAMPLING_RATIO: 2
MRCNN:
ROI_MASK_HEAD: mask_rcnn_heads.mask_rcnn_fcn_head_v1up4convs_gn # Note: this is a GN mask head
RESOLUTION: 28 # (output mask resolution) default 14
ROI_XFORM_METHOD: RoIAlign
ROI_XFORM_RESOLUTION: 14 # default 7
ROI_XFORM_SAMPLING_RATIO: 2 # default 0
DILATION: 1 # default 2
CONV_INIT: MSRAFill # default GaussianFill
TRAIN:
WEIGHTS: https://s3-us-west-2.amazonaws.com/detectron/ImageNetPretrained/47261647/R-50-GN.pkl # Note: a GN pre-trained model
DATASETS: ('coco_2014_train', 'coco_2014_valminusminival')
PROPOSAL_FILES: ('https://s3-us-west-2.amazonaws.com/detectron/35998814/12_2017_baselines/rpn_R-50-FPN_1x.yaml.08_06_03.Axg0r179/output/test/coco_2014_train/generalized_rcnn/rpn_proposals.pkl', 'https://s3-us-west-2.amazonaws.com/detectron/35998814/12_2017_baselines/rpn_R-50-FPN_1x.yaml.08_06_03.Axg0r179/output/test/coco_2014_valminusminival/generalized_rcnn/rpn_proposals.pkl')
SCALES: (800,)
MAX_SIZE: 1333
BATCH_SIZE_PER_IM: 512
TEST:
DATASETS: ('coco_2014_minival',)
PROPOSAL_FILES: ('https://s3-us-west-2.amazonaws.com/detectron/35998814/12_2017_baselines/rpn_R-50-FPN_1x.yaml.08_06_03.Axg0r179/output/test/coco_2014_minival/generalized_rcnn/rpn_proposals.pkl',)
PROPOSAL_LIMIT: 1000
SCALE: 800
MAX_SIZE: 1333
NMS: 0.5
OUTPUT_DIR: .
MODEL:
TYPE: generalized_rcnn
CONV_BODY: FPN.add_fpn_ResNet101_conv5_body
NUM_CLASSES: 81
FASTER_RCNN: True
MASK_ON: True
NUM_GPUS: 8
SOLVER:
WEIGHT_DECAY: 0.0001
LR_POLICY: steps_with_decay
BASE_LR: 0.02
GAMMA: 0.1
MAX_ITER: 270000
STEPS: [0, 210000, 250000]
FPN:
FPN_ON: True
MULTILEVEL_ROIS: True
MULTILEVEL_RPN: True
USE_GN: True # Note: use GN on the FPN-specific layers
RESNETS:
STRIDE_1X1: False # default True for MSRA; False for C2 or Torch models
TRANS_FUNC: bottleneck_gn_transformation # Note: this is a GN bottleneck transform
STEM_FUNC: basic_gn_stem # Note: this is a GN stem
SHORTCUT_FUNC: basic_gn_shortcut # Note: this is a GN shortcut
FAST_RCNN:
ROI_BOX_HEAD: fast_rcnn_heads.add_roi_Xconv1fc_gn_head # Note: this is a Conv GN head
ROI_XFORM_METHOD: RoIAlign
ROI_XFORM_RESOLUTION: 7
ROI_XFORM_SAMPLING_RATIO: 2
MRCNN:
ROI_MASK_HEAD: mask_rcnn_heads.mask_rcnn_fcn_head_v1up4convs_gn # Note: this is a GN mask head
RESOLUTION: 28 # (output mask resolution) default 14
ROI_XFORM_METHOD: RoIAlign
ROI_XFORM_RESOLUTION: 14 # default 7
ROI_XFORM_SAMPLING_RATIO: 2 # default 0
DILATION: 1 # default 2
CONV_INIT: MSRAFill # default GaussianFill
TRAIN:
# WEIGHTS: N/A
DATASETS: ('coco_2014_train', 'coco_2014_valminusminival')
SCALES: (800,)
MAX_SIZE: 1333
BATCH_SIZE_PER_IM: 512
RPN_PRE_NMS_TOP_N: 2000 # Per FPN level
TEST:
DATASETS: ('coco_2014_minival',)
SCALE: 800
MAX_SIZE: 1333
NMS: 0.5
RPN_PRE_NMS_TOP_N: 1000 # Per FPN level
RPN_POST_NMS_TOP_N: 1000
OUTPUT_DIR: .
MODEL:
TYPE: generalized_rcnn
CONV_BODY: FPN.add_fpn_ResNet50_conv5_body
NUM_CLASSES: 81
FASTER_RCNN: True
MASK_ON: True
NUM_GPUS: 8
SOLVER:
WEIGHT_DECAY: 0.0001
LR_POLICY: steps_with_decay
BASE_LR: 0.02
GAMMA: 0.1
MAX_ITER: 270000
STEPS: [0, 210000, 250000]
FPN:
FPN_ON: True
MULTILEVEL_ROIS: True
MULTILEVEL_RPN: True
USE_GN: True # Note: use GN on the FPN-specific layers
RESNETS:
STRIDE_1X1: False # default True for MSRA; False for C2 or Torch models
TRANS_FUNC: bottleneck_gn_transformation # Note: this is a GN bottleneck transform
STEM_FUNC: basic_gn_stem # Note: this is a GN stem
SHORTCUT_FUNC: basic_gn_shortcut # Note: this is a GN shortcut
FAST_RCNN:
ROI_BOX_HEAD: fast_rcnn_heads.add_roi_Xconv1fc_gn_head # Note: this is a Conv GN head
ROI_XFORM_METHOD: RoIAlign
ROI_XFORM_RESOLUTION: 7
ROI_XFORM_SAMPLING_RATIO: 2
MRCNN:
ROI_MASK_HEAD: mask_rcnn_heads.mask_rcnn_fcn_head_v1up4convs_gn # Note: this is a GN mask head
RESOLUTION: 28 # (output mask resolution) default 14
ROI_XFORM_METHOD: RoIAlign
ROI_XFORM_RESOLUTION: 14 # default 7
ROI_XFORM_SAMPLING_RATIO: 2 # default 0
DILATION: 1 # default 2
CONV_INIT: MSRAFill # default GaussianFill
TRAIN:
# WEIGHTS: N/A
DATASETS: ('coco_2014_train', 'coco_2014_valminusminival')
SCALES: (800,)
MAX_SIZE: 1333
BATCH_SIZE_PER_IM: 512
RPN_PRE_NMS_TOP_N: 2000 # Per FPN level
TEST:
DATASETS: ('coco_2014_minival',)
SCALE: 800
MAX_SIZE: 1333
NMS: 0.5
RPN_PRE_NMS_TOP_N: 1000 # Per FPN level
RPN_POST_NMS_TOP_N: 1000
OUTPUT_DIR: .
......@@ -592,6 +592,8 @@ __C.SOLVER.MOMENTUM = 0.9
# L2 regularization hyperparameter
__C.SOLVER.WEIGHT_DECAY = 0.0005
# L2 regularization hyperparameter for GroupNorm's parameters
__C.SOLVER.WEIGHT_DECAY_GN = 0.0
# Warm up to SOLVER.BASE_LR over this number of SGD iterations
__C.SOLVER.WARM_UP_ITERS = 500
......@@ -628,6 +630,11 @@ __C.FAST_RCNN.ROI_BOX_HEAD = b''
# Hidden layer dimension when using an MLP for the RoI box head
__C.FAST_RCNN.MLP_HEAD_DIM = 1024
# Hidden Conv layer dimension when using Convs for the RoI box head
__C.FAST_RCNN.CONV_HEAD_DIM = 256
# Number of stacked Conv layers in the RoI box head
__C.FAST_RCNN.NUM_STACKED_CONVS = 4
# RoI transformation function (e.g., RoIPool or RoIAlign)
# (RoIPoolF is the same as RoIPool; ignore the trailing 'F')
__C.FAST_RCNN.ROI_XFORM_METHOD = b'RoIPoolF'
......@@ -708,6 +715,8 @@ __C.FPN.RPN_ASPECT_RATIOS = (0.5, 1, 2)
__C.FPN.RPN_ANCHOR_START_SIZE = 32
# Use extra FPN levels, as done in the RetinaNet paper
__C.FPN.EXTRA_CONV_LEVELS = False
# Use GroupNorm in the FPN-specific layers (lateral, etc.)
__C.FPN.USE_GN = False
# ---------------------------------------------------------------------------- #
......@@ -863,11 +872,27 @@ __C.RESNETS.STRIDE_1X1 = True
# Residual transformation function
__C.RESNETS.TRANS_FUNC = b'bottleneck_transformation'
# ResNet's stem function (conv1 and pool1)
__C.RESNETS.STEM_FUNC = b'basic_bn_stem'
# ResNet's shortcut function
__C.RESNETS.SHORTCUT_FUNC = b'basic_bn_shortcut'
# Apply dilation in stage "res5"
__C.RESNETS.RES5_DILATION = 1
# ---------------------------------------------------------------------------- #
# GroupNorm options
# ---------------------------------------------------------------------------- #
__C.GROUP_NORM = AttrDict()
# Number of dimensions per group in GroupNorm (-1 if using NUM_GROUPS)
__C.GROUP_NORM.DIM_PER_GP = -1
# Number of groups in GroupNorm (-1 if using DIM_PER_GP)
__C.GROUP_NORM.NUM_GROUPS = 32
# GroupNorm's small constant in the denominator
__C.GROUP_NORM.EPSILON = 1e-5
# ---------------------------------------------------------------------------- #
# Misc options
# ---------------------------------------------------------------------------- #
......
......@@ -27,6 +27,7 @@ from core.config import cfg
from modeling.generate_anchors import generate_anchors
from utils.c2 import const_fill
from utils.c2 import gauss_fill
from utils.net import get_group_gn
import modeling.ResNet as ResNet
import utils.blob as blob_utils
import utils.boxes as box_utils
......@@ -138,18 +139,34 @@ def add_fpn(model, fpn_level_info):
fpn_dim_lateral = fpn_level_info.dims
xavier_fill = ('XavierFill', {})
# For the coarest backbone level: 1x1 conv only seeds recursion
model.Conv(
lateral_input_blobs[0],
output_blobs[0],
dim_in=fpn_dim_lateral[0],
dim_out=fpn_dim,
kernel=1,
pad=0,
stride=1,
weight_init=xavier_fill,
bias_init=const_fill(0.0)
)
# For the coarsest backbone level: 1x1 conv only seeds recursion
if cfg.FPN.USE_GN:
# use GroupNorm
c = model.ConvGN(
lateral_input_blobs[0],
output_blobs[0], # note: this is a prefix
dim_in=fpn_dim_lateral[0],
dim_out=fpn_dim,
group_gn=get_group_gn(fpn_dim),
kernel=1,
pad=0,
stride=1,
weight_init=xavier_fill,
bias_init=const_fill(0.0)
)
output_blobs[0] = c # rename it
else:
model.Conv(
lateral_input_blobs[0],
output_blobs[0],
dim_in=fpn_dim_lateral[0],
dim_out=fpn_dim,
kernel=1,
pad=0,
stride=1,
weight_init=xavier_fill,
bias_init=const_fill(0.0)
)
#
# Step 1: recursively build down starting from the coarsest backbone level
......@@ -170,17 +187,32 @@ def add_fpn(model, fpn_level_info):
blobs_fpn = []
spatial_scales = []
for i in range(num_backbone_stages):
fpn_blob = model.Conv(
output_blobs[i],
'fpn_{}'.format(fpn_level_info.blobs[i]),
dim_in=fpn_dim,
dim_out=fpn_dim,
kernel=3,
pad=1,
stride=1,
weight_init=xavier_fill,
bias_init=const_fill(0.0)
)
if cfg.FPN.USE_GN:
# use GroupNorm
fpn_blob = model.ConvGN(
output_blobs[i],
'fpn_{}'.format(fpn_level_info.blobs[i]),
dim_in=fpn_dim,
dim_out=fpn_dim,
group_gn=get_group_gn(fpn_dim),
kernel=3,
pad=1,
stride=1,
weight_init=xavier_fill,
bias_init=const_fill(0.0)
)
else:
fpn_blob = model.Conv(
output_blobs[i],
'fpn_{}'.format(fpn_level_info.blobs[i]),
dim_in=fpn_dim,
dim_out=fpn_dim,
kernel=3,
pad=1,
stride=1,
weight_init=xavier_fill,
bias_init=const_fill(0.0)
)
blobs_fpn += [fpn_blob]
spatial_scales += [fpn_level_info.spatial_scales[i]]
......@@ -229,20 +261,37 @@ def add_topdown_lateral_module(
):
"""Add a top-down lateral module."""
# Lateral 1x1 conv
lat = model.Conv(
fpn_lateral,
fpn_bottom + '_lateral',
dim_in=dim_lateral,
dim_out=dim_top,
kernel=1,
pad=0,
stride=1,
weight_init=(
const_fill(0.0) if cfg.FPN.ZERO_INIT_LATERAL
else ('XavierFill', {})
),
bias_init=const_fill(0.0)
)
if cfg.FPN.USE_GN:
# use GroupNorm
lat = model.ConvGN(
fpn_lateral,
fpn_bottom + '_lateral',
dim_in=dim_lateral,
dim_out=dim_top,
group_gn=get_group_gn(dim_top),
kernel=1,
pad=0,
stride=1,
weight_init=(
const_fill(0.0) if cfg.FPN.ZERO_INIT_LATERAL
else ('XavierFill', {})),
bias_init=const_fill(0.0)
)
else:
lat = model.Conv(
fpn_lateral,
fpn_bottom + '_lateral',
dim_in=dim_lateral,
dim_out=dim_top,
kernel=1,
pad=0,
stride=1,
weight_init=(
const_fill(0.0)
if cfg.FPN.ZERO_INIT_LATERAL else ('XavierFill', {})
),
bias_init=const_fill(0.0)
)
# Top-down 2x upsampling
td = model.net.UpsampleNearest(fpn_top, fpn_bottom + '_topdown', scale=2)
# Sum lateral and top-down
......
......@@ -24,6 +24,7 @@ from __future__ import print_function
from __future__ import unicode_literals
from core.config import cfg
from utils.net import get_group_gn
# ---------------------------------------------------------------------------- #
# Bits for specific architectures (ResNet50, ResNet101, ...)
......@@ -91,11 +92,10 @@ def add_ResNet_convX_body(model, block_counts, freeze_at=2):
The final res5/conv5 stage may be optionally excluded (hence convX, where
X = 4 or 5)."""
assert freeze_at in [0, 2, 3, 4, 5]
p = model.Conv('data', 'conv1', 3, 64, 7, pad=3, stride=2, no_bias=1)
p = model.AffineChannel(p, 'res_conv1_bn', dim=64, inplace=True)
p = model.Relu(p, p)
p = model.MaxPool(p, 'pool1', kernel=3, pad=1, stride=2)
dim_in = 64
# add the stem (by default, conv1 and pool1 with bn; can support gn)
p, dim_in = globals()[cfg.RESNETS.STEM_FUNC](model, 'data')
dim_bottleneck = cfg.RESNETS.NUM_GROUPS * cfg.RESNETS.WIDTH_PER_GROUP
(n1, n2, n3) = block_counts[:3]
s, dim_in = add_stage(model, 'res2', p, n1, dim_in, 256, dim_bottleneck, 1)
......@@ -182,6 +182,8 @@ def add_residual_block(
)
# sum -> ReLU
# shortcut function: by default using bn; support gn
add_shortcut = globals()[cfg.RESNETS.SHORTCUT_FUNC]
sc = add_shortcut(model, prefix, blob_in, dim_in, dim_out, stride)
if inplace_sum:
s = model.net.Sum([tr, sc], tr)
......@@ -191,7 +193,16 @@ def add_residual_block(
return model.Relu(s, s)
def add_shortcut(model, prefix, blob_in, dim_in, dim_out, stride):
# ------------------------------------------------------------------------------
# various shortcuts (may expand and may consider a new helper)
# ------------------------------------------------------------------------------
def basic_bn_shortcut(model, prefix, blob_in, dim_in, dim_out, stride):
""" For a pre-trained network that used BN. An AffineChannel op replaces BN
during fine-tuning.
"""
if dim_in == dim_out:
return blob_in
......@@ -207,6 +218,54 @@ def add_shortcut(model, prefix, blob_in, dim_in, dim_out, stride):
return model.AffineChannel(c, prefix + '_branch1_bn', dim=dim_out)
def basic_gn_shortcut(model, prefix, blob_in, dim_in, dim_out, stride):
if dim_in == dim_out:
return blob_in
# output name is prefix + '_branch1_gn'
return model.ConvGN(
blob_in,
prefix + '_branch1',
dim_in,
dim_out,
kernel=1,
group_gn=get_group_gn(dim_out),
stride=stride,
pad=0,
group=1,
)
# ------------------------------------------------------------------------------
# various stems (may expand and may consider a new helper)
# ------------------------------------------------------------------------------
def basic_bn_stem(model, data, **kwargs):
"""Add a basic ResNet stem. For a pre-trained network that used BN.
An AffineChannel op replaces BN during fine-tuning.
"""
dim = 64
p = model.Conv(data, 'conv1', 3, dim, 7, pad=3, stride=2, no_bias=1)
p = model.AffineChannel(p, 'res_conv1_bn', dim=dim, inplace=True)
p = model.Relu(p, p)
p = model.MaxPool(p, 'pool1', kernel=3, pad=1, stride=2)
return p, dim
def basic_gn_stem(model, data, **kwargs):
"""Add a basic ResNet stem (using GN)"""
dim = 64
p = model.ConvGN(
data, 'conv1', 3, dim, 7, group_gn=get_group_gn(dim), pad=3, stride=2
)
p = model.Relu(p, p)
p = model.MaxPool(p, 'pool1', kernel=3, pad=1, stride=2)
return p, dim
# ------------------------------------------------------------------------------
# various transformations (may expand and may consider a new helper)
# ------------------------------------------------------------------------------
......@@ -270,3 +329,61 @@ def bottleneck_transformation(
inplace=False
)
return cur
def bottleneck_gn_transformation(
model,
blob_in,
dim_in,
dim_out,
stride,
prefix,
dim_inner,
dilation=1,
group=1
):
"""Add a bottleneck transformation with GroupNorm to the model."""
# In original resnet, stride=2 is on 1x1.
# In fb.torch resnet, stride=2 is on 3x3.
(str1x1, str3x3) = (stride, 1) if cfg.RESNETS.STRIDE_1X1 else (1, stride)
# conv 1x1 -> GN -> ReLU
cur = model.ConvGN(
blob_in,
prefix + '_branch2a',
dim_in,
dim_inner,
kernel=1,
group_gn=get_group_gn(dim_inner),
stride=str1x1,
pad=0,
)
cur = model.Relu(cur, cur)
# conv 3x3 -> GN -> ReLU
cur = model.ConvGN(
cur,
prefix + '_branch2b',
dim_inner,
dim_inner,
kernel=3,
group_gn=get_group_gn(dim_inner),
stride=str3x3,
pad=1 * dilation,
dilation=dilation,
group=group,
)
cur = model.Relu(cur, cur)
# conv 1x1 -> GN (no ReLU)
cur = model.ConvGN(
cur,
prefix + '_branch2c',
dim_inner,
dim_out,
kernel=1,
group_gn=get_group_gn(dim_out),
stride=1,
pad=0,
)
return cur
......@@ -63,6 +63,7 @@ class DetectionModelHelper(cnn.CNNModelHelper):
self.net.Proto().type = cfg.MODEL.EXECUTION_TYPE
self.net.Proto().num_workers = cfg.NUM_GPUS * 4
self.prev_use_cudnn = self.use_cudnn
self.gn_params = [] # Param on this list are GroupNorm parameters
def TrainableParams(self, gpu_id=-1):
"""Get the blob names for all trainable parameters, possibly filtered by
......@@ -410,6 +411,48 @@ class DetectionModelHelper(cnn.CNNModelHelper):
)
return blob_out
def ConvGN( # args in the same order of Conv()
self, blob_in, prefix, dim_in, dim_out, kernel, stride, pad,
group_gn, # num of groups in gn
group=1, dilation=1,
weight_init=None,
bias_init=None,
suffix='_gn',
no_conv_bias=1,
):
"""ConvGN adds a Conv op followed by a GroupNorm op,
including learnable scale/bias (gamma/beta)
"""
conv_blob = self.Conv(
blob_in,
prefix,
dim_in,
dim_out,
kernel,
stride=stride,
pad=pad,
group=group,
dilation=dilation,
weight_init=weight_init,
bias_init=bias_init,
no_bias=no_conv_bias)
if group_gn < 1:
logger.warning(
'Layer: {} (dim {}): '
'group_gn < 1; reset to 1.'.format(prefix, dim_in)
)
group_gn = 1
blob_out = self.SpatialGN(
conv_blob, prefix + suffix,
dim_out, num_groups=group_gn,
epsilon=cfg.GROUP_NORM.EPSILON,)
self.gn_params.append(self.params[-1]) # add gn's bias to list
self.gn_params.append(self.params[-2]) # add gn's scale to list
return blob_out
def DisableCudnn(self):
self.prev_use_cudnn = self.use_cudnn
self.use_cudnn = False
......
......@@ -35,6 +35,7 @@ from __future__ import unicode_literals
from core.config import cfg
from utils.c2 import const_fill
from utils.c2 import gauss_fill
from utils.net import get_group_gn
import utils.blob as blob_utils
......@@ -109,3 +110,64 @@ def add_roi_2mlp_head(model, blob_in, dim_in, spatial_scale):
model.FC('fc6', 'fc7', hidden_dim, hidden_dim)
model.Relu('fc7', 'fc7')
return 'fc7', hidden_dim
def add_roi_Xconv1fc_head(model, blob_in, dim_in, spatial_scale):
"""Add a X conv + 1fc head, as a reference if not using GroupNorm"""
hidden_dim = cfg.FAST_RCNN.CONV_HEAD_DIM
roi_size = cfg.FAST_RCNN.ROI_XFORM_RESOLUTION
roi_feat = model.RoIFeatureTransform(
blob_in,
'roi_feat',
blob_rois='rois',
method=cfg.FAST_RCNN.ROI_XFORM_METHOD,
resolution=roi_size,
sampling_ratio=cfg.FAST_RCNN.ROI_XFORM_SAMPLING_RATIO,
spatial_scale=spatial_scale
)
current = roi_feat
for i in range(cfg.FAST_RCNN.NUM_STACKED_CONVS):
current = model.Conv(
current, 'head_conv' + str(i + 1), dim_in, hidden_dim, 3,
stride=1, pad=1,
weight_init=('MSRAFill', {}),
bias_init=('ConstantFill', {'value': 0.}),
no_bias=0)
current = model.Relu(current, current)
dim_in = hidden_dim
fc_dim = cfg.FAST_RCNN.MLP_HEAD_DIM
model.FC(current, 'fc6', dim_in * roi_size * roi_size, fc_dim)
model.Relu('fc6', 'fc6')
return 'fc6', fc_dim
def add_roi_Xconv1fc_gn_head(model, blob_in, dim_in, spatial_scale):
"""Add a X conv + 1fc head, with GroupNorm"""
hidden_dim = cfg.FAST_RCNN.CONV_HEAD_DIM
roi_size = cfg.FAST_RCNN.ROI_XFORM_RESOLUTION
roi_feat = model.RoIFeatureTransform(
blob_in, 'roi_feat',
blob_rois='rois',
method=cfg.FAST_RCNN.ROI_XFORM_METHOD,
resolution=roi_size,
sampling_ratio=cfg.FAST_RCNN.ROI_XFORM_SAMPLING_RATIO,
spatial_scale=spatial_scale
)
current = roi_feat
for i in range(cfg.FAST_RCNN.NUM_STACKED_CONVS):
current = model.ConvGN(
current, 'head_conv' + str(i + 1), dim_in, hidden_dim, 3,
group_gn=get_group_gn(hidden_dim),
stride=1, pad=1,
weight_init=('MSRAFill', {}),
bias_init=('ConstantFill', {'value': 0.}))
current = model.Relu(current, current)
dim_in = hidden_dim
fc_dim = cfg.FAST_RCNN.MLP_HEAD_DIM
model.FC(current, 'fc6', dim_in * roi_size * roi_size, fc_dim)
model.Relu('fc6', 'fc6')
return 'fc6', fc_dim
......@@ -35,6 +35,7 @@ from __future__ import unicode_literals
from core.config import cfg
from utils.c2 import const_fill
from utils.c2 import gauss_fill
from utils.net import get_group_gn
import modeling.ResNet as ResNet
import utils.blob as blob_utils
......@@ -114,6 +115,13 @@ def mask_rcnn_fcn_head_v1up4convs(model, blob_in, dim_in, spatial_scale):
)
def mask_rcnn_fcn_head_v1up4convs_gn(model, blob_in, dim_in, spatial_scale):
"""v1up design: 4 * (conv 3x3), convT 2x2, with GroupNorm"""
return mask_rcnn_fcn_head_v1upXconvs_gn(
model, blob_in, dim_in, spatial_scale, 4
)
def mask_rcnn_fcn_head_v1up(model, blob_in, dim_in, spatial_scale):
"""v1up design: 2 * (conv 3x3), convT 2x2."""
return mask_rcnn_fcn_head_v1upXconvs(
......@@ -170,6 +178,56 @@ def mask_rcnn_fcn_head_v1upXconvs(
return blob_mask, dim_inner
def mask_rcnn_fcn_head_v1upXconvs_gn(
model, blob_in, dim_in, spatial_scale, num_convs
):
"""v1upXconvs design: X * (conv 3x3), convT 2x2, with GroupNorm"""
current = model.RoIFeatureTransform(
blob_in,
blob_out='_mask_roi_feat',
blob_rois='mask_rois',
method=cfg.MRCNN.ROI_XFORM_METHOD,
resolution=cfg.MRCNN.ROI_XFORM_RESOLUTION,
sampling_ratio=cfg.MRCNN.ROI_XFORM_SAMPLING_RATIO,
spatial_scale=spatial_scale
)
dilation = cfg.MRCNN.DILATION
dim_inner = cfg.MRCNN.DIM_REDUCED
for i in range(num_convs):
current = model.ConvGN(
current,
'_mask_fcn' + str(i + 1),
dim_in,
dim_inner,
group_gn=get_group_gn(dim_inner),
kernel=3,
pad=1 * dilation,
stride=1,
weight_init=(cfg.MRCNN.CONV_INIT, {'std': 0.001}),
bias_init=('ConstantFill', {'value': 0.})
)
current = model.Relu(current, current)
dim_in = dim_inner
# upsample layer
model.ConvTranspose(
current,
'conv5_mask',
dim_inner,
dim_inner,
kernel=2,
pad=0,
stride=2,
weight_init=(cfg.MRCNN.CONV_INIT, {'std': 0.001}),
bias_init=const_fill(0.0)
)
blob_mask = model.Relu('conv5_mask', 'conv5_mask')
return blob_mask, dim_inner
def mask_rcnn_fcn_head_v0upshare(model, blob_in, dim_in, spatial_scale):
"""Use a ResNet "conv5" / "stage5" head for mask prediction. Weights and
computation are shared with the conv5 box head. Computation can only be
......
......@@ -99,6 +99,10 @@ def add_single_gpu_param_update_ops(model, gpu_id):
wd = model.param_init_net.ConstantFill(
[], 'wd', shape=[1], value=cfg.SOLVER.WEIGHT_DECAY
)
# weight decay of GroupNorm's parameters
wd_gn = model.param_init_net.ConstantFill(
[], 'wd_gn', shape=[1], value=cfg.SOLVER.WEIGHT_DECAY_GN
)
for param in model.TrainableParams(gpu_id=gpu_id):
logger.debug('param ' + str(param) + ' will be updated')
param_grad = model.param_to_grad[param]
......@@ -112,6 +116,9 @@ def add_single_gpu_param_update_ops(model, gpu_id):
# (1) Do not apply weight decay
# (2) Use a 2x higher learning rate
model.Scale(param_grad, param_grad, scale=2.0)
elif param in model.gn_params:
# Special treatment for GroupNorm's parameters
model.WeightedSum([param_grad, one, param, wd_gn], param_grad)
elif cfg.SOLVER.WEIGHT_DECAY > 0:
# Apply weight decay to non-bias weights
model.WeightedSum([param_grad, one, param, wd], param_grad)
......
......@@ -274,3 +274,22 @@ def configure_bbox_reg_weights(model, saved_cfg):
'longer be used for training. To upgrade it to a trainable model '
'please use fb/compat/convert_bbox_reg_normalized_model.py.'
)
def get_group_gn(dim):
"""
get number of groups used by GroupNorm, based on number of channels
"""
dim_per_gp = cfg.GROUP_NORM.DIM_PER_GP
num_groups = cfg.GROUP_NORM.NUM_GROUPS
assert dim_per_gp == -1 or num_groups == -1, \
"GroupNorm: can only specify G or C/G."
if dim_per_gp > 0:
assert dim % dim_per_gp == 0
group_gn = dim // dim_per_gp
else:
assert dim % num_groups == 0
group_gn = num_groups
return group_gn
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment