Commit c7692eb7 authored by Kaiming He's avatar Kaiming He Committed by Facebook Github Bot

add group norm

Summary: Add GroupNorm support to master Detectron.

Reviewed By: rbgirshick

Differential Revision: D7611892

fbshipit-source-id: dc4fb84a0e2167b05fd8a94ee0ff1ab1c21369b7
parent 0fda5f9a
MODEL:
TYPE: generalized_rcnn
CONV_BODY: FPN.add_fpn_ResNet101_conv5_body
NUM_CLASSES: 81
FASTER_RCNN: True
MASK_ON: True
NUM_GPUS: 8
SOLVER:
WEIGHT_DECAY: 0.0001
LR_POLICY: steps_with_decay
BASE_LR: 0.02
GAMMA: 0.1
MAX_ITER: 180000
STEPS: [0, 120000, 160000]
FPN:
FPN_ON: True
MULTILEVEL_ROIS: True
MULTILEVEL_RPN: True
USE_GN: True # Note: use GN on the FPN-specific layers
RESNETS:
STRIDE_1X1: False # default True for MSRA; False for C2 or Torch models
TRANS_FUNC: bottleneck_gn_transformation # Note: this is a GN bottleneck transform
STEM_FUNC: basic_gn_stem # Note: this is a GN stem
SHORTCUT_FUNC: basic_gn_shortcut # Note: this is a GN shortcut
FAST_RCNN:
ROI_BOX_HEAD: fast_rcnn_heads.add_roi_Xconv1fc_gn_head # Note: this is a Conv GN head
ROI_XFORM_METHOD: RoIAlign
ROI_XFORM_RESOLUTION: 7
ROI_XFORM_SAMPLING_RATIO: 2
MRCNN:
ROI_MASK_HEAD: mask_rcnn_heads.mask_rcnn_fcn_head_v1up4convs_gn # Note: this is a GN mask head
RESOLUTION: 28 # (output mask resolution) default 14
ROI_XFORM_METHOD: RoIAlign
ROI_XFORM_RESOLUTION: 14 # default 7
ROI_XFORM_SAMPLING_RATIO: 2 # default 0
DILATION: 1 # default 2
CONV_INIT: MSRAFill # default GaussianFill
TRAIN:
WEIGHTS: https://s3-us-west-2.amazonaws.com/detectron/ImageNetPretrained/47592356/R-101-GN.pkl # Note: a GN pre-trained model
DATASETS: ('coco_2014_train', 'coco_2014_valminusminival')
SCALES: (800,)
MAX_SIZE: 1333
BATCH_SIZE_PER_IM: 512
RPN_PRE_NMS_TOP_N: 2000 # Per FPN level
TEST:
DATASETS: ('coco_2014_minival',)
SCALE: 800
MAX_SIZE: 1333
NMS: 0.5
RPN_PRE_NMS_TOP_N: 1000 # Per FPN level
RPN_POST_NMS_TOP_N: 1000
OUTPUT_DIR: .
MODEL:
TYPE: generalized_rcnn
CONV_BODY: FPN.add_fpn_ResNet101_conv5_body
NUM_CLASSES: 81
FASTER_RCNN: True
MASK_ON: True
NUM_GPUS: 8
SOLVER:
WEIGHT_DECAY: 0.0001
LR_POLICY: steps_with_decay
BASE_LR: 0.02
GAMMA: 0.1
MAX_ITER: 270000
STEPS: [0, 210000, 250000]
FPN:
FPN_ON: True
MULTILEVEL_ROIS: True
MULTILEVEL_RPN: True
USE_GN: True # Note: use GN on the FPN-specific layers
RESNETS:
STRIDE_1X1: False # default True for MSRA; False for C2 or Torch models
TRANS_FUNC: bottleneck_gn_transformation # Note: this is a GN bottleneck transform
STEM_FUNC: basic_gn_stem # Note: this is a GN stem
SHORTCUT_FUNC: basic_gn_shortcut # Note: this is a GN shortcut
FAST_RCNN:
ROI_BOX_HEAD: fast_rcnn_heads.add_roi_Xconv1fc_gn_head # Note: this is a Conv GN head
ROI_XFORM_METHOD: RoIAlign
ROI_XFORM_RESOLUTION: 7
ROI_XFORM_SAMPLING_RATIO: 2
MRCNN:
ROI_MASK_HEAD: mask_rcnn_heads.mask_rcnn_fcn_head_v1up4convs_gn # Note: this is a GN mask head
RESOLUTION: 28 # (output mask resolution) default 14
ROI_XFORM_METHOD: RoIAlign
ROI_XFORM_RESOLUTION: 14 # default 7
ROI_XFORM_SAMPLING_RATIO: 2 # default 0
DILATION: 1 # default 2
CONV_INIT: MSRAFill # default GaussianFill
TRAIN:
WEIGHTS: https://s3-us-west-2.amazonaws.com/detectron/ImageNetPretrained/47592356/R-101-GN.pkl # Note: a GN pre-trained model
DATASETS: ('coco_2014_train', 'coco_2014_valminusminival')
SCALES: (800,)
MAX_SIZE: 1333
BATCH_SIZE_PER_IM: 512
RPN_PRE_NMS_TOP_N: 2000 # Per FPN level
TEST:
DATASETS: ('coco_2014_minival',)
SCALE: 800
MAX_SIZE: 1333
NMS: 0.5
RPN_PRE_NMS_TOP_N: 1000 # Per FPN level
RPN_POST_NMS_TOP_N: 1000
OUTPUT_DIR: .
MODEL:
TYPE: generalized_rcnn
CONV_BODY: FPN.add_fpn_ResNet50_conv5_body
NUM_CLASSES: 81
FASTER_RCNN: True
MASK_ON: True
NUM_GPUS: 8
SOLVER:
WEIGHT_DECAY: 0.0001
LR_POLICY: steps_with_decay
BASE_LR: 0.02
GAMMA: 0.1
MAX_ITER: 180000
STEPS: [0, 120000, 160000]
FPN:
FPN_ON: True
MULTILEVEL_ROIS: True
MULTILEVEL_RPN: True
USE_GN: True # Note: use GN on the FPN-specific layers
RESNETS:
STRIDE_1X1: False # default True for MSRA; False for C2 or Torch models
TRANS_FUNC: bottleneck_gn_transformation # Note: this is a GN bottleneck transform
STEM_FUNC: basic_gn_stem # Note: this is a GN stem
SHORTCUT_FUNC: basic_gn_shortcut # Note: this is a GN shortcut
FAST_RCNN:
ROI_BOX_HEAD: fast_rcnn_heads.add_roi_Xconv1fc_gn_head # Note: this is a Conv GN head
ROI_XFORM_METHOD: RoIAlign
ROI_XFORM_RESOLUTION: 7
ROI_XFORM_SAMPLING_RATIO: 2
MRCNN:
ROI_MASK_HEAD: mask_rcnn_heads.mask_rcnn_fcn_head_v1up4convs_gn # Note: this is a GN mask head
RESOLUTION: 28 # (output mask resolution) default 14
ROI_XFORM_METHOD: RoIAlign
ROI_XFORM_RESOLUTION: 14 # default 7
ROI_XFORM_SAMPLING_RATIO: 2 # default 0
DILATION: 1 # default 2
CONV_INIT: MSRAFill # default GaussianFill
TRAIN:
WEIGHTS: https://s3-us-west-2.amazonaws.com/detectron/ImageNetPretrained/47261647/R-50-GN.pkl # Note: a GN pre-trained model
DATASETS: ('coco_2014_train', 'coco_2014_valminusminival')
SCALES: (800,)
MAX_SIZE: 1333
BATCH_SIZE_PER_IM: 512
RPN_PRE_NMS_TOP_N: 2000 # Per FPN level
TEST:
DATASETS: ('coco_2014_minival',)
SCALE: 800
MAX_SIZE: 1333
NMS: 0.5
RPN_PRE_NMS_TOP_N: 1000 # Per FPN level
RPN_POST_NMS_TOP_N: 1000
OUTPUT_DIR: .
MODEL:
TYPE: generalized_rcnn
CONV_BODY: FPN.add_fpn_ResNet50_conv5_body
NUM_CLASSES: 81
FASTER_RCNN: True
MASK_ON: True
NUM_GPUS: 8
SOLVER:
WEIGHT_DECAY: 0.0001
LR_POLICY: steps_with_decay
BASE_LR: 0.02
GAMMA: 0.1
MAX_ITER: 270000
STEPS: [0, 210000, 250000]
FPN:
FPN_ON: True
MULTILEVEL_ROIS: True
MULTILEVEL_RPN: True
USE_GN: True # Note: use GN on the FPN-specific layers
RESNETS:
STRIDE_1X1: False # default True for MSRA; False for C2 or Torch models
TRANS_FUNC: bottleneck_gn_transformation # Note: this is a GN bottleneck transform
STEM_FUNC: basic_gn_stem # Note: this is a GN stem
SHORTCUT_FUNC: basic_gn_shortcut # Note: this is a GN shortcut
FAST_RCNN:
ROI_BOX_HEAD: fast_rcnn_heads.add_roi_Xconv1fc_gn_head # Note: this is a Conv GN head
ROI_XFORM_METHOD: RoIAlign
ROI_XFORM_RESOLUTION: 7
ROI_XFORM_SAMPLING_RATIO: 2
MRCNN:
ROI_MASK_HEAD: mask_rcnn_heads.mask_rcnn_fcn_head_v1up4convs_gn # Note: this is a GN mask head
RESOLUTION: 28 # (output mask resolution) default 14
ROI_XFORM_METHOD: RoIAlign
ROI_XFORM_RESOLUTION: 14 # default 7
ROI_XFORM_SAMPLING_RATIO: 2 # default 0
DILATION: 1 # default 2
CONV_INIT: MSRAFill # default GaussianFill
TRAIN:
WEIGHTS: https://s3-us-west-2.amazonaws.com/detectron/ImageNetPretrained/47261647/R-50-GN.pkl # Note: a GN pre-trained model
DATASETS: ('coco_2014_train', 'coco_2014_valminusminival')
SCALES: (800,)
MAX_SIZE: 1333
BATCH_SIZE_PER_IM: 512
RPN_PRE_NMS_TOP_N: 2000 # Per FPN level
TEST:
DATASETS: ('coco_2014_minival',)
SCALE: 800
MAX_SIZE: 1333
NMS: 0.5
RPN_PRE_NMS_TOP_N: 1000 # Per FPN level
RPN_POST_NMS_TOP_N: 1000
OUTPUT_DIR: .
# WARNING: this script uses **pre-computed** BN-based proposals, and is for quick debugging only.
MODEL:
TYPE: generalized_rcnn
CONV_BODY: FPN.add_fpn_ResNet50_conv5_body
NUM_CLASSES: 81
MASK_ON: True
NUM_GPUS: 8
SOLVER:
WEIGHT_DECAY: 0.0001
LR_POLICY: steps_with_decay
BASE_LR: 0.02
GAMMA: 0.1
MAX_ITER: 90000
STEPS: [0, 60000, 80000]
FPN:
FPN_ON: True
MULTILEVEL_ROIS: True
MULTILEVEL_RPN: True
USE_GN: True # Note: use GN on the FPN-specific layers
RESNETS:
STRIDE_1X1: False # default True for MSRA; False for C2 or Torch models
TRANS_FUNC: bottleneck_gn_transformation # Note: this is a GN bottleneck transform
STEM_FUNC: basic_gn_stem # Note: this is a GN stem
SHORTCUT_FUNC: basic_gn_shortcut # Note: this is a GN shortcut
FAST_RCNN:
ROI_BOX_HEAD: fast_rcnn_heads.add_roi_Xconv1fc_gn_head # Note: this is a Conv GN head
ROI_XFORM_METHOD: RoIAlign
ROI_XFORM_RESOLUTION: 7
ROI_XFORM_SAMPLING_RATIO: 2
MRCNN:
ROI_MASK_HEAD: mask_rcnn_heads.mask_rcnn_fcn_head_v1up4convs_gn # Note: this is a GN mask head
RESOLUTION: 28 # (output mask resolution) default 14
ROI_XFORM_METHOD: RoIAlign
ROI_XFORM_RESOLUTION: 14 # default 7
ROI_XFORM_SAMPLING_RATIO: 2 # default 0
DILATION: 1 # default 2
CONV_INIT: MSRAFill # default GaussianFill
TRAIN:
WEIGHTS: https://s3-us-west-2.amazonaws.com/detectron/ImageNetPretrained/47261647/R-50-GN.pkl # Note: a GN pre-trained model
DATASETS: ('coco_2014_train', 'coco_2014_valminusminival')
PROPOSAL_FILES: ('https://s3-us-west-2.amazonaws.com/detectron/35998814/12_2017_baselines/rpn_R-50-FPN_1x.yaml.08_06_03.Axg0r179/output/test/coco_2014_train/generalized_rcnn/rpn_proposals.pkl', 'https://s3-us-west-2.amazonaws.com/detectron/35998814/12_2017_baselines/rpn_R-50-FPN_1x.yaml.08_06_03.Axg0r179/output/test/coco_2014_valminusminival/generalized_rcnn/rpn_proposals.pkl')
SCALES: (800,)
MAX_SIZE: 1333
BATCH_SIZE_PER_IM: 512
TEST:
DATASETS: ('coco_2014_minival',)
PROPOSAL_FILES: ('https://s3-us-west-2.amazonaws.com/detectron/35998814/12_2017_baselines/rpn_R-50-FPN_1x.yaml.08_06_03.Axg0r179/output/test/coco_2014_minival/generalized_rcnn/rpn_proposals.pkl',)
PROPOSAL_LIMIT: 1000
SCALE: 800
MAX_SIZE: 1333
NMS: 0.5
OUTPUT_DIR: .
MODEL:
TYPE: generalized_rcnn
CONV_BODY: FPN.add_fpn_ResNet101_conv5_body
NUM_CLASSES: 81
FASTER_RCNN: True
MASK_ON: True
NUM_GPUS: 8
SOLVER:
WEIGHT_DECAY: 0.0001
LR_POLICY: steps_with_decay
BASE_LR: 0.02
GAMMA: 0.1
MAX_ITER: 270000
STEPS: [0, 210000, 250000]
FPN:
FPN_ON: True
MULTILEVEL_ROIS: True
MULTILEVEL_RPN: True
USE_GN: True # Note: use GN on the FPN-specific layers
RESNETS:
STRIDE_1X1: False # default True for MSRA; False for C2 or Torch models
TRANS_FUNC: bottleneck_gn_transformation # Note: this is a GN bottleneck transform
STEM_FUNC: basic_gn_stem # Note: this is a GN stem
SHORTCUT_FUNC: basic_gn_shortcut # Note: this is a GN shortcut
FAST_RCNN:
ROI_BOX_HEAD: fast_rcnn_heads.add_roi_Xconv1fc_gn_head # Note: this is a Conv GN head
ROI_XFORM_METHOD: RoIAlign
ROI_XFORM_RESOLUTION: 7
ROI_XFORM_SAMPLING_RATIO: 2
MRCNN:
ROI_MASK_HEAD: mask_rcnn_heads.mask_rcnn_fcn_head_v1up4convs_gn # Note: this is a GN mask head
RESOLUTION: 28 # (output mask resolution) default 14
ROI_XFORM_METHOD: RoIAlign
ROI_XFORM_RESOLUTION: 14 # default 7
ROI_XFORM_SAMPLING_RATIO: 2 # default 0
DILATION: 1 # default 2
CONV_INIT: MSRAFill # default GaussianFill
TRAIN:
# WEIGHTS: N/A
DATASETS: ('coco_2014_train', 'coco_2014_valminusminival')
SCALES: (800,)
MAX_SIZE: 1333
BATCH_SIZE_PER_IM: 512
RPN_PRE_NMS_TOP_N: 2000 # Per FPN level
TEST:
DATASETS: ('coco_2014_minival',)
SCALE: 800
MAX_SIZE: 1333
NMS: 0.5
RPN_PRE_NMS_TOP_N: 1000 # Per FPN level
RPN_POST_NMS_TOP_N: 1000
OUTPUT_DIR: .
MODEL:
TYPE: generalized_rcnn
CONV_BODY: FPN.add_fpn_ResNet50_conv5_body
NUM_CLASSES: 81
FASTER_RCNN: True
MASK_ON: True
NUM_GPUS: 8
SOLVER:
WEIGHT_DECAY: 0.0001
LR_POLICY: steps_with_decay
BASE_LR: 0.02
GAMMA: 0.1
MAX_ITER: 270000
STEPS: [0, 210000, 250000]
FPN:
FPN_ON: True
MULTILEVEL_ROIS: True
MULTILEVEL_RPN: True
USE_GN: True # Note: use GN on the FPN-specific layers
RESNETS:
STRIDE_1X1: False # default True for MSRA; False for C2 or Torch models
TRANS_FUNC: bottleneck_gn_transformation # Note: this is a GN bottleneck transform
STEM_FUNC: basic_gn_stem # Note: this is a GN stem
SHORTCUT_FUNC: basic_gn_shortcut # Note: this is a GN shortcut
FAST_RCNN:
ROI_BOX_HEAD: fast_rcnn_heads.add_roi_Xconv1fc_gn_head # Note: this is a Conv GN head
ROI_XFORM_METHOD: RoIAlign
ROI_XFORM_RESOLUTION: 7
ROI_XFORM_SAMPLING_RATIO: 2
MRCNN:
ROI_MASK_HEAD: mask_rcnn_heads.mask_rcnn_fcn_head_v1up4convs_gn # Note: this is a GN mask head
RESOLUTION: 28 # (output mask resolution) default 14
ROI_XFORM_METHOD: RoIAlign
ROI_XFORM_RESOLUTION: 14 # default 7
ROI_XFORM_SAMPLING_RATIO: 2 # default 0
DILATION: 1 # default 2
CONV_INIT: MSRAFill # default GaussianFill
TRAIN:
# WEIGHTS: N/A
DATASETS: ('coco_2014_train', 'coco_2014_valminusminival')
SCALES: (800,)
MAX_SIZE: 1333
BATCH_SIZE_PER_IM: 512
RPN_PRE_NMS_TOP_N: 2000 # Per FPN level
TEST:
DATASETS: ('coco_2014_minival',)
SCALE: 800
MAX_SIZE: 1333
NMS: 0.5
RPN_PRE_NMS_TOP_N: 1000 # Per FPN level
RPN_POST_NMS_TOP_N: 1000
OUTPUT_DIR: .
...@@ -592,6 +592,8 @@ __C.SOLVER.MOMENTUM = 0.9 ...@@ -592,6 +592,8 @@ __C.SOLVER.MOMENTUM = 0.9
# L2 regularization hyperparameter # L2 regularization hyperparameter
__C.SOLVER.WEIGHT_DECAY = 0.0005 __C.SOLVER.WEIGHT_DECAY = 0.0005
# L2 regularization hyperparameter for GroupNorm's parameters
__C.SOLVER.WEIGHT_DECAY_GN = 0.0
# Warm up to SOLVER.BASE_LR over this number of SGD iterations # Warm up to SOLVER.BASE_LR over this number of SGD iterations
__C.SOLVER.WARM_UP_ITERS = 500 __C.SOLVER.WARM_UP_ITERS = 500
...@@ -628,6 +630,11 @@ __C.FAST_RCNN.ROI_BOX_HEAD = b'' ...@@ -628,6 +630,11 @@ __C.FAST_RCNN.ROI_BOX_HEAD = b''
# Hidden layer dimension when using an MLP for the RoI box head # Hidden layer dimension when using an MLP for the RoI box head
__C.FAST_RCNN.MLP_HEAD_DIM = 1024 __C.FAST_RCNN.MLP_HEAD_DIM = 1024
# Hidden Conv layer dimension when using Convs for the RoI box head
__C.FAST_RCNN.CONV_HEAD_DIM = 256
# Number of stacked Conv layers in the RoI box head
__C.FAST_RCNN.NUM_STACKED_CONVS = 4
# RoI transformation function (e.g., RoIPool or RoIAlign) # RoI transformation function (e.g., RoIPool or RoIAlign)
# (RoIPoolF is the same as RoIPool; ignore the trailing 'F') # (RoIPoolF is the same as RoIPool; ignore the trailing 'F')
__C.FAST_RCNN.ROI_XFORM_METHOD = b'RoIPoolF' __C.FAST_RCNN.ROI_XFORM_METHOD = b'RoIPoolF'
...@@ -708,6 +715,8 @@ __C.FPN.RPN_ASPECT_RATIOS = (0.5, 1, 2) ...@@ -708,6 +715,8 @@ __C.FPN.RPN_ASPECT_RATIOS = (0.5, 1, 2)
__C.FPN.RPN_ANCHOR_START_SIZE = 32 __C.FPN.RPN_ANCHOR_START_SIZE = 32
# Use extra FPN levels, as done in the RetinaNet paper # Use extra FPN levels, as done in the RetinaNet paper
__C.FPN.EXTRA_CONV_LEVELS = False __C.FPN.EXTRA_CONV_LEVELS = False
# Use GroupNorm in the FPN-specific layers (lateral, etc.)
__C.FPN.USE_GN = False
# ---------------------------------------------------------------------------- # # ---------------------------------------------------------------------------- #
...@@ -863,11 +872,27 @@ __C.RESNETS.STRIDE_1X1 = True ...@@ -863,11 +872,27 @@ __C.RESNETS.STRIDE_1X1 = True
# Residual transformation function # Residual transformation function
__C.RESNETS.TRANS_FUNC = b'bottleneck_transformation' __C.RESNETS.TRANS_FUNC = b'bottleneck_transformation'
# ResNet's stem function (conv1 and pool1)
__C.RESNETS.STEM_FUNC = b'basic_bn_stem'
# ResNet's shortcut function
__C.RESNETS.SHORTCUT_FUNC = b'basic_bn_shortcut'
# Apply dilation in stage "res5" # Apply dilation in stage "res5"
__C.RESNETS.RES5_DILATION = 1 __C.RESNETS.RES5_DILATION = 1
# ---------------------------------------------------------------------------- #
# GroupNorm options
# ---------------------------------------------------------------------------- #
__C.GROUP_NORM = AttrDict()
# Number of dimensions per group in GroupNorm (-1 if using NUM_GROUPS)
__C.GROUP_NORM.DIM_PER_GP = -1
# Number of groups in GroupNorm (-1 if using DIM_PER_GP)
__C.GROUP_NORM.NUM_GROUPS = 32
# GroupNorm's small constant in the denominator
__C.GROUP_NORM.EPSILON = 1e-5
# ---------------------------------------------------------------------------- # # ---------------------------------------------------------------------------- #
# Misc options # Misc options
# ---------------------------------------------------------------------------- # # ---------------------------------------------------------------------------- #
......
...@@ -27,6 +27,7 @@ from core.config import cfg ...@@ -27,6 +27,7 @@ from core.config import cfg
from modeling.generate_anchors import generate_anchors from modeling.generate_anchors import generate_anchors
from utils.c2 import const_fill from utils.c2 import const_fill
from utils.c2 import gauss_fill from utils.c2 import gauss_fill
from utils.net import get_group_gn
import modeling.ResNet as ResNet import modeling.ResNet as ResNet
import utils.blob as blob_utils import utils.blob as blob_utils
import utils.boxes as box_utils import utils.boxes as box_utils
...@@ -138,7 +139,23 @@ def add_fpn(model, fpn_level_info): ...@@ -138,7 +139,23 @@ def add_fpn(model, fpn_level_info):
fpn_dim_lateral = fpn_level_info.dims fpn_dim_lateral = fpn_level_info.dims
xavier_fill = ('XavierFill', {}) xavier_fill = ('XavierFill', {})
# For the coarest backbone level: 1x1 conv only seeds recursion # For the coarsest backbone level: 1x1 conv only seeds recursion
if cfg.FPN.USE_GN:
# use GroupNorm
c = model.ConvGN(
lateral_input_blobs[0],
output_blobs[0], # note: this is a prefix
dim_in=fpn_dim_lateral[0],
dim_out=fpn_dim,
group_gn=get_group_gn(fpn_dim),
kernel=1,
pad=0,
stride=1,
weight_init=xavier_fill,
bias_init=const_fill(0.0)
)
output_blobs[0] = c # rename it
else:
model.Conv( model.Conv(
lateral_input_blobs[0], lateral_input_blobs[0],
output_blobs[0], output_blobs[0],
...@@ -170,6 +187,21 @@ def add_fpn(model, fpn_level_info): ...@@ -170,6 +187,21 @@ def add_fpn(model, fpn_level_info):
blobs_fpn = [] blobs_fpn = []
spatial_scales = [] spatial_scales = []
for i in range(num_backbone_stages): for i in range(num_backbone_stages):
if cfg.FPN.USE_GN:
# use GroupNorm
fpn_blob = model.ConvGN(
output_blobs[i],
'fpn_{}'.format(fpn_level_info.blobs[i]),
dim_in=fpn_dim,
dim_out=fpn_dim,
group_gn=get_group_gn(fpn_dim),
kernel=3,
pad=1,
stride=1,
weight_init=xavier_fill,
bias_init=const_fill(0.0)
)
else:
fpn_blob = model.Conv( fpn_blob = model.Conv(
output_blobs[i], output_blobs[i],
'fpn_{}'.format(fpn_level_info.blobs[i]), 'fpn_{}'.format(fpn_level_info.blobs[i]),
...@@ -229,17 +261,34 @@ def add_topdown_lateral_module( ...@@ -229,17 +261,34 @@ def add_topdown_lateral_module(
): ):
"""Add a top-down lateral module.""" """Add a top-down lateral module."""
# Lateral 1x1 conv # Lateral 1x1 conv
lat = model.Conv( if cfg.FPN.USE_GN:
# use GroupNorm
lat = model.ConvGN(
fpn_lateral, fpn_lateral,
fpn_bottom + '_lateral', fpn_bottom + '_lateral',
dim_in=dim_lateral, dim_in=dim_lateral,
dim_out=dim_top, dim_out=dim_top,
group_gn=get_group_gn(dim_top),
kernel=1, kernel=1,
pad=0, pad=0,
stride=1, stride=1,
weight_init=( weight_init=(
const_fill(0.0) if cfg.FPN.ZERO_INIT_LATERAL const_fill(0.0) if cfg.FPN.ZERO_INIT_LATERAL
else ('XavierFill', {}) else ('XavierFill', {})),
bias_init=const_fill(0.0)
)
else:
lat = model.Conv(
fpn_lateral,
fpn_bottom + '_lateral',
dim_in=dim_lateral,
dim_out=dim_top,
kernel=1,
pad=0,
stride=1,
weight_init=(
const_fill(0.0)
if cfg.FPN.ZERO_INIT_LATERAL else ('XavierFill', {})
), ),
bias_init=const_fill(0.0) bias_init=const_fill(0.0)
) )
......
...@@ -24,6 +24,7 @@ from __future__ import print_function ...@@ -24,6 +24,7 @@ from __future__ import print_function
from __future__ import unicode_literals from __future__ import unicode_literals
from core.config import cfg from core.config import cfg
from utils.net import get_group_gn
# ---------------------------------------------------------------------------- # # ---------------------------------------------------------------------------- #
# Bits for specific architectures (ResNet50, ResNet101, ...) # Bits for specific architectures (ResNet50, ResNet101, ...)
...@@ -91,11 +92,10 @@ def add_ResNet_convX_body(model, block_counts, freeze_at=2): ...@@ -91,11 +92,10 @@ def add_ResNet_convX_body(model, block_counts, freeze_at=2):
The final res5/conv5 stage may be optionally excluded (hence convX, where The final res5/conv5 stage may be optionally excluded (hence convX, where
X = 4 or 5).""" X = 4 or 5)."""
assert freeze_at in [0, 2, 3, 4, 5] assert freeze_at in [0, 2, 3, 4, 5]
p = model.Conv('data', 'conv1', 3, 64, 7, pad=3, stride=2, no_bias=1)
p = model.AffineChannel(p, 'res_conv1_bn', dim=64, inplace=True) # add the stem (by default, conv1 and pool1 with bn; can support gn)
p = model.Relu(p, p) p, dim_in = globals()[cfg.RESNETS.STEM_FUNC](model, 'data')
p = model.MaxPool(p, 'pool1', kernel=3, pad=1, stride=2)
dim_in = 64
dim_bottleneck = cfg.RESNETS.NUM_GROUPS * cfg.RESNETS.WIDTH_PER_GROUP dim_bottleneck = cfg.RESNETS.NUM_GROUPS * cfg.RESNETS.WIDTH_PER_GROUP
(n1, n2, n3) = block_counts[:3] (n1, n2, n3) = block_counts[:3]
s, dim_in = add_stage(model, 'res2', p, n1, dim_in, 256, dim_bottleneck, 1) s, dim_in = add_stage(model, 'res2', p, n1, dim_in, 256, dim_bottleneck, 1)
...@@ -182,6 +182,8 @@ def add_residual_block( ...@@ -182,6 +182,8 @@ def add_residual_block(
) )
# sum -> ReLU # sum -> ReLU
# shortcut function: by default using bn; support gn
add_shortcut = globals()[cfg.RESNETS.SHORTCUT_FUNC]
sc = add_shortcut(model, prefix, blob_in, dim_in, dim_out, stride) sc = add_shortcut(model, prefix, blob_in, dim_in, dim_out, stride)
if inplace_sum: if inplace_sum:
s = model.net.Sum([tr, sc], tr) s = model.net.Sum([tr, sc], tr)
...@@ -191,7 +193,16 @@ def add_residual_block( ...@@ -191,7 +193,16 @@ def add_residual_block(
return model.Relu(s, s) return model.Relu(s, s)
def add_shortcut(model, prefix, blob_in, dim_in, dim_out, stride): # ------------------------------------------------------------------------------
# various shortcuts (may expand and may consider a new helper)
# ------------------------------------------------------------------------------
def basic_bn_shortcut(model, prefix, blob_in, dim_in, dim_out, stride):
""" For a pre-trained network that used BN. An AffineChannel op replaces BN
during fine-tuning.
"""
if dim_in == dim_out: if dim_in == dim_out:
return blob_in return blob_in
...@@ -207,6 +218,54 @@ def add_shortcut(model, prefix, blob_in, dim_in, dim_out, stride): ...@@ -207,6 +218,54 @@ def add_shortcut(model, prefix, blob_in, dim_in, dim_out, stride):
return model.AffineChannel(c, prefix + '_branch1_bn', dim=dim_out) return model.AffineChannel(c, prefix + '_branch1_bn', dim=dim_out)
def basic_gn_shortcut(model, prefix, blob_in, dim_in, dim_out, stride):
if dim_in == dim_out:
return blob_in
# output name is prefix + '_branch1_gn'
return model.ConvGN(
blob_in,
prefix + '_branch1',
dim_in,
dim_out,
kernel=1,
group_gn=get_group_gn(dim_out),
stride=stride,
pad=0,
group=1,
)
# ------------------------------------------------------------------------------
# various stems (may expand and may consider a new helper)
# ------------------------------------------------------------------------------
def basic_bn_stem(model, data, **kwargs):
"""Add a basic ResNet stem. For a pre-trained network that used BN.
An AffineChannel op replaces BN during fine-tuning.
"""
dim = 64
p = model.Conv(data, 'conv1', 3, dim, 7, pad=3, stride=2, no_bias=1)
p = model.AffineChannel(p, 'res_conv1_bn', dim=dim, inplace=True)
p = model.Relu(p, p)
p = model.MaxPool(p, 'pool1', kernel=3, pad=1, stride=2)
return p, dim
def basic_gn_stem(model, data, **kwargs):
"""Add a basic ResNet stem (using GN)"""
dim = 64
p = model.ConvGN(
data, 'conv1', 3, dim, 7, group_gn=get_group_gn(dim), pad=3, stride=2
)
p = model.Relu(p, p)
p = model.MaxPool(p, 'pool1', kernel=3, pad=1, stride=2)
return p, dim
# ------------------------------------------------------------------------------ # ------------------------------------------------------------------------------
# various transformations (may expand and may consider a new helper) # various transformations (may expand and may consider a new helper)
# ------------------------------------------------------------------------------ # ------------------------------------------------------------------------------
...@@ -270,3 +329,61 @@ def bottleneck_transformation( ...@@ -270,3 +329,61 @@ def bottleneck_transformation(
inplace=False inplace=False
) )
return cur return cur
def bottleneck_gn_transformation(
model,
blob_in,
dim_in,
dim_out,
stride,
prefix,
dim_inner,
dilation=1,
group=1
):
"""Add a bottleneck transformation with GroupNorm to the model."""
# In original resnet, stride=2 is on 1x1.
# In fb.torch resnet, stride=2 is on 3x3.
(str1x1, str3x3) = (stride, 1) if cfg.RESNETS.STRIDE_1X1 else (1, stride)
# conv 1x1 -> GN -> ReLU
cur = model.ConvGN(
blob_in,
prefix + '_branch2a',
dim_in,
dim_inner,
kernel=1,
group_gn=get_group_gn(dim_inner),
stride=str1x1,
pad=0,
)
cur = model.Relu(cur, cur)
# conv 3x3 -> GN -> ReLU
cur = model.ConvGN(
cur,
prefix + '_branch2b',
dim_inner,
dim_inner,
kernel=3,
group_gn=get_group_gn(dim_inner),
stride=str3x3,
pad=1 * dilation,
dilation=dilation,
group=group,
)
cur = model.Relu(cur, cur)
# conv 1x1 -> GN (no ReLU)
cur = model.ConvGN(
cur,
prefix + '_branch2c',
dim_inner,
dim_out,
kernel=1,
group_gn=get_group_gn(dim_out),
stride=1,
pad=0,
)
return cur
...@@ -63,6 +63,7 @@ class DetectionModelHelper(cnn.CNNModelHelper): ...@@ -63,6 +63,7 @@ class DetectionModelHelper(cnn.CNNModelHelper):
self.net.Proto().type = cfg.MODEL.EXECUTION_TYPE self.net.Proto().type = cfg.MODEL.EXECUTION_TYPE
self.net.Proto().num_workers = cfg.NUM_GPUS * 4 self.net.Proto().num_workers = cfg.NUM_GPUS * 4
self.prev_use_cudnn = self.use_cudnn self.prev_use_cudnn = self.use_cudnn
self.gn_params = [] # Param on this list are GroupNorm parameters
def TrainableParams(self, gpu_id=-1): def TrainableParams(self, gpu_id=-1):
"""Get the blob names for all trainable parameters, possibly filtered by """Get the blob names for all trainable parameters, possibly filtered by
...@@ -410,6 +411,48 @@ class DetectionModelHelper(cnn.CNNModelHelper): ...@@ -410,6 +411,48 @@ class DetectionModelHelper(cnn.CNNModelHelper):
) )
return blob_out return blob_out
def ConvGN( # args in the same order of Conv()
self, blob_in, prefix, dim_in, dim_out, kernel, stride, pad,
group_gn, # num of groups in gn
group=1, dilation=1,
weight_init=None,
bias_init=None,
suffix='_gn',
no_conv_bias=1,
):
"""ConvGN adds a Conv op followed by a GroupNorm op,
including learnable scale/bias (gamma/beta)
"""
conv_blob = self.Conv(
blob_in,
prefix,
dim_in,
dim_out,
kernel,
stride=stride,
pad=pad,
group=group,
dilation=dilation,
weight_init=weight_init,
bias_init=bias_init,
no_bias=no_conv_bias)
if group_gn < 1:
logger.warning(
'Layer: {} (dim {}): '
'group_gn < 1; reset to 1.'.format(prefix, dim_in)
)
group_gn = 1
blob_out = self.SpatialGN(
conv_blob, prefix + suffix,
dim_out, num_groups=group_gn,
epsilon=cfg.GROUP_NORM.EPSILON,)
self.gn_params.append(self.params[-1]) # add gn's bias to list
self.gn_params.append(self.params[-2]) # add gn's scale to list
return blob_out
def DisableCudnn(self): def DisableCudnn(self):
self.prev_use_cudnn = self.use_cudnn self.prev_use_cudnn = self.use_cudnn
self.use_cudnn = False self.use_cudnn = False
......
...@@ -35,6 +35,7 @@ from __future__ import unicode_literals ...@@ -35,6 +35,7 @@ from __future__ import unicode_literals
from core.config import cfg from core.config import cfg
from utils.c2 import const_fill from utils.c2 import const_fill
from utils.c2 import gauss_fill from utils.c2 import gauss_fill
from utils.net import get_group_gn
import utils.blob as blob_utils import utils.blob as blob_utils
...@@ -109,3 +110,64 @@ def add_roi_2mlp_head(model, blob_in, dim_in, spatial_scale): ...@@ -109,3 +110,64 @@ def add_roi_2mlp_head(model, blob_in, dim_in, spatial_scale):
model.FC('fc6', 'fc7', hidden_dim, hidden_dim) model.FC('fc6', 'fc7', hidden_dim, hidden_dim)
model.Relu('fc7', 'fc7') model.Relu('fc7', 'fc7')
return 'fc7', hidden_dim return 'fc7', hidden_dim
def add_roi_Xconv1fc_head(model, blob_in, dim_in, spatial_scale):
"""Add a X conv + 1fc head, as a reference if not using GroupNorm"""
hidden_dim = cfg.FAST_RCNN.CONV_HEAD_DIM
roi_size = cfg.FAST_RCNN.ROI_XFORM_RESOLUTION
roi_feat = model.RoIFeatureTransform(
blob_in,
'roi_feat',
blob_rois='rois',
method=cfg.FAST_RCNN.ROI_XFORM_METHOD,
resolution=roi_size,
sampling_ratio=cfg.FAST_RCNN.ROI_XFORM_SAMPLING_RATIO,
spatial_scale=spatial_scale
)
current = roi_feat
for i in range(cfg.FAST_RCNN.NUM_STACKED_CONVS):
current = model.Conv(
current, 'head_conv' + str(i + 1), dim_in, hidden_dim, 3,
stride=1, pad=1,
weight_init=('MSRAFill', {}),
bias_init=('ConstantFill', {'value': 0.}),
no_bias=0)
current = model.Relu(current, current)
dim_in = hidden_dim
fc_dim = cfg.FAST_RCNN.MLP_HEAD_DIM
model.FC(current, 'fc6', dim_in * roi_size * roi_size, fc_dim)
model.Relu('fc6', 'fc6')
return 'fc6', fc_dim
def add_roi_Xconv1fc_gn_head(model, blob_in, dim_in, spatial_scale):
"""Add a X conv + 1fc head, with GroupNorm"""
hidden_dim = cfg.FAST_RCNN.CONV_HEAD_DIM
roi_size = cfg.FAST_RCNN.ROI_XFORM_RESOLUTION
roi_feat = model.RoIFeatureTransform(
blob_in, 'roi_feat',
blob_rois='rois',
method=cfg.FAST_RCNN.ROI_XFORM_METHOD,
resolution=roi_size,
sampling_ratio=cfg.FAST_RCNN.ROI_XFORM_SAMPLING_RATIO,
spatial_scale=spatial_scale
)
current = roi_feat
for i in range(cfg.FAST_RCNN.NUM_STACKED_CONVS):
current = model.ConvGN(
current, 'head_conv' + str(i + 1), dim_in, hidden_dim, 3,
group_gn=get_group_gn(hidden_dim),
stride=1, pad=1,
weight_init=('MSRAFill', {}),
bias_init=('ConstantFill', {'value': 0.}))
current = model.Relu(current, current)
dim_in = hidden_dim
fc_dim = cfg.FAST_RCNN.MLP_HEAD_DIM
model.FC(current, 'fc6', dim_in * roi_size * roi_size, fc_dim)
model.Relu('fc6', 'fc6')
return 'fc6', fc_dim
...@@ -35,6 +35,7 @@ from __future__ import unicode_literals ...@@ -35,6 +35,7 @@ from __future__ import unicode_literals
from core.config import cfg from core.config import cfg
from utils.c2 import const_fill from utils.c2 import const_fill
from utils.c2 import gauss_fill from utils.c2 import gauss_fill
from utils.net import get_group_gn
import modeling.ResNet as ResNet import modeling.ResNet as ResNet
import utils.blob as blob_utils import utils.blob as blob_utils
...@@ -114,6 +115,13 @@ def mask_rcnn_fcn_head_v1up4convs(model, blob_in, dim_in, spatial_scale): ...@@ -114,6 +115,13 @@ def mask_rcnn_fcn_head_v1up4convs(model, blob_in, dim_in, spatial_scale):
) )
def mask_rcnn_fcn_head_v1up4convs_gn(model, blob_in, dim_in, spatial_scale):
"""v1up design: 4 * (conv 3x3), convT 2x2, with GroupNorm"""
return mask_rcnn_fcn_head_v1upXconvs_gn(
model, blob_in, dim_in, spatial_scale, 4
)
def mask_rcnn_fcn_head_v1up(model, blob_in, dim_in, spatial_scale): def mask_rcnn_fcn_head_v1up(model, blob_in, dim_in, spatial_scale):
"""v1up design: 2 * (conv 3x3), convT 2x2.""" """v1up design: 2 * (conv 3x3), convT 2x2."""
return mask_rcnn_fcn_head_v1upXconvs( return mask_rcnn_fcn_head_v1upXconvs(
...@@ -170,6 +178,56 @@ def mask_rcnn_fcn_head_v1upXconvs( ...@@ -170,6 +178,56 @@ def mask_rcnn_fcn_head_v1upXconvs(
return blob_mask, dim_inner return blob_mask, dim_inner
def mask_rcnn_fcn_head_v1upXconvs_gn(
model, blob_in, dim_in, spatial_scale, num_convs
):
"""v1upXconvs design: X * (conv 3x3), convT 2x2, with GroupNorm"""
current = model.RoIFeatureTransform(
blob_in,
blob_out='_mask_roi_feat',
blob_rois='mask_rois',
method=cfg.MRCNN.ROI_XFORM_METHOD,
resolution=cfg.MRCNN.ROI_XFORM_RESOLUTION,
sampling_ratio=cfg.MRCNN.ROI_XFORM_SAMPLING_RATIO,
spatial_scale=spatial_scale
)
dilation = cfg.MRCNN.DILATION
dim_inner = cfg.MRCNN.DIM_REDUCED
for i in range(num_convs):
current = model.ConvGN(
current,
'_mask_fcn' + str(i + 1),
dim_in,
dim_inner,
group_gn=get_group_gn(dim_inner),
kernel=3,
pad=1 * dilation,
stride=1,
weight_init=(cfg.MRCNN.CONV_INIT, {'std': 0.001}),
bias_init=('ConstantFill', {'value': 0.})
)
current = model.Relu(current, current)
dim_in = dim_inner
# upsample layer
model.ConvTranspose(
current,
'conv5_mask',
dim_inner,
dim_inner,
kernel=2,
pad=0,
stride=2,
weight_init=(cfg.MRCNN.CONV_INIT, {'std': 0.001}),
bias_init=const_fill(0.0)
)
blob_mask = model.Relu('conv5_mask', 'conv5_mask')
return blob_mask, dim_inner
def mask_rcnn_fcn_head_v0upshare(model, blob_in, dim_in, spatial_scale): def mask_rcnn_fcn_head_v0upshare(model, blob_in, dim_in, spatial_scale):
"""Use a ResNet "conv5" / "stage5" head for mask prediction. Weights and """Use a ResNet "conv5" / "stage5" head for mask prediction. Weights and
computation are shared with the conv5 box head. Computation can only be computation are shared with the conv5 box head. Computation can only be
......
...@@ -99,6 +99,10 @@ def add_single_gpu_param_update_ops(model, gpu_id): ...@@ -99,6 +99,10 @@ def add_single_gpu_param_update_ops(model, gpu_id):
wd = model.param_init_net.ConstantFill( wd = model.param_init_net.ConstantFill(
[], 'wd', shape=[1], value=cfg.SOLVER.WEIGHT_DECAY [], 'wd', shape=[1], value=cfg.SOLVER.WEIGHT_DECAY
) )
# weight decay of GroupNorm's parameters
wd_gn = model.param_init_net.ConstantFill(
[], 'wd_gn', shape=[1], value=cfg.SOLVER.WEIGHT_DECAY_GN
)
for param in model.TrainableParams(gpu_id=gpu_id): for param in model.TrainableParams(gpu_id=gpu_id):
logger.debug('param ' + str(param) + ' will be updated') logger.debug('param ' + str(param) + ' will be updated')
param_grad = model.param_to_grad[param] param_grad = model.param_to_grad[param]
...@@ -112,6 +116,9 @@ def add_single_gpu_param_update_ops(model, gpu_id): ...@@ -112,6 +116,9 @@ def add_single_gpu_param_update_ops(model, gpu_id):
# (1) Do not apply weight decay # (1) Do not apply weight decay
# (2) Use a 2x higher learning rate # (2) Use a 2x higher learning rate
model.Scale(param_grad, param_grad, scale=2.0) model.Scale(param_grad, param_grad, scale=2.0)
elif param in model.gn_params:
# Special treatment for GroupNorm's parameters
model.WeightedSum([param_grad, one, param, wd_gn], param_grad)
elif cfg.SOLVER.WEIGHT_DECAY > 0: elif cfg.SOLVER.WEIGHT_DECAY > 0:
# Apply weight decay to non-bias weights # Apply weight decay to non-bias weights
model.WeightedSum([param_grad, one, param, wd], param_grad) model.WeightedSum([param_grad, one, param, wd], param_grad)
......
...@@ -274,3 +274,22 @@ def configure_bbox_reg_weights(model, saved_cfg): ...@@ -274,3 +274,22 @@ def configure_bbox_reg_weights(model, saved_cfg):
'longer be used for training. To upgrade it to a trainable model ' 'longer be used for training. To upgrade it to a trainable model '
'please use fb/compat/convert_bbox_reg_normalized_model.py.' 'please use fb/compat/convert_bbox_reg_normalized_model.py.'
) )
def get_group_gn(dim):
"""
get number of groups used by GroupNorm, based on number of channels
"""
dim_per_gp = cfg.GROUP_NORM.DIM_PER_GP
num_groups = cfg.GROUP_NORM.NUM_GROUPS
assert dim_per_gp == -1 or num_groups == -1, \
"GroupNorm: can only specify G or C/G."
if dim_per_gp > 0:
assert dim % dim_per_gp == 0
group_gn = dim // dim_per_gp
else:
assert dim % num_groups == 0
group_gn = num_groups
return group_gn
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment