Commit 1d6e9add authored by zimenglan's avatar zimenglan Committed by Francisco Massa

add dcn from mmdetection (#693)

* make pixel indexes 0-based for bounding box in pascal voc dataset

* replacing all instances of torch.distributed.deprecated with torch.distributed

* replacing all instances of torch.distributed.deprecated with torch.distributed

* add GroupNorm

* add GroupNorm -- sort out yaml files

* use torch.nn.GroupNorm instead, replace 'use_gn' with 'conv_block' and use 'BaseStem'&'Bottleneck' to simply codes

* modification on 'group_norm' and 'conv_with_kaiming_uniform' function

* modification on yaml files in configs/gn_baselines/ and reduce the amount of indentation and code duplication

* use 'kaiming_uniform' to initialize resnet, disable gn after fc layer, and add dilation into ResNetHead

* agnostic-regression for bbox

* please set 'STRIDE_IN_1X1' to be 'False' when backbone use GN

* add README.md for GN

* add dcn from mmdetection
parent 1714b7c2
......@@ -28,3 +28,4 @@ dist/
# project dirs
/datasets
/models
/output
### Reference
1 [Deformable ConvNets v2: More Deformable, Better Results](https://arxiv.org/pdf/1811.11168.pdf)
2 third-party: [mmdetection](https://github.com/open-mmlab/mmdetection/tree/master/configs/dcn)
### Performance
| case | bbox AP | mask AP |
|----------------------------:|--------:|:-------:|
| R-50-FPN-dcn (implement) | 39.8 | - |
| R-50-FPN-dcn (mmdetection) | 40.0 | - |
| R-50-FPN-mdcn (implement) | 40.0 | - |
| R-50-FPN-mdcn (mmdetection) | 40.3 | - |
| R-50-FPN-dcn (implement) | 40.8 | 36.8 |
| R-50-FPN-dcn (mmdetection) | 41.1 | 37.2 |
| R-50-FPN-dcn (implement) | 40.7 | 36.7 |
| R-50-FPN-dcn (mmdetection) | 41.4 | 37.4 |
### Note
see [dcn-v2](https://github.com/open-mmlab/mmdetection/blob/master/MODEL_ZOO.md#deformable-convolution-v2) in `mmdetection` for more details.
### Usage
add these three lines
```
MODEL:
RESNETS:
# corresponding to C2,C3,C4,C5
STAGE_WITH_DCN: (False, True, True, True)
WITH_MODULATED_DCN: True
DEFORMABLE_GROUPS: 1
```
\ No newline at end of file
INPUT:
MIN_SIZE_TRAIN: (800,)
MAX_SIZE_TRAIN: 1333
MIN_SIZE_TEST: 800
MAX_SIZE_TEST: 1333
MODEL:
META_ARCHITECTURE: "GeneralizedRCNN"
WEIGHT: "catalog://ImageNetPretrained/MSRA/R-50"
BACKBONE:
CONV_BODY: "R-50-FPN"
RESNETS:
BACKBONE_OUT_CHANNELS: 256
STAGE_WITH_DCN: (False, True, True, True)
WITH_MODULATED_DCN: False
DEFORMABLE_GROUPS: 1
RPN:
USE_FPN: True
ANCHOR_STRIDE: (4, 8, 16, 32, 64)
PRE_NMS_TOP_N_TRAIN: 2000
PRE_NMS_TOP_N_TEST: 1000
POST_NMS_TOP_N_TEST: 1000
FPN_POST_NMS_TOP_N_TEST: 1000
ROI_HEADS:
USE_FPN: True
ROI_BOX_HEAD:
POOLER_RESOLUTION: 7
POOLER_SCALES: (0.25, 0.125, 0.0625, 0.03125)
POOLER_SAMPLING_RATIO: 2
FEATURE_EXTRACTOR: "FPN2MLPFeatureExtractor"
PREDICTOR: "FPNPredictor"
DATASETS:
TRAIN: ("coco_2014_train", "coco_2014_valminusminival")
TEST: ("coco_2014_minival",)
DATALOADER:
SIZE_DIVISIBILITY: 32
SOLVER:
# Assume 8 gpus
BASE_LR: 0.02
WEIGHT_DECAY: 0.0001
STEPS: (60000, 80000)
MAX_ITER: 90000
IMS_PER_BATCH: 16
TEST:
IMS_PER_BATCH: 8
INPUT:
MIN_SIZE_TRAIN: (800,)
MAX_SIZE_TRAIN: 1333
MIN_SIZE_TEST: 800
MAX_SIZE_TEST: 1333
MODEL:
META_ARCHITECTURE: "GeneralizedRCNN"
WEIGHT: "catalog://ImageNetPretrained/MSRA/R-50"
BACKBONE:
CONV_BODY: "R-50-FPN"
RESNETS:
BACKBONE_OUT_CHANNELS: 256
STAGE_WITH_DCN: (False, True, True, True)
WITH_MODULATED_DCN: True
DEFORMABLE_GROUPS: 1
RPN:
USE_FPN: True
ANCHOR_STRIDE: (4, 8, 16, 32, 64)
PRE_NMS_TOP_N_TRAIN: 2000
PRE_NMS_TOP_N_TEST: 1000
POST_NMS_TOP_N_TEST: 1000
FPN_POST_NMS_TOP_N_TEST: 1000
ROI_HEADS:
USE_FPN: True
ROI_BOX_HEAD:
POOLER_RESOLUTION: 7
POOLER_SCALES: (0.25, 0.125, 0.0625, 0.03125)
POOLER_SAMPLING_RATIO: 2
FEATURE_EXTRACTOR: "FPN2MLPFeatureExtractor"
PREDICTOR: "FPNPredictor"
DATASETS:
TRAIN: ("coco_2014_train", "coco_2014_valminusminival")
TEST: ("coco_2014_minival",)
DATALOADER:
SIZE_DIVISIBILITY: 32
SOLVER:
# Assume 8 gpus
BASE_LR: 0.02
WEIGHT_DECAY: 0.0001
STEPS: (60000, 80000)
MAX_ITER: 90000
IMS_PER_BATCH: 16
TEST:
IMS_PER_BATCH: 8
INPUT:
MIN_SIZE_TRAIN: (800,)
MAX_SIZE_TRAIN: 1333
MIN_SIZE_TEST: 800
MAX_SIZE_TEST: 1333
MODEL:
META_ARCHITECTURE: "GeneralizedRCNN"
WEIGHT: "catalog://ImageNetPretrained/MSRA/R-50"
BACKBONE:
CONV_BODY: "R-50-FPN"
RESNETS:
BACKBONE_OUT_CHANNELS: 256
STAGE_WITH_DCN: (False, True, True, True)
WITH_MODULATED_DCN: False
DEFORMABLE_GROUPS: 1
RPN:
USE_FPN: True
ANCHOR_STRIDE: (4, 8, 16, 32, 64)
PRE_NMS_TOP_N_TRAIN: 2000
PRE_NMS_TOP_N_TEST: 1000
POST_NMS_TOP_N_TEST: 1000
FPN_POST_NMS_TOP_N_TEST: 1000
ROI_HEADS:
USE_FPN: True
ROI_BOX_HEAD:
POOLER_RESOLUTION: 7
POOLER_SCALES: (0.25, 0.125, 0.0625, 0.03125)
POOLER_SAMPLING_RATIO: 2
FEATURE_EXTRACTOR: "FPN2MLPFeatureExtractor"
PREDICTOR: "FPNPredictor"
ROI_MASK_HEAD:
POOLER_SCALES: (0.25, 0.125, 0.0625, 0.03125)
FEATURE_EXTRACTOR: "MaskRCNNFPNFeatureExtractor"
PREDICTOR: "MaskRCNNC4Predictor"
POOLER_RESOLUTION: 14
POOLER_SAMPLING_RATIO: 2
RESOLUTION: 28
SHARE_BOX_FEATURE_EXTRACTOR: False
MASK_ON: True
DATASETS:
TRAIN: ("coco_2014_train", "coco_2014_valminusminival")
TEST: ("coco_2014_minival",)
DATALOADER:
SIZE_DIVISIBILITY: 32
SOLVER:
# Assume 8 gpus
BASE_LR: 0.02
WEIGHT_DECAY: 0.0001
STEPS: (60000, 80000)
MAX_ITER: 90000
IMS_PER_BATCH: 16
TEST:
IMS_PER_BATCH: 8
INPUT:
MIN_SIZE_TRAIN: (800,)
MAX_SIZE_TRAIN: 1333
MIN_SIZE_TEST: 800
MAX_SIZE_TEST: 1333
MODEL:
META_ARCHITECTURE: "GeneralizedRCNN"
WEIGHT: "catalog://ImageNetPretrained/MSRA/R-50"
BACKBONE:
CONV_BODY: "R-50-FPN"
RESNETS:
BACKBONE_OUT_CHANNELS: 256
STAGE_WITH_DCN: (False, True, True, True)
WITH_MODULATED_DCN: True
DEFORMABLE_GROUPS: 1
RPN:
USE_FPN: True
ANCHOR_STRIDE: (4, 8, 16, 32, 64)
PRE_NMS_TOP_N_TRAIN: 2000
PRE_NMS_TOP_N_TEST: 1000
POST_NMS_TOP_N_TEST: 1000
FPN_POST_NMS_TOP_N_TEST: 1000
ROI_HEADS:
USE_FPN: True
ROI_BOX_HEAD:
POOLER_RESOLUTION: 7
POOLER_SCALES: (0.25, 0.125, 0.0625, 0.03125)
POOLER_SAMPLING_RATIO: 2
FEATURE_EXTRACTOR: "FPN2MLPFeatureExtractor"
PREDICTOR: "FPNPredictor"
ROI_MASK_HEAD:
POOLER_SCALES: (0.25, 0.125, 0.0625, 0.03125)
FEATURE_EXTRACTOR: "MaskRCNNFPNFeatureExtractor"
PREDICTOR: "MaskRCNNC4Predictor"
POOLER_RESOLUTION: 14
POOLER_SAMPLING_RATIO: 2
RESOLUTION: 28
SHARE_BOX_FEATURE_EXTRACTOR: False
MASK_ON: True
DATASETS:
TRAIN: ("coco_2014_train", "coco_2014_valminusminival")
TEST: ("coco_2014_minival",)
DATALOADER:
SIZE_DIVISIBILITY: 32
SOLVER:
# Assume 8 gpus
BASE_LR: 0.02
WEIGHT_DECAY: 0.0001
STEPS: (60000, 80000)
MAX_ITER: 90000
IMS_PER_BATCH: 16
TEST:
IMS_PER_BATCH: 8
......@@ -274,6 +274,10 @@ _C.MODEL.RESNETS.BACKBONE_OUT_CHANNELS = 256 * 4
_C.MODEL.RESNETS.RES2_OUT_CHANNELS = 256
_C.MODEL.RESNETS.STEM_OUT_CHANNELS = 64
_C.MODEL.RESNETS.STAGE_WITH_DCN = (False, False, False, False)
_C.MODEL.RESNETS.WITH_MODULATED_DCN = False
_C.MODEL.RESNETS.DEFORMABLE_GROUPS = 1
# ---------------------------------------------------------------------------- #
# RetinaNet Options (Follow the Detectron version)
......
This diff is collapsed.
This diff is collapsed.
// modify from
// https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/modulated_dcn_cuda.c
// based on
// author: Charles Shang
// https://github.com/torch/cunn/blob/master/lib/THCUNN/generic/SpatialConvolutionMM.cu
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <THC/THC.h>
#include <THC/THCDeviceUtils.cuh>
#include <vector>
#include <iostream>
#include <cmath>
void DeformablePSROIPoolForward(
const at::Tensor data, const at::Tensor bbox, const at::Tensor trans,
at::Tensor out, at::Tensor top_count, const int batch, const int channels,
const int height, const int width, const int num_bbox,
const int channels_trans, const int no_trans, const float spatial_scale,
const int output_dim, const int group_size, const int pooled_size,
const int part_size, const int sample_per_part, const float trans_std);
void DeformablePSROIPoolBackwardAcc(
const at::Tensor out_grad, const at::Tensor data, const at::Tensor bbox,
const at::Tensor trans, const at::Tensor top_count, at::Tensor in_grad,
at::Tensor trans_grad, const int batch, const int channels,
const int height, const int width, const int num_bbox,
const int channels_trans, const int no_trans, const float spatial_scale,
const int output_dim, const int group_size, const int pooled_size,
const int part_size, const int sample_per_part, const float trans_std);
void deform_psroi_pooling_cuda_forward(
at::Tensor input, at::Tensor bbox, at::Tensor trans, at::Tensor out,
at::Tensor top_count, const int no_trans, const float spatial_scale,
const int output_dim, const int group_size, const int pooled_size,
const int part_size, const int sample_per_part, const float trans_std)
{
AT_CHECK(input.is_contiguous(), "input tensor has to be contiguous");
const int batch = input.size(0);
const int channels = input.size(1);
const int height = input.size(2);
const int width = input.size(3);
const int channels_trans = no_trans ? 2 : trans.size(1);
const int num_bbox = bbox.size(0);
if (num_bbox != out.size(0))
AT_ERROR("Output shape and bbox number wont match: (%d vs %d).",
out.size(0), num_bbox);
DeformablePSROIPoolForward(
input, bbox, trans, out, top_count, batch, channels, height, width,
num_bbox, channels_trans, no_trans, spatial_scale, output_dim, group_size,
pooled_size, part_size, sample_per_part, trans_std);
}
void deform_psroi_pooling_cuda_backward(
at::Tensor out_grad, at::Tensor input, at::Tensor bbox, at::Tensor trans,
at::Tensor top_count, at::Tensor input_grad, at::Tensor trans_grad,
const int no_trans, const float spatial_scale, const int output_dim,
const int group_size, const int pooled_size, const int part_size,
const int sample_per_part, const float trans_std)
{
AT_CHECK(out_grad.is_contiguous(), "out_grad tensor has to be contiguous");
AT_CHECK(input.is_contiguous(), "input tensor has to be contiguous");
const int batch = input.size(0);
const int channels = input.size(1);
const int height = input.size(2);
const int width = input.size(3);
const int channels_trans = no_trans ? 2 : trans.size(1);
const int num_bbox = bbox.size(0);
if (num_bbox != out_grad.size(0))
AT_ERROR("Output shape and bbox number wont match: (%d vs %d).",
out_grad.size(0), num_bbox);
DeformablePSROIPoolBackwardAcc(
out_grad, input, bbox, trans, top_count, input_grad, trans_grad, batch,
channels, height, width, num_bbox, channels_trans, no_trans,
spatial_scale, output_dim, group_size, pooled_size, part_size,
sample_per_part, trans_std);
}
This diff is collapsed.
......@@ -58,6 +58,59 @@ at::Tensor ROIPool_backward_cuda(const at::Tensor& grad,
at::Tensor nms_cuda(const at::Tensor boxes, float nms_overlap_thresh);
int deform_conv_forward_cuda(at::Tensor input, at::Tensor weight,
at::Tensor offset, at::Tensor output,
at::Tensor columns, at::Tensor ones, int kW,
int kH, int dW, int dH, int padW, int padH,
int dilationW, int dilationH, int group,
int deformable_group, int im2col_step);
int deform_conv_backward_input_cuda(at::Tensor input, at::Tensor offset,
at::Tensor gradOutput, at::Tensor gradInput,
at::Tensor gradOffset, at::Tensor weight,
at::Tensor columns, int kW, int kH, int dW,
int dH, int padW, int padH, int dilationW,
int dilationH, int group,
int deformable_group, int im2col_step);
int deform_conv_backward_parameters_cuda(
at::Tensor input, at::Tensor offset, at::Tensor gradOutput,
at::Tensor gradWeight, // at::Tensor gradBias,
at::Tensor columns, at::Tensor ones, int kW, int kH, int dW, int dH,
int padW, int padH, int dilationW, int dilationH, int group,
int deformable_group, float scale, int im2col_step);
void modulated_deform_conv_cuda_forward(
at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones,
at::Tensor offset, at::Tensor mask, at::Tensor output, at::Tensor columns,
int kernel_h, int kernel_w, const int stride_h, const int stride_w,
const int pad_h, const int pad_w, const int dilation_h,
const int dilation_w, const int group, const int deformable_group,
const bool with_bias);
void modulated_deform_conv_cuda_backward(
at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones,
at::Tensor offset, at::Tensor mask, at::Tensor columns,
at::Tensor grad_input, at::Tensor grad_weight, at::Tensor grad_bias,
at::Tensor grad_offset, at::Tensor grad_mask, at::Tensor grad_output,
int kernel_h, int kernel_w, int stride_h, int stride_w, int pad_h,
int pad_w, int dilation_h, int dilation_w, int group, int deformable_group,
const bool with_bias);
void deform_psroi_pooling_cuda_forward(
at::Tensor input, at::Tensor bbox, at::Tensor trans, at::Tensor out,
at::Tensor top_count, const int no_trans, const float spatial_scale,
const int output_dim, const int group_size, const int pooled_size,
const int part_size, const int sample_per_part, const float trans_std);
void deform_psroi_pooling_cuda_backward(
at::Tensor out_grad, at::Tensor input, at::Tensor bbox, at::Tensor trans,
at::Tensor top_count, at::Tensor input_grad, at::Tensor trans_grad,
const int no_trans, const float spatial_scale, const int output_dim,
const int group_size, const int pooled_size, const int part_size,
const int sample_per_part, const float trans_std);
at::Tensor compute_flow_cuda(const at::Tensor& boxes,
const int height,
const int width);
// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
#pragma once
#include "cpu/vision.h"
#ifdef WITH_CUDA
#include "cuda/vision.h"
#endif
// Interface for Python
int deform_conv_forward(
at::Tensor input,
at::Tensor weight,
at::Tensor offset,
at::Tensor output,
at::Tensor columns,
at::Tensor ones,
int kW,
int kH,
int dW,
int dH,
int padW,
int padH,
int dilationW,
int dilationH,
int group,
int deformable_group,
int im2col_step)
{
if (input.type().is_cuda()) {
#ifdef WITH_CUDA
return deform_conv_forward_cuda(
input, weight, offset, output, columns, ones,
kW, kH, dW, dH, padW, padH, dilationW, dilationH,
group, deformable_group, im2col_step
);
#else
AT_ERROR("Not compiled with GPU support");
#endif
}
AT_ERROR("Not implemented on the CPU");
}
int deform_conv_backward_input(
at::Tensor input,
at::Tensor offset,
at::Tensor gradOutput,
at::Tensor gradInput,
at::Tensor gradOffset,
at::Tensor weight,
at::Tensor columns,
int kW,
int kH,
int dW,
int dH,
int padW,
int padH,
int dilationW,
int dilationH,
int group,
int deformable_group,
int im2col_step)
{
if (input.type().is_cuda()) {
#ifdef WITH_CUDA
return deform_conv_backward_input_cuda(
input, offset, gradOutput, gradInput, gradOffset, weight, columns,
kW, kH, dW, dH, padW, padH, dilationW, dilationH,
group, deformable_group, im2col_step
);
#else
AT_ERROR("Not compiled with GPU support");
#endif
}
AT_ERROR("Not implemented on the CPU");
}
int deform_conv_backward_parameters(
at::Tensor input,
at::Tensor offset,
at::Tensor gradOutput,
at::Tensor gradWeight, // at::Tensor gradBias,
at::Tensor columns,
at::Tensor ones,
int kW,
int kH,
int dW,
int dH,
int padW,
int padH,
int dilationW,
int dilationH,
int group,
int deformable_group,
float scale,
int im2col_step)
{
if (input.type().is_cuda()) {
#ifdef WITH_CUDA
return deform_conv_backward_parameters_cuda(
input, offset, gradOutput, gradWeight, columns, ones,
kW, kH, dW, dH, padW, padH, dilationW, dilationH,
group, deformable_group, scale, im2col_step
);
#else
AT_ERROR("Not compiled with GPU support");
#endif
}
AT_ERROR("Not implemented on the CPU");
}
void modulated_deform_conv_forward(
at::Tensor input,
at::Tensor weight,
at::Tensor bias,
at::Tensor ones,
at::Tensor offset,
at::Tensor mask,
at::Tensor output,
at::Tensor columns,
int kernel_h,
int kernel_w,
const int stride_h,
const int stride_w,
const int pad_h,
const int pad_w,
const int dilation_h,
const int dilation_w,
const int group,
const int deformable_group,
const bool with_bias)
{
if (input.type().is_cuda()) {
#ifdef WITH_CUDA
return modulated_deform_conv_cuda_forward(
input, weight, bias, ones, offset, mask, output, columns,
kernel_h, kernel_w, stride_h, stride_w,
pad_h, pad_w, dilation_h, dilation_w,
group, deformable_group, with_bias
);
#else
AT_ERROR("Not compiled with GPU support");
#endif
}
AT_ERROR("Not implemented on the CPU");
}
void modulated_deform_conv_backward(
at::Tensor input,
at::Tensor weight,
at::Tensor bias,
at::Tensor ones,
at::Tensor offset,
at::Tensor mask,
at::Tensor columns,
at::Tensor grad_input,
at::Tensor grad_weight,
at::Tensor grad_bias,
at::Tensor grad_offset,
at::Tensor grad_mask,
at::Tensor grad_output,
int kernel_h,
int kernel_w,
int stride_h,
int stride_w,
int pad_h,
int pad_w,
int dilation_h,
int dilation_w,
int group,
int deformable_group,
const bool with_bias)
{
if (input.type().is_cuda()) {
#ifdef WITH_CUDA
return modulated_deform_conv_cuda_backward(
input, weight, bias, ones, offset, mask, columns,
grad_input, grad_weight, grad_bias, grad_offset, grad_mask, grad_output,
kernel_h, kernel_w, stride_h, stride_w, pad_h, pad_w, dilation_h, dilation_w,
group, deformable_group, with_bias
);
#else
AT_ERROR("Not compiled with GPU support");
#endif
}
AT_ERROR("Not implemented on the CPU");
}
\ No newline at end of file
// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
#pragma once
#include "cpu/vision.h"
#ifdef WITH_CUDA
#include "cuda/vision.h"
#endif
// Interface for Python
void deform_psroi_pooling_forward(
at::Tensor input,
at::Tensor bbox,
at::Tensor trans,
at::Tensor out,
at::Tensor top_count,
const int no_trans,
const float spatial_scale,
const int output_dim,
const int group_size,
const int pooled_size,
const int part_size,
const int sample_per_part,
const float trans_std)
{
if (input.type().is_cuda()) {
#ifdef WITH_CUDA
return deform_psroi_pooling_cuda_forward(
input, bbox, trans, out, top_count,
no_trans, spatial_scale, output_dim, group_size,
pooled_size, part_size, sample_per_part, trans_std
);
#else
AT_ERROR("Not compiled with GPU support");
#endif
}
AT_ERROR("Not implemented on the CPU");
}
void deform_psroi_pooling_backward(
at::Tensor out_grad,
at::Tensor input,
at::Tensor bbox,
at::Tensor trans,
at::Tensor top_count,
at::Tensor input_grad,
at::Tensor trans_grad,
const int no_trans,
const float spatial_scale,
const int output_dim,
const int group_size,
const int pooled_size,
const int part_size,
const int sample_per_part,
const float trans_std)
{
if (input.type().is_cuda()) {
#ifdef WITH_CUDA
return deform_psroi_pooling_cuda_backward(
out_grad, input, bbox, trans, top_count, input_grad, trans_grad,
no_trans, spatial_scale, output_dim, group_size, pooled_size,
part_size, sample_per_part, trans_std
);
#else
AT_ERROR("Not compiled with GPU support");
#endif
}
AT_ERROR("Not implemented on the CPU");
}
......@@ -3,6 +3,8 @@
#include "ROIAlign.h"
#include "ROIPool.h"
#include "SigmoidFocalLoss.h"
#include "deform_conv.h"
#include "deform_pool.h"
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("nms", &nms, "non-maximum suppression");
......@@ -12,4 +14,12 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("roi_pool_backward", &ROIPool_backward, "ROIPool_backward");
m.def("sigmoid_focalloss_forward", &SigmoidFocalLoss_forward, "SigmoidFocalLoss_forward");
m.def("sigmoid_focalloss_backward", &SigmoidFocalLoss_backward, "SigmoidFocalLoss_backward");
}
// dcn-v2
m.def("deform_conv_forward", &deform_conv_forward, "deform_conv_forward");
m.def("deform_conv_backward_input", &deform_conv_backward_input, "deform_conv_backward_input");
m.def("deform_conv_backward_parameters", &deform_conv_backward_parameters, "deform_conv_backward_parameters");
m.def("modulated_deform_conv_forward", &modulated_deform_conv_forward, "modulated_deform_conv_forward");
m.def("modulated_deform_conv_backward", &modulated_deform_conv_backward, "modulated_deform_conv_backward");
m.def("deform_psroi_pooling_forward", &deform_psroi_pooling_forward, "deform_psroi_pooling_forward");
m.def("deform_psroi_pooling_backward", &deform_psroi_pooling_backward, "deform_psroi_pooling_backward");
}
\ No newline at end of file
......@@ -3,6 +3,7 @@ import torch
from .batch_norm import FrozenBatchNorm2d
from .misc import Conv2d
from .misc import DFConv2d
from .misc import ConvTranspose2d
from .misc import BatchNorm2d
from .misc import interpolate
......@@ -13,9 +14,34 @@ from .roi_pool import ROIPool
from .roi_pool import roi_pool
from .smooth_l1_loss import smooth_l1_loss
from .sigmoid_focal_loss import SigmoidFocalLoss
from .dcn.deform_conv_func import deform_conv, modulated_deform_conv
from .dcn.deform_conv_module import DeformConv, ModulatedDeformConv, ModulatedDeformConvPack
from .dcn.deform_pool_func import deform_roi_pooling
from .dcn.deform_pool_module import DeformRoIPooling, DeformRoIPoolingPack, ModulatedDeformRoIPoolingPack
__all__ = ["nms", "roi_align", "ROIAlign", "roi_pool", "ROIPool",
"smooth_l1_loss", "Conv2d", "ConvTranspose2d", "interpolate",
"BatchNorm2d", "FrozenBatchNorm2d", "SigmoidFocalLoss"
]
__all__ = [
"nms",
"roi_align",
"ROIAlign",
"roi_pool",
"ROIPool",
"smooth_l1_loss",
"Conv2d",
"DFConv2d",
"ConvTranspose2d",
"interpolate",
"BatchNorm2d",
"FrozenBatchNorm2d",
"SigmoidFocalLoss",
'deform_conv',
'modulated_deform_conv',
'DeformConv',
'ModulatedDeformConv',
'ModulatedDeformConvPack',
'deform_roi_pooling',
'DeformRoIPooling',
'DeformRoIPoolingPack',
'ModulatedDeformRoIPoolingPack',
]
#
# Copied From [mmdetection](https://github.com/open-mmlab/mmdetection/tree/master/mmdet/ops/dcn)
#
\ No newline at end of file
import torch
from torch.autograd import Function
from torch.nn.modules.utils import _pair
from maskrcnn_benchmark import _C
class DeformConvFunction(Function):
@staticmethod
def forward(
ctx,
input,
offset,
weight,
stride=1,
padding=0,
dilation=1,
groups=1,
deformable_groups=1,
im2col_step=64
):
if input is not None and input.dim() != 4:
raise ValueError(
"Expected 4D tensor as input, got {}D tensor instead.".format(
input.dim()))
ctx.stride = _pair(stride)
ctx.padding = _pair(padding)
ctx.dilation = _pair(dilation)
ctx.groups = groups
ctx.deformable_groups = deformable_groups
ctx.im2col_step = im2col_step
ctx.save_for_backward(input, offset, weight)
output = input.new_empty(
DeformConvFunction._output_size(input, weight, ctx.padding,
ctx.dilation, ctx.stride))
ctx.bufs_ = [input.new_empty(0), input.new_empty(0)] # columns, ones
if not input.is_cuda:
raise NotImplementedError
else:
cur_im2col_step = min(ctx.im2col_step, input.shape[0])
assert (input.shape[0] %
cur_im2col_step) == 0, 'im2col step must divide batchsize'
_C.deform_conv_forward(
input,
weight,
offset,
output,
ctx.bufs_[0],
ctx.bufs_[1],
weight.size(3),
weight.size(2),
ctx.stride[1],
ctx.stride[0],
ctx.padding[1],
ctx.padding[0],
ctx.dilation[1],
ctx.dilation[0],
ctx.groups,
ctx.deformable_groups,
cur_im2col_step
)
return output
@staticmethod
def backward(ctx, grad_output):
input, offset, weight = ctx.saved_tensors
grad_input = grad_offset = grad_weight = None
if not grad_output.is_cuda:
raise NotImplementedError
else:
cur_im2col_step = min(ctx.im2col_step, input.shape[0])
assert (input.shape[0] %
cur_im2col_step) == 0, 'im2col step must divide batchsize'
if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]:
grad_input = torch.zeros_like(input)
grad_offset = torch.zeros_like(offset)
_C.deform_conv_backward_input(
input,
offset,
grad_output,
grad_input,
grad_offset,
weight,
ctx.bufs_[0],
weight.size(3),
weight.size(2),
ctx.stride[1],
ctx.stride[0],
ctx.padding[1],
ctx.padding[0],
ctx.dilation[1],
ctx.dilation[0],
ctx.groups,
ctx.deformable_groups,
cur_im2col_step
)
if ctx.needs_input_grad[2]:
grad_weight = torch.zeros_like(weight)
_C.deform_conv_backward_parameters(
input,
offset,
grad_output,
grad_weight,
ctx.bufs_[0],
ctx.bufs_[1],
weight.size(3),
weight.size(2),
ctx.stride[1],
ctx.stride[0],
ctx.padding[1],
ctx.padding[0],
ctx.dilation[1],
ctx.dilation[0],
ctx.groups,
ctx.deformable_groups,
1,
cur_im2col_step
)
return (grad_input, grad_offset, grad_weight, None, None, None, None, None)
@staticmethod
def _output_size(input, weight, padding, dilation, stride):
channels = weight.size(0)
output_size = (input.size(0), channels)
for d in range(input.dim() - 2):
in_size = input.size(d + 2)
pad = padding[d]
kernel = dilation[d] * (weight.size(d + 2) - 1) + 1
stride_ = stride[d]
output_size += ((in_size + (2 * pad) - kernel) // stride_ + 1, )
if not all(map(lambda s: s > 0, output_size)):
raise ValueError(
"convolution input is too small (output would be {})".format(
'x'.join(map(str, output_size))))
return output_size
class ModulatedDeformConvFunction(Function):
@staticmethod
def forward(
ctx,
input,
offset,
mask,
weight,
bias=None,
stride=1,
padding=0,
dilation=1,
groups=1,
deformable_groups=1
):
ctx.stride = stride
ctx.padding = padding
ctx.dilation = dilation
ctx.groups = groups
ctx.deformable_groups = deformable_groups
ctx.with_bias = bias is not None
if not ctx.with_bias:
bias = input.new_empty(1) # fake tensor
if not input.is_cuda:
raise NotImplementedError
if weight.requires_grad or mask.requires_grad or offset.requires_grad \
or input.requires_grad:
ctx.save_for_backward(input, offset, mask, weight, bias)
output = input.new_empty(
ModulatedDeformConvFunction._infer_shape(ctx, input, weight))
ctx._bufs = [input.new_empty(0), input.new_empty(0)]
_C.modulated_deform_conv_forward(
input,
weight,
bias,
ctx._bufs[0],
offset,
mask,
output,
ctx._bufs[1],
weight.shape[2],
weight.shape[3],
ctx.stride,
ctx.stride,
ctx.padding,
ctx.padding,
ctx.dilation,
ctx.dilation,
ctx.groups,
ctx.deformable_groups,
ctx.with_bias
)
return output
@staticmethod
def backward(ctx, grad_output):
if not grad_output.is_cuda:
raise NotImplementedError
input, offset, mask, weight, bias = ctx.saved_tensors
grad_input = torch.zeros_like(input)
grad_offset = torch.zeros_like(offset)
grad_mask = torch.zeros_like(mask)
grad_weight = torch.zeros_like(weight)
grad_bias = torch.zeros_like(bias)
_C.modulated_deform_conv_backward(
input,
weight,
bias,
ctx._bufs[0],
offset,
mask,
ctx._bufs[1],
grad_input,
grad_weight,
grad_bias,
grad_offset,
grad_mask,
grad_output,
weight.shape[2],
weight.shape[3],
ctx.stride,
ctx.stride,
ctx.padding,
ctx.padding,
ctx.dilation,
ctx.dilation,
ctx.groups,
ctx.deformable_groups,
ctx.with_bias
)
if not ctx.with_bias:
grad_bias = None
return (grad_input, grad_offset, grad_mask, grad_weight, grad_bias,
None, None, None, None, None)
@staticmethod
def _infer_shape(ctx, input, weight):
n = input.size(0)
channels_out = weight.size(0)
height, width = input.shape[2:4]
kernel_h, kernel_w = weight.shape[2:4]
height_out = (height + 2 * ctx.padding -
(ctx.dilation * (kernel_h - 1) + 1)) // ctx.stride + 1
width_out = (width + 2 * ctx.padding -
(ctx.dilation * (kernel_w - 1) + 1)) // ctx.stride + 1
return n, channels_out, height_out, width_out
deform_conv = DeformConvFunction.apply
modulated_deform_conv = ModulatedDeformConvFunction.apply
import math
import torch
import torch.nn as nn
from torch.nn.modules.utils import _pair
from .deform_conv_func import deform_conv, modulated_deform_conv
class DeformConv(nn.Module):
def __init__(
self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
groups=1,
deformable_groups=1,
bias=False
):
assert not bias
super(DeformConv, self).__init__()
self.with_bias = bias
assert in_channels % groups == 0, \
'in_channels {} cannot be divisible by groups {}'.format(
in_channels, groups)
assert out_channels % groups == 0, \
'out_channels {} cannot be divisible by groups {}'.format(
out_channels, groups)
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = _pair(kernel_size)
self.stride = _pair(stride)
self.padding = _pair(padding)
self.dilation = _pair(dilation)
self.groups = groups
self.deformable_groups = deformable_groups
self.weight = nn.Parameter(
torch.Tensor(out_channels, in_channels // self.groups,
*self.kernel_size))
self.reset_parameters()
def reset_parameters(self):
n = self.in_channels
for k in self.kernel_size:
n *= k
stdv = 1. / math.sqrt(n)
self.weight.data.uniform_(-stdv, stdv)
def forward(self, input, offset):
return deform_conv(input, offset, self.weight, self.stride,
self.padding, self.dilation, self.groups,
self.deformable_groups)
def __repr__(self):
return "".join([
"{}(".format(self.__class__.__name__),
"in_channels={}, ".format(self.in_channels),
"out_channels={}, ".format(self.out_channels),
"kernel_size={}, ".format(self.kernel_size),
"stride={}, ".format(self.stride),
"dilation={}, ".format(self.dilation),
"padding={}, ".format(self.padding),
"groups={}, ".format(self.groups),
"deformable_groups={}, ".format(self.deformable_groups),
"bias={})".format(self.with_bias),
])
class ModulatedDeformConv(nn.Module):
def __init__(
self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
groups=1,
deformable_groups=1,
bias=True
):
super(ModulatedDeformConv, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = _pair(kernel_size)
self.stride = stride
self.padding = padding
self.dilation = dilation
self.groups = groups
self.deformable_groups = deformable_groups
self.with_bias = bias
self.weight = nn.Parameter(torch.Tensor(
out_channels,
in_channels // groups,
*self.kernel_size
))
if bias:
self.bias = nn.Parameter(torch.Tensor(out_channels))
else:
self.register_parameter('bias', None)
self.reset_parameters()
def reset_parameters(self):
n = self.in_channels
for k in self.kernel_size:
n *= k
stdv = 1. / math.sqrt(n)
self.weight.data.uniform_(-stdv, stdv)
if self.bias is not None:
self.bias.data.zero_()
def forward(self, input, offset, mask):
return modulated_deform_conv(
input, offset, mask, self.weight, self.bias, self.stride,
self.padding, self.dilation, self.groups, self.deformable_groups)
def __repr__(self):
return "".join([
"{}(".format(self.__class__.__name__),
"in_channels={}, ".format(self.in_channels),
"out_channels={}, ".format(self.out_channels),
"kernel_size={}, ".format(self.kernel_size),
"stride={}, ".format(self.stride),
"dilation={}, ".format(self.dilation),
"padding={}, ".format(self.padding),
"groups={}, ".format(self.groups),
"deformable_groups={}, ".format(self.deformable_groups),
"bias={})".format(self.with_bias),
])
class ModulatedDeformConvPack(ModulatedDeformConv):
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
groups=1,
deformable_groups=1,
bias=True):
super(ModulatedDeformConvPack, self).__init__(
in_channels, out_channels, kernel_size, stride, padding, dilation,
groups, deformable_groups, bias)
self.conv_offset_mask = nn.Conv2d(
self.in_channels // self.groups,
self.deformable_groups * 3 * self.kernel_size[0] *
self.kernel_size[1],
kernel_size=self.kernel_size,
stride=_pair(self.stride),
padding=_pair(self.padding),
bias=True)
self.init_offset()
def init_offset(self):
self.conv_offset_mask.weight.data.zero_()
self.conv_offset_mask.bias.data.zero_()
def forward(self, input):
out = self.conv_offset_mask(input)
o1, o2, mask = torch.chunk(out, 3, dim=1)
offset = torch.cat((o1, o2), dim=1)
mask = torch.sigmoid(mask)
return modulated_deform_conv(
input, offset, mask, self.weight, self.bias, self.stride,
self.padding, self.dilation, self.groups, self.deformable_groups)
import torch
from torch.autograd import Function
from maskrcnn_benchmark import _C
class DeformRoIPoolingFunction(Function):
@staticmethod
def forward(
ctx,
data,
rois,
offset,
spatial_scale,
out_size,
out_channels,
no_trans,
group_size=1,
part_size=None,
sample_per_part=4,
trans_std=.0
):
ctx.spatial_scale = spatial_scale
ctx.out_size = out_size
ctx.out_channels = out_channels
ctx.no_trans = no_trans
ctx.group_size = group_size
ctx.part_size = out_size if part_size is None else part_size
ctx.sample_per_part = sample_per_part
ctx.trans_std = trans_std
assert 0.0 <= ctx.trans_std <= 1.0
if not data.is_cuda:
raise NotImplementedError
n = rois.shape[0]
output = data.new_empty(n, out_channels, out_size, out_size)
output_count = data.new_empty(n, out_channels, out_size, out_size)
_C.deform_psroi_pooling_forward(
data,
rois,
offset,
output,
output_count,
ctx.no_trans,
ctx.spatial_scale,
ctx.out_channels,
ctx.group_size,
ctx.out_size,
ctx.part_size,
ctx.sample_per_part,
ctx.trans_std
)
if data.requires_grad or rois.requires_grad or offset.requires_grad:
ctx.save_for_backward(data, rois, offset)
ctx.output_count = output_count
return output
@staticmethod
def backward(ctx, grad_output):
if not grad_output.is_cuda:
raise NotImplementedError
data, rois, offset = ctx.saved_tensors
output_count = ctx.output_count
grad_input = torch.zeros_like(data)
grad_rois = None
grad_offset = torch.zeros_like(offset)
_C.deform_psroi_pooling_backward(
grad_output,
data,
rois,
offset,
output_count,
grad_input,
grad_offset,
ctx.no_trans,
ctx.spatial_scale,
ctx.out_channels,
ctx.group_size,
ctx.out_size,
ctx.part_size,
ctx.sample_per_part,
ctx.trans_std
)
return (grad_input, grad_rois, grad_offset, None, None, None, None, None, None, None, None)
deform_roi_pooling = DeformRoIPoolingFunction.apply
from torch import nn
from .deform_pool_func import deform_roi_pooling
class DeformRoIPooling(nn.Module):
def __init__(self,
spatial_scale,
out_size,
out_channels,
no_trans,
group_size=1,
part_size=None,
sample_per_part=4,
trans_std=.0):
super(DeformRoIPooling, self).__init__()
self.spatial_scale = spatial_scale
self.out_size = out_size
self.out_channels = out_channels
self.no_trans = no_trans
self.group_size = group_size
self.part_size = out_size if part_size is None else part_size
self.sample_per_part = sample_per_part
self.trans_std = trans_std
def forward(self, data, rois, offset):
if self.no_trans:
offset = data.new_empty(0)
return deform_roi_pooling(
data, rois, offset, self.spatial_scale, self.out_size,
self.out_channels, self.no_trans, self.group_size, self.part_size,
self.sample_per_part, self.trans_std)
class DeformRoIPoolingPack(DeformRoIPooling):
def __init__(self,
spatial_scale,
out_size,
out_channels,
no_trans,
group_size=1,
part_size=None,
sample_per_part=4,
trans_std=.0,
deform_fc_channels=1024):
super(DeformRoIPoolingPack,
self).__init__(spatial_scale, out_size, out_channels, no_trans,
group_size, part_size, sample_per_part, trans_std)
self.deform_fc_channels = deform_fc_channels
if not no_trans:
self.offset_fc = nn.Sequential(
nn.Linear(self.out_size * self.out_size * self.out_channels,
self.deform_fc_channels),
nn.ReLU(inplace=True),
nn.Linear(self.deform_fc_channels, self.deform_fc_channels),
nn.ReLU(inplace=True),
nn.Linear(self.deform_fc_channels,
self.out_size * self.out_size * 2))
self.offset_fc[-1].weight.data.zero_()
self.offset_fc[-1].bias.data.zero_()
def forward(self, data, rois):
assert data.size(1) == self.out_channels
if self.no_trans:
offset = data.new_empty(0)
return deform_roi_pooling(
data, rois, offset, self.spatial_scale, self.out_size,
self.out_channels, self.no_trans, self.group_size,
self.part_size, self.sample_per_part, self.trans_std)
else:
n = rois.shape[0]
offset = data.new_empty(0)
x = deform_roi_pooling(data, rois, offset, self.spatial_scale,
self.out_size, self.out_channels, True,
self.group_size, self.part_size,
self.sample_per_part, self.trans_std)
offset = self.offset_fc(x.view(n, -1))
offset = offset.view(n, 2, self.out_size, self.out_size)
return deform_roi_pooling(
data, rois, offset, self.spatial_scale, self.out_size,
self.out_channels, self.no_trans, self.group_size,
self.part_size, self.sample_per_part, self.trans_std)
class ModulatedDeformRoIPoolingPack(DeformRoIPooling):
def __init__(self,
spatial_scale,
out_size,
out_channels,
no_trans,
group_size=1,
part_size=None,
sample_per_part=4,
trans_std=.0,
deform_fc_channels=1024):
super(ModulatedDeformRoIPoolingPack, self).__init__(
spatial_scale, out_size, out_channels, no_trans, group_size,
part_size, sample_per_part, trans_std)
self.deform_fc_channels = deform_fc_channels
if not no_trans:
self.offset_fc = nn.Sequential(
nn.Linear(self.out_size * self.out_size * self.out_channels,
self.deform_fc_channels),
nn.ReLU(inplace=True),
nn.Linear(self.deform_fc_channels, self.deform_fc_channels),
nn.ReLU(inplace=True),
nn.Linear(self.deform_fc_channels,
self.out_size * self.out_size * 2))
self.offset_fc[-1].weight.data.zero_()
self.offset_fc[-1].bias.data.zero_()
self.mask_fc = nn.Sequential(
nn.Linear(self.out_size * self.out_size * self.out_channels,
self.deform_fc_channels),
nn.ReLU(inplace=True),
nn.Linear(self.deform_fc_channels,
self.out_size * self.out_size * 1),
nn.Sigmoid())
self.mask_fc[2].weight.data.zero_()
self.mask_fc[2].bias.data.zero_()
def forward(self, data, rois):
assert data.size(1) == self.out_channels
if self.no_trans:
offset = data.new_empty(0)
return deform_roi_pooling(
data, rois, offset, self.spatial_scale, self.out_size,
self.out_channels, self.no_trans, self.group_size,
self.part_size, self.sample_per_part, self.trans_std)
else:
n = rois.shape[0]
offset = data.new_empty(0)
x = deform_roi_pooling(data, rois, offset, self.spatial_scale,
self.out_size, self.out_channels, True,
self.group_size, self.part_size,
self.sample_per_part, self.trans_std)
offset = self.offset_fc(x.view(n, -1))
offset = offset.view(n, 2, self.out_size, self.out_size)
mask = self.mask_fc(x.view(n, -1))
mask = mask.view(n, 1, self.out_size, self.out_size)
return deform_roi_pooling(
data, rois, offset, self.spatial_scale, self.out_size,
self.out_channels, self.no_trans, self.group_size,
self.part_size, self.sample_per_part, self.trans_std) * mask
......@@ -11,6 +11,7 @@ is implemented
import math
import torch
from torch import nn
from torch.nn.modules.utils import _ntuple
......@@ -108,3 +109,86 @@ def interpolate(
output_shape = tuple(_output_size(2))
output_shape = input.shape[:-2] + output_shape
return _NewEmptyTensorOp.apply(input, output_shape)
class DFConv2d(nn.Module):
"""Deformable convolutional layer"""
def __init__(
self,
in_channels,
out_channels,
with_modulated_dcn=True,
kernel_size=3,
stride=1,
groups=1,
dilation=1,
deformable_groups=1,
bias=False
):
super(DFConv2d, self).__init__()
if isinstance(kernel_size, (list, tuple)):
assert len(kernel_size) == 2
offset_base_channels = kernel_size[0] * kernel_size[1]
else:
offset_base_channels = kernel_size * kernel_size
if with_modulated_dcn:
from maskrcnn_benchmark.layers import ModulatedDeformConv
offset_channels = offset_base_channels * 3 #default: 27
conv_block = ModulatedDeformConv
else:
from maskrcnn_benchmark.layers import DeformConv
offset_channels = offset_base_channels * 2 #default: 18
conv_block = DeformConv
self.offset = Conv2d(
in_channels,
deformable_groups * offset_channels,
kernel_size=kernel_size,
stride= stride,
padding= dilation,
groups=1,
dilation=dilation
)
for l in [self.offset,]:
nn.init.kaiming_uniform_(l.weight, a=1)
torch.nn.init.constant_(l.bias, 0.)
self.conv = conv_block(
in_channels,
out_channels,
kernel_size=kernel_size,
stride= stride,
padding=dilation,
dilation=dilation,
groups=groups,
deformable_groups=deformable_groups,
bias=bias
)
self.with_modulated_dcn = with_modulated_dcn
self.kernel_size = kernel_size
self.stride = stride
self.padding = dilation
self.dilation = dilation
def forward(self, x):
if x.numel() > 0:
if not self.with_modulated_dcn:
offset = self.offset(x)
x = self.conv(x, offset)
else:
offset_mask = self.offset(x)
offset = offset_mask[:, :18, :, :]
mask = offset_mask[:, -9:, :, :].sigmoid()
x = self.conv(x, offset, mask)
return x
# get output shape
output_shape = [
(i + 2 * p - (di * (k - 1) + 1)) // d + 1
for i, p, di, k, d in zip(
x.shape[-2:],
self.padding,
self.dilation,
self.kernel_size,
self.stride
)
]
output_shape = [x.shape[0], self.conv.weight.shape[0]] + output_shape
return _NewEmptyTensorOp.apply(x, output_shape)
......@@ -24,6 +24,7 @@ from torch import nn
from maskrcnn_benchmark.layers import FrozenBatchNorm2d
from maskrcnn_benchmark.layers import Conv2d
from maskrcnn_benchmark.layers import DFConv2d
from maskrcnn_benchmark.modeling.make_layers import group_norm
from maskrcnn_benchmark.utils.registry import Registry
......@@ -106,6 +107,7 @@ class ResNet(nn.Module):
stage2_relative_factor = 2 ** (stage_spec.index - 1)
bottleneck_channels = stage2_bottleneck_channels * stage2_relative_factor
out_channels = stage2_out_channels * stage2_relative_factor
stage_with_dcn = cfg.MODEL.RESNETS.STAGE_WITH_DCN[stage_spec.index -1]
module = _make_stage(
transformation_module,
in_channels,
......@@ -115,6 +117,11 @@ class ResNet(nn.Module):
num_groups,
cfg.MODEL.RESNETS.STRIDE_IN_1X1,
first_stride=int(stage_spec.index > 1) + 1,
dcn_config={
"stage_with_dcn": stage_with_dcn,
"with_modulated_dcn": cfg.MODEL.RESNETS.WITH_MODULATED_DCN,
"deformable_groups": cfg.MODEL.RESNETS.DEFORMABLE_GROUPS,
}
)
in_channels = out_channels
self.add_module(name, module)
......@@ -155,7 +162,8 @@ class ResNetHead(nn.Module):
stride_in_1x1=True,
stride_init=None,
res2_out_channels=256,
dilation=1
dilation=1,
dcn_config={}
):
super(ResNetHead, self).__init__()
......@@ -182,7 +190,8 @@ class ResNetHead(nn.Module):
num_groups,
stride_in_1x1,
first_stride=stride,
dilation=dilation
dilation=dilation,
dcn_config=dcn_config
)
stride = None
self.add_module(name, module)
......@@ -204,7 +213,8 @@ def _make_stage(
num_groups,
stride_in_1x1,
first_stride,
dilation=1
dilation=1,
dcn_config={}
):
blocks = []
stride = first_stride
......@@ -217,7 +227,8 @@ def _make_stage(
num_groups,
stride_in_1x1,
stride,
dilation=dilation
dilation=dilation,
dcn_config=dcn_config
)
)
stride = 1
......@@ -235,7 +246,8 @@ class Bottleneck(nn.Module):
stride_in_1x1,
stride,
dilation,
norm_func
norm_func,
dcn_config
):
super(Bottleneck, self).__init__()
......@@ -271,17 +283,34 @@ class Bottleneck(nn.Module):
)
self.bn1 = norm_func(bottleneck_channels)
# TODO: specify init for the above
with_dcn = dcn_config.get("stage_with_dcn", False)
if with_dcn:
deformable_groups = dcn_config.get("deformable_groups", 1)
with_modulated_dcn = dcn_config.get("with_modulated_dcn", False)
self.conv2 = DFConv2d(
bottleneck_channels,
bottleneck_channels,
with_modulated_dcn=with_modulated_dcn,
kernel_size=3,
stride=stride_3x3,
groups=num_groups,
dilation=dilation,
deformable_groups=deformable_groups,
bias=False
)
else:
self.conv2 = Conv2d(
bottleneck_channels,
bottleneck_channels,
kernel_size=3,
stride=stride_3x3,
padding=dilation,
bias=False,
groups=num_groups,
dilation=dilation
)
nn.init.kaiming_uniform_(self.conv2.weight, a=1)
self.conv2 = Conv2d(
bottleneck_channels,
bottleneck_channels,
kernel_size=3,
stride=stride_3x3,
padding=dilation,
bias=False,
groups=num_groups,
dilation=dilation
)
self.bn2 = norm_func(bottleneck_channels)
self.conv3 = Conv2d(
......@@ -289,7 +318,7 @@ class Bottleneck(nn.Module):
)
self.bn3 = norm_func(out_channels)
for l in [self.conv1, self.conv2, self.conv3,]:
for l in [self.conv1, self.conv3,]:
nn.init.kaiming_uniform_(l.weight, a=1)
def forward(self, x):
......@@ -346,7 +375,8 @@ class BottleneckWithFixedBatchNorm(Bottleneck):
num_groups=1,
stride_in_1x1=True,
stride=1,
dilation=1
dilation=1,
dcn_config={}
):
super(BottleneckWithFixedBatchNorm, self).__init__(
in_channels=in_channels,
......@@ -356,7 +386,8 @@ class BottleneckWithFixedBatchNorm(Bottleneck):
stride_in_1x1=stride_in_1x1,
stride=stride,
dilation=dilation,
norm_func=FrozenBatchNorm2d
norm_func=FrozenBatchNorm2d,
dcn_config=dcn_config
)
......@@ -376,7 +407,8 @@ class BottleneckWithGN(Bottleneck):
num_groups=1,
stride_in_1x1=True,
stride=1,
dilation=1
dilation=1,
dcn_config={}
):
super(BottleneckWithGN, self).__init__(
in_channels=in_channels,
......@@ -386,7 +418,8 @@ class BottleneckWithGN(Bottleneck):
stride_in_1x1=stride_in_1x1,
stride=stride,
dilation=dilation,
norm_func=group_norm
norm_func=group_norm,
dcn_config=dcn_config
)
......
......@@ -143,6 +143,33 @@ def _load_c2_pickled_weights(file_path):
return weights
def _rename_conv_weights_for_deformable_conv_layers(state_dict, cfg):
import re
logger = logging.getLogger(__name__)
logger.info("Remapping conv weights for deformable conv weights")
layer_keys = sorted(state_dict.keys())
for ix, stage_with_dcn in enumerate(cfg.MODEL.RESNETS.STAGE_WITH_DCN, 1):
if not stage_with_dcn:
continue
for old_key in layer_keys:
pattern = ".*layer{}.*conv2.*".format(ix)
r = re.match(pattern, old_key)
if r is None:
continue
for param in ["weight", "bias"]:
if old_key.find(param) is -1:
continue
new_key = old_key.replace(
"conv2.{}".format(param), "conv2.conv.{}".format(param)
)
logger.info("pattern: {}, old_key: {}, new_key: {}".format(
pattern, old_key, new_key
))
state_dict[new_key] = state_dict[old_key]
del state_dict[old_key]
return state_dict
_C2_STAGE_NAMES = {
"R-50": ["1.2", "2.3", "3.5", "4.2"],
"R-101": ["1.2", "2.3", "3.22", "4.2"],
......@@ -168,6 +195,10 @@ def load_resnet_c2_format(cfg, f):
arch = arch.replace("-RETINANET", "")
stages = _C2_STAGE_NAMES[arch]
state_dict = _rename_weights_for_resnet(state_dict, stages)
# ***********************************
# for deformable convolutional layer
state_dict = _rename_conv_weights_for_deformable_conv_layers(state_dict, cfg)
# ***********************************
return dict(model=state_dict)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment