Commit 1d6e9add authored by zimenglan's avatar zimenglan Committed by Francisco Massa

add dcn from mmdetection (#693)

* make pixel indexes 0-based for bounding box in pascal voc dataset

* replacing all instances of torch.distributed.deprecated with torch.distributed

* replacing all instances of torch.distributed.deprecated with torch.distributed

* add GroupNorm

* add GroupNorm -- sort out yaml files

* use torch.nn.GroupNorm instead, replace 'use_gn' with 'conv_block' and use 'BaseStem'&'Bottleneck' to simply codes

* modification on 'group_norm' and 'conv_with_kaiming_uniform' function

* modification on yaml files in configs/gn_baselines/ and reduce the amount of indentation and code duplication

* use 'kaiming_uniform' to initialize resnet, disable gn after fc layer, and add dilation into ResNetHead

* agnostic-regression for bbox

* please set 'STRIDE_IN_1X1' to be 'False' when backbone use GN

* add README.md for GN

* add dcn from mmdetection
parent 1714b7c2
...@@ -28,3 +28,4 @@ dist/ ...@@ -28,3 +28,4 @@ dist/
# project dirs # project dirs
/datasets /datasets
/models /models
/output
### Reference
1 [Deformable ConvNets v2: More Deformable, Better Results](https://arxiv.org/pdf/1811.11168.pdf)
2 third-party: [mmdetection](https://github.com/open-mmlab/mmdetection/tree/master/configs/dcn)
### Performance
| case | bbox AP | mask AP |
|----------------------------:|--------:|:-------:|
| R-50-FPN-dcn (implement) | 39.8 | - |
| R-50-FPN-dcn (mmdetection) | 40.0 | - |
| R-50-FPN-mdcn (implement) | 40.0 | - |
| R-50-FPN-mdcn (mmdetection) | 40.3 | - |
| R-50-FPN-dcn (implement) | 40.8 | 36.8 |
| R-50-FPN-dcn (mmdetection) | 41.1 | 37.2 |
| R-50-FPN-dcn (implement) | 40.7 | 36.7 |
| R-50-FPN-dcn (mmdetection) | 41.4 | 37.4 |
### Note
see [dcn-v2](https://github.com/open-mmlab/mmdetection/blob/master/MODEL_ZOO.md#deformable-convolution-v2) in `mmdetection` for more details.
### Usage
add these three lines
```
MODEL:
RESNETS:
# corresponding to C2,C3,C4,C5
STAGE_WITH_DCN: (False, True, True, True)
WITH_MODULATED_DCN: True
DEFORMABLE_GROUPS: 1
```
\ No newline at end of file
INPUT:
MIN_SIZE_TRAIN: (800,)
MAX_SIZE_TRAIN: 1333
MIN_SIZE_TEST: 800
MAX_SIZE_TEST: 1333
MODEL:
META_ARCHITECTURE: "GeneralizedRCNN"
WEIGHT: "catalog://ImageNetPretrained/MSRA/R-50"
BACKBONE:
CONV_BODY: "R-50-FPN"
RESNETS:
BACKBONE_OUT_CHANNELS: 256
STAGE_WITH_DCN: (False, True, True, True)
WITH_MODULATED_DCN: False
DEFORMABLE_GROUPS: 1
RPN:
USE_FPN: True
ANCHOR_STRIDE: (4, 8, 16, 32, 64)
PRE_NMS_TOP_N_TRAIN: 2000
PRE_NMS_TOP_N_TEST: 1000
POST_NMS_TOP_N_TEST: 1000
FPN_POST_NMS_TOP_N_TEST: 1000
ROI_HEADS:
USE_FPN: True
ROI_BOX_HEAD:
POOLER_RESOLUTION: 7
POOLER_SCALES: (0.25, 0.125, 0.0625, 0.03125)
POOLER_SAMPLING_RATIO: 2
FEATURE_EXTRACTOR: "FPN2MLPFeatureExtractor"
PREDICTOR: "FPNPredictor"
DATASETS:
TRAIN: ("coco_2014_train", "coco_2014_valminusminival")
TEST: ("coco_2014_minival",)
DATALOADER:
SIZE_DIVISIBILITY: 32
SOLVER:
# Assume 8 gpus
BASE_LR: 0.02
WEIGHT_DECAY: 0.0001
STEPS: (60000, 80000)
MAX_ITER: 90000
IMS_PER_BATCH: 16
TEST:
IMS_PER_BATCH: 8
INPUT:
MIN_SIZE_TRAIN: (800,)
MAX_SIZE_TRAIN: 1333
MIN_SIZE_TEST: 800
MAX_SIZE_TEST: 1333
MODEL:
META_ARCHITECTURE: "GeneralizedRCNN"
WEIGHT: "catalog://ImageNetPretrained/MSRA/R-50"
BACKBONE:
CONV_BODY: "R-50-FPN"
RESNETS:
BACKBONE_OUT_CHANNELS: 256
STAGE_WITH_DCN: (False, True, True, True)
WITH_MODULATED_DCN: True
DEFORMABLE_GROUPS: 1
RPN:
USE_FPN: True
ANCHOR_STRIDE: (4, 8, 16, 32, 64)
PRE_NMS_TOP_N_TRAIN: 2000
PRE_NMS_TOP_N_TEST: 1000
POST_NMS_TOP_N_TEST: 1000
FPN_POST_NMS_TOP_N_TEST: 1000
ROI_HEADS:
USE_FPN: True
ROI_BOX_HEAD:
POOLER_RESOLUTION: 7
POOLER_SCALES: (0.25, 0.125, 0.0625, 0.03125)
POOLER_SAMPLING_RATIO: 2
FEATURE_EXTRACTOR: "FPN2MLPFeatureExtractor"
PREDICTOR: "FPNPredictor"
DATASETS:
TRAIN: ("coco_2014_train", "coco_2014_valminusminival")
TEST: ("coco_2014_minival",)
DATALOADER:
SIZE_DIVISIBILITY: 32
SOLVER:
# Assume 8 gpus
BASE_LR: 0.02
WEIGHT_DECAY: 0.0001
STEPS: (60000, 80000)
MAX_ITER: 90000
IMS_PER_BATCH: 16
TEST:
IMS_PER_BATCH: 8
INPUT:
MIN_SIZE_TRAIN: (800,)
MAX_SIZE_TRAIN: 1333
MIN_SIZE_TEST: 800
MAX_SIZE_TEST: 1333
MODEL:
META_ARCHITECTURE: "GeneralizedRCNN"
WEIGHT: "catalog://ImageNetPretrained/MSRA/R-50"
BACKBONE:
CONV_BODY: "R-50-FPN"
RESNETS:
BACKBONE_OUT_CHANNELS: 256
STAGE_WITH_DCN: (False, True, True, True)
WITH_MODULATED_DCN: False
DEFORMABLE_GROUPS: 1
RPN:
USE_FPN: True
ANCHOR_STRIDE: (4, 8, 16, 32, 64)
PRE_NMS_TOP_N_TRAIN: 2000
PRE_NMS_TOP_N_TEST: 1000
POST_NMS_TOP_N_TEST: 1000
FPN_POST_NMS_TOP_N_TEST: 1000
ROI_HEADS:
USE_FPN: True
ROI_BOX_HEAD:
POOLER_RESOLUTION: 7
POOLER_SCALES: (0.25, 0.125, 0.0625, 0.03125)
POOLER_SAMPLING_RATIO: 2
FEATURE_EXTRACTOR: "FPN2MLPFeatureExtractor"
PREDICTOR: "FPNPredictor"
ROI_MASK_HEAD:
POOLER_SCALES: (0.25, 0.125, 0.0625, 0.03125)
FEATURE_EXTRACTOR: "MaskRCNNFPNFeatureExtractor"
PREDICTOR: "MaskRCNNC4Predictor"
POOLER_RESOLUTION: 14
POOLER_SAMPLING_RATIO: 2
RESOLUTION: 28
SHARE_BOX_FEATURE_EXTRACTOR: False
MASK_ON: True
DATASETS:
TRAIN: ("coco_2014_train", "coco_2014_valminusminival")
TEST: ("coco_2014_minival",)
DATALOADER:
SIZE_DIVISIBILITY: 32
SOLVER:
# Assume 8 gpus
BASE_LR: 0.02
WEIGHT_DECAY: 0.0001
STEPS: (60000, 80000)
MAX_ITER: 90000
IMS_PER_BATCH: 16
TEST:
IMS_PER_BATCH: 8
INPUT:
MIN_SIZE_TRAIN: (800,)
MAX_SIZE_TRAIN: 1333
MIN_SIZE_TEST: 800
MAX_SIZE_TEST: 1333
MODEL:
META_ARCHITECTURE: "GeneralizedRCNN"
WEIGHT: "catalog://ImageNetPretrained/MSRA/R-50"
BACKBONE:
CONV_BODY: "R-50-FPN"
RESNETS:
BACKBONE_OUT_CHANNELS: 256
STAGE_WITH_DCN: (False, True, True, True)
WITH_MODULATED_DCN: True
DEFORMABLE_GROUPS: 1
RPN:
USE_FPN: True
ANCHOR_STRIDE: (4, 8, 16, 32, 64)
PRE_NMS_TOP_N_TRAIN: 2000
PRE_NMS_TOP_N_TEST: 1000
POST_NMS_TOP_N_TEST: 1000
FPN_POST_NMS_TOP_N_TEST: 1000
ROI_HEADS:
USE_FPN: True
ROI_BOX_HEAD:
POOLER_RESOLUTION: 7
POOLER_SCALES: (0.25, 0.125, 0.0625, 0.03125)
POOLER_SAMPLING_RATIO: 2
FEATURE_EXTRACTOR: "FPN2MLPFeatureExtractor"
PREDICTOR: "FPNPredictor"
ROI_MASK_HEAD:
POOLER_SCALES: (0.25, 0.125, 0.0625, 0.03125)
FEATURE_EXTRACTOR: "MaskRCNNFPNFeatureExtractor"
PREDICTOR: "MaskRCNNC4Predictor"
POOLER_RESOLUTION: 14
POOLER_SAMPLING_RATIO: 2
RESOLUTION: 28
SHARE_BOX_FEATURE_EXTRACTOR: False
MASK_ON: True
DATASETS:
TRAIN: ("coco_2014_train", "coco_2014_valminusminival")
TEST: ("coco_2014_minival",)
DATALOADER:
SIZE_DIVISIBILITY: 32
SOLVER:
# Assume 8 gpus
BASE_LR: 0.02
WEIGHT_DECAY: 0.0001
STEPS: (60000, 80000)
MAX_ITER: 90000
IMS_PER_BATCH: 16
TEST:
IMS_PER_BATCH: 8
...@@ -274,6 +274,10 @@ _C.MODEL.RESNETS.BACKBONE_OUT_CHANNELS = 256 * 4 ...@@ -274,6 +274,10 @@ _C.MODEL.RESNETS.BACKBONE_OUT_CHANNELS = 256 * 4
_C.MODEL.RESNETS.RES2_OUT_CHANNELS = 256 _C.MODEL.RESNETS.RES2_OUT_CHANNELS = 256
_C.MODEL.RESNETS.STEM_OUT_CHANNELS = 64 _C.MODEL.RESNETS.STEM_OUT_CHANNELS = 64
_C.MODEL.RESNETS.STAGE_WITH_DCN = (False, False, False, False)
_C.MODEL.RESNETS.WITH_MODULATED_DCN = False
_C.MODEL.RESNETS.DEFORMABLE_GROUPS = 1
# ---------------------------------------------------------------------------- # # ---------------------------------------------------------------------------- #
# RetinaNet Options (Follow the Detectron version) # RetinaNet Options (Follow the Detectron version)
......
This diff is collapsed.
This diff is collapsed.
// modify from
// https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/modulated_dcn_cuda.c
// based on
// author: Charles Shang
// https://github.com/torch/cunn/blob/master/lib/THCUNN/generic/SpatialConvolutionMM.cu
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <THC/THC.h>
#include <THC/THCDeviceUtils.cuh>
#include <vector>
#include <iostream>
#include <cmath>
void DeformablePSROIPoolForward(
const at::Tensor data, const at::Tensor bbox, const at::Tensor trans,
at::Tensor out, at::Tensor top_count, const int batch, const int channels,
const int height, const int width, const int num_bbox,
const int channels_trans, const int no_trans, const float spatial_scale,
const int output_dim, const int group_size, const int pooled_size,
const int part_size, const int sample_per_part, const float trans_std);
void DeformablePSROIPoolBackwardAcc(
const at::Tensor out_grad, const at::Tensor data, const at::Tensor bbox,
const at::Tensor trans, const at::Tensor top_count, at::Tensor in_grad,
at::Tensor trans_grad, const int batch, const int channels,
const int height, const int width, const int num_bbox,
const int channels_trans, const int no_trans, const float spatial_scale,
const int output_dim, const int group_size, const int pooled_size,
const int part_size, const int sample_per_part, const float trans_std);
void deform_psroi_pooling_cuda_forward(
at::Tensor input, at::Tensor bbox, at::Tensor trans, at::Tensor out,
at::Tensor top_count, const int no_trans, const float spatial_scale,
const int output_dim, const int group_size, const int pooled_size,
const int part_size, const int sample_per_part, const float trans_std)
{
AT_CHECK(input.is_contiguous(), "input tensor has to be contiguous");
const int batch = input.size(0);
const int channels = input.size(1);
const int height = input.size(2);
const int width = input.size(3);
const int channels_trans = no_trans ? 2 : trans.size(1);
const int num_bbox = bbox.size(0);
if (num_bbox != out.size(0))
AT_ERROR("Output shape and bbox number wont match: (%d vs %d).",
out.size(0), num_bbox);
DeformablePSROIPoolForward(
input, bbox, trans, out, top_count, batch, channels, height, width,
num_bbox, channels_trans, no_trans, spatial_scale, output_dim, group_size,
pooled_size, part_size, sample_per_part, trans_std);
}
void deform_psroi_pooling_cuda_backward(
at::Tensor out_grad, at::Tensor input, at::Tensor bbox, at::Tensor trans,
at::Tensor top_count, at::Tensor input_grad, at::Tensor trans_grad,
const int no_trans, const float spatial_scale, const int output_dim,
const int group_size, const int pooled_size, const int part_size,
const int sample_per_part, const float trans_std)
{
AT_CHECK(out_grad.is_contiguous(), "out_grad tensor has to be contiguous");
AT_CHECK(input.is_contiguous(), "input tensor has to be contiguous");
const int batch = input.size(0);
const int channels = input.size(1);
const int height = input.size(2);
const int width = input.size(3);
const int channels_trans = no_trans ? 2 : trans.size(1);
const int num_bbox = bbox.size(0);
if (num_bbox != out_grad.size(0))
AT_ERROR("Output shape and bbox number wont match: (%d vs %d).",
out_grad.size(0), num_bbox);
DeformablePSROIPoolBackwardAcc(
out_grad, input, bbox, trans, top_count, input_grad, trans_grad, batch,
channels, height, width, num_bbox, channels_trans, no_trans,
spatial_scale, output_dim, group_size, pooled_size, part_size,
sample_per_part, trans_std);
}
This diff is collapsed.
...@@ -58,6 +58,59 @@ at::Tensor ROIPool_backward_cuda(const at::Tensor& grad, ...@@ -58,6 +58,59 @@ at::Tensor ROIPool_backward_cuda(const at::Tensor& grad,
at::Tensor nms_cuda(const at::Tensor boxes, float nms_overlap_thresh); at::Tensor nms_cuda(const at::Tensor boxes, float nms_overlap_thresh);
int deform_conv_forward_cuda(at::Tensor input, at::Tensor weight,
at::Tensor offset, at::Tensor output,
at::Tensor columns, at::Tensor ones, int kW,
int kH, int dW, int dH, int padW, int padH,
int dilationW, int dilationH, int group,
int deformable_group, int im2col_step);
int deform_conv_backward_input_cuda(at::Tensor input, at::Tensor offset,
at::Tensor gradOutput, at::Tensor gradInput,
at::Tensor gradOffset, at::Tensor weight,
at::Tensor columns, int kW, int kH, int dW,
int dH, int padW, int padH, int dilationW,
int dilationH, int group,
int deformable_group, int im2col_step);
int deform_conv_backward_parameters_cuda(
at::Tensor input, at::Tensor offset, at::Tensor gradOutput,
at::Tensor gradWeight, // at::Tensor gradBias,
at::Tensor columns, at::Tensor ones, int kW, int kH, int dW, int dH,
int padW, int padH, int dilationW, int dilationH, int group,
int deformable_group, float scale, int im2col_step);
void modulated_deform_conv_cuda_forward(
at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones,
at::Tensor offset, at::Tensor mask, at::Tensor output, at::Tensor columns,
int kernel_h, int kernel_w, const int stride_h, const int stride_w,
const int pad_h, const int pad_w, const int dilation_h,
const int dilation_w, const int group, const int deformable_group,
const bool with_bias);
void modulated_deform_conv_cuda_backward(
at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones,
at::Tensor offset, at::Tensor mask, at::Tensor columns,
at::Tensor grad_input, at::Tensor grad_weight, at::Tensor grad_bias,
at::Tensor grad_offset, at::Tensor grad_mask, at::Tensor grad_output,
int kernel_h, int kernel_w, int stride_h, int stride_w, int pad_h,
int pad_w, int dilation_h, int dilation_w, int group, int deformable_group,
const bool with_bias);
void deform_psroi_pooling_cuda_forward(
at::Tensor input, at::Tensor bbox, at::Tensor trans, at::Tensor out,
at::Tensor top_count, const int no_trans, const float spatial_scale,
const int output_dim, const int group_size, const int pooled_size,
const int part_size, const int sample_per_part, const float trans_std);
void deform_psroi_pooling_cuda_backward(
at::Tensor out_grad, at::Tensor input, at::Tensor bbox, at::Tensor trans,
at::Tensor top_count, at::Tensor input_grad, at::Tensor trans_grad,
const int no_trans, const float spatial_scale, const int output_dim,
const int group_size, const int pooled_size, const int part_size,
const int sample_per_part, const float trans_std);
at::Tensor compute_flow_cuda(const at::Tensor& boxes, at::Tensor compute_flow_cuda(const at::Tensor& boxes,
const int height, const int height,
const int width); const int width);
// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
#pragma once
#include "cpu/vision.h"
#ifdef WITH_CUDA
#include "cuda/vision.h"
#endif
// Interface for Python
int deform_conv_forward(
at::Tensor input,
at::Tensor weight,
at::Tensor offset,
at::Tensor output,
at::Tensor columns,
at::Tensor ones,
int kW,
int kH,
int dW,
int dH,
int padW,
int padH,
int dilationW,
int dilationH,
int group,
int deformable_group,
int im2col_step)
{
if (input.type().is_cuda()) {
#ifdef WITH_CUDA
return deform_conv_forward_cuda(
input, weight, offset, output, columns, ones,
kW, kH, dW, dH, padW, padH, dilationW, dilationH,
group, deformable_group, im2col_step
);
#else
AT_ERROR("Not compiled with GPU support");
#endif
}
AT_ERROR("Not implemented on the CPU");
}
int deform_conv_backward_input(
at::Tensor input,
at::Tensor offset,
at::Tensor gradOutput,
at::Tensor gradInput,
at::Tensor gradOffset,
at::Tensor weight,
at::Tensor columns,
int kW,
int kH,
int dW,
int dH,
int padW,
int padH,
int dilationW,
int dilationH,
int group,
int deformable_group,
int im2col_step)
{
if (input.type().is_cuda()) {
#ifdef WITH_CUDA
return deform_conv_backward_input_cuda(
input, offset, gradOutput, gradInput, gradOffset, weight, columns,
kW, kH, dW, dH, padW, padH, dilationW, dilationH,
group, deformable_group, im2col_step
);
#else
AT_ERROR("Not compiled with GPU support");
#endif
}
AT_ERROR("Not implemented on the CPU");
}
int deform_conv_backward_parameters(
at::Tensor input,
at::Tensor offset,
at::Tensor gradOutput,
at::Tensor gradWeight, // at::Tensor gradBias,
at::Tensor columns,
at::Tensor ones,
int kW,
int kH,
int dW,
int dH,
int padW,
int padH,
int dilationW,
int dilationH,
int group,
int deformable_group,
float scale,
int im2col_step)
{
if (input.type().is_cuda()) {
#ifdef WITH_CUDA
return deform_conv_backward_parameters_cuda(
input, offset, gradOutput, gradWeight, columns, ones,
kW, kH, dW, dH, padW, padH, dilationW, dilationH,
group, deformable_group, scale, im2col_step
);
#else
AT_ERROR("Not compiled with GPU support");
#endif
}
AT_ERROR("Not implemented on the CPU");
}
void modulated_deform_conv_forward(
at::Tensor input,
at::Tensor weight,
at::Tensor bias,
at::Tensor ones,
at::Tensor offset,
at::Tensor mask,
at::Tensor output,
at::Tensor columns,
int kernel_h,
int kernel_w,
const int stride_h,
const int stride_w,
const int pad_h,
const int pad_w,
const int dilation_h,
const int dilation_w,
const int group,
const int deformable_group,
const bool with_bias)
{
if (input.type().is_cuda()) {
#ifdef WITH_CUDA
return modulated_deform_conv_cuda_forward(
input, weight, bias, ones, offset, mask, output, columns,
kernel_h, kernel_w, stride_h, stride_w,
pad_h, pad_w, dilation_h, dilation_w,
group, deformable_group, with_bias
);
#else
AT_ERROR("Not compiled with GPU support");
#endif
}
AT_ERROR("Not implemented on the CPU");
}
void modulated_deform_conv_backward(
at::Tensor input,
at::Tensor weight,
at::Tensor bias,
at::Tensor ones,
at::Tensor offset,
at::Tensor mask,
at::Tensor columns,
at::Tensor grad_input,
at::Tensor grad_weight,
at::Tensor grad_bias,
at::Tensor grad_offset,
at::Tensor grad_mask,
at::Tensor grad_output,
int kernel_h,
int kernel_w,
int stride_h,
int stride_w,
int pad_h,
int pad_w,
int dilation_h,
int dilation_w,
int group,
int deformable_group,
const bool with_bias)
{
if (input.type().is_cuda()) {
#ifdef WITH_CUDA
return modulated_deform_conv_cuda_backward(
input, weight, bias, ones, offset, mask, columns,
grad_input, grad_weight, grad_bias, grad_offset, grad_mask, grad_output,
kernel_h, kernel_w, stride_h, stride_w, pad_h, pad_w, dilation_h, dilation_w,
group, deformable_group, with_bias
);
#else
AT_ERROR("Not compiled with GPU support");
#endif
}
AT_ERROR("Not implemented on the CPU");
}
\ No newline at end of file
// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
#pragma once
#include "cpu/vision.h"
#ifdef WITH_CUDA
#include "cuda/vision.h"
#endif
// Interface for Python
void deform_psroi_pooling_forward(
at::Tensor input,
at::Tensor bbox,
at::Tensor trans,
at::Tensor out,
at::Tensor top_count,
const int no_trans,
const float spatial_scale,
const int output_dim,
const int group_size,
const int pooled_size,
const int part_size,
const int sample_per_part,
const float trans_std)
{
if (input.type().is_cuda()) {
#ifdef WITH_CUDA
return deform_psroi_pooling_cuda_forward(
input, bbox, trans, out, top_count,
no_trans, spatial_scale, output_dim, group_size,
pooled_size, part_size, sample_per_part, trans_std
);
#else
AT_ERROR("Not compiled with GPU support");
#endif
}
AT_ERROR("Not implemented on the CPU");
}
void deform_psroi_pooling_backward(
at::Tensor out_grad,
at::Tensor input,
at::Tensor bbox,
at::Tensor trans,
at::Tensor top_count,
at::Tensor input_grad,
at::Tensor trans_grad,
const int no_trans,
const float spatial_scale,
const int output_dim,
const int group_size,
const int pooled_size,
const int part_size,
const int sample_per_part,
const float trans_std)
{
if (input.type().is_cuda()) {
#ifdef WITH_CUDA
return deform_psroi_pooling_cuda_backward(
out_grad, input, bbox, trans, top_count, input_grad, trans_grad,
no_trans, spatial_scale, output_dim, group_size, pooled_size,
part_size, sample_per_part, trans_std
);
#else
AT_ERROR("Not compiled with GPU support");
#endif
}
AT_ERROR("Not implemented on the CPU");
}
...@@ -3,6 +3,8 @@ ...@@ -3,6 +3,8 @@
#include "ROIAlign.h" #include "ROIAlign.h"
#include "ROIPool.h" #include "ROIPool.h"
#include "SigmoidFocalLoss.h" #include "SigmoidFocalLoss.h"
#include "deform_conv.h"
#include "deform_pool.h"
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("nms", &nms, "non-maximum suppression"); m.def("nms", &nms, "non-maximum suppression");
...@@ -12,4 +14,12 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -12,4 +14,12 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("roi_pool_backward", &ROIPool_backward, "ROIPool_backward"); m.def("roi_pool_backward", &ROIPool_backward, "ROIPool_backward");
m.def("sigmoid_focalloss_forward", &SigmoidFocalLoss_forward, "SigmoidFocalLoss_forward"); m.def("sigmoid_focalloss_forward", &SigmoidFocalLoss_forward, "SigmoidFocalLoss_forward");
m.def("sigmoid_focalloss_backward", &SigmoidFocalLoss_backward, "SigmoidFocalLoss_backward"); m.def("sigmoid_focalloss_backward", &SigmoidFocalLoss_backward, "SigmoidFocalLoss_backward");
// dcn-v2
m.def("deform_conv_forward", &deform_conv_forward, "deform_conv_forward");
m.def("deform_conv_backward_input", &deform_conv_backward_input, "deform_conv_backward_input");
m.def("deform_conv_backward_parameters", &deform_conv_backward_parameters, "deform_conv_backward_parameters");
m.def("modulated_deform_conv_forward", &modulated_deform_conv_forward, "modulated_deform_conv_forward");
m.def("modulated_deform_conv_backward", &modulated_deform_conv_backward, "modulated_deform_conv_backward");
m.def("deform_psroi_pooling_forward", &deform_psroi_pooling_forward, "deform_psroi_pooling_forward");
m.def("deform_psroi_pooling_backward", &deform_psroi_pooling_backward, "deform_psroi_pooling_backward");
} }
\ No newline at end of file
...@@ -3,6 +3,7 @@ import torch ...@@ -3,6 +3,7 @@ import torch
from .batch_norm import FrozenBatchNorm2d from .batch_norm import FrozenBatchNorm2d
from .misc import Conv2d from .misc import Conv2d
from .misc import DFConv2d
from .misc import ConvTranspose2d from .misc import ConvTranspose2d
from .misc import BatchNorm2d from .misc import BatchNorm2d
from .misc import interpolate from .misc import interpolate
...@@ -13,9 +14,34 @@ from .roi_pool import ROIPool ...@@ -13,9 +14,34 @@ from .roi_pool import ROIPool
from .roi_pool import roi_pool from .roi_pool import roi_pool
from .smooth_l1_loss import smooth_l1_loss from .smooth_l1_loss import smooth_l1_loss
from .sigmoid_focal_loss import SigmoidFocalLoss from .sigmoid_focal_loss import SigmoidFocalLoss
from .dcn.deform_conv_func import deform_conv, modulated_deform_conv
from .dcn.deform_conv_module import DeformConv, ModulatedDeformConv, ModulatedDeformConvPack
from .dcn.deform_pool_func import deform_roi_pooling
from .dcn.deform_pool_module import DeformRoIPooling, DeformRoIPoolingPack, ModulatedDeformRoIPoolingPack
__all__ = ["nms", "roi_align", "ROIAlign", "roi_pool", "ROIPool",
"smooth_l1_loss", "Conv2d", "ConvTranspose2d", "interpolate", __all__ = [
"BatchNorm2d", "FrozenBatchNorm2d", "SigmoidFocalLoss" "nms",
] "roi_align",
"ROIAlign",
"roi_pool",
"ROIPool",
"smooth_l1_loss",
"Conv2d",
"DFConv2d",
"ConvTranspose2d",
"interpolate",
"BatchNorm2d",
"FrozenBatchNorm2d",
"SigmoidFocalLoss",
'deform_conv',
'modulated_deform_conv',
'DeformConv',
'ModulatedDeformConv',
'ModulatedDeformConvPack',
'deform_roi_pooling',
'DeformRoIPooling',
'DeformRoIPoolingPack',
'ModulatedDeformRoIPoolingPack',
]
#
# Copied From [mmdetection](https://github.com/open-mmlab/mmdetection/tree/master/mmdet/ops/dcn)
#
\ No newline at end of file
import torch
from torch.autograd import Function
from torch.nn.modules.utils import _pair
from maskrcnn_benchmark import _C
class DeformConvFunction(Function):
@staticmethod
def forward(
ctx,
input,
offset,
weight,
stride=1,
padding=0,
dilation=1,
groups=1,
deformable_groups=1,
im2col_step=64
):
if input is not None and input.dim() != 4:
raise ValueError(
"Expected 4D tensor as input, got {}D tensor instead.".format(
input.dim()))
ctx.stride = _pair(stride)
ctx.padding = _pair(padding)
ctx.dilation = _pair(dilation)
ctx.groups = groups
ctx.deformable_groups = deformable_groups
ctx.im2col_step = im2col_step
ctx.save_for_backward(input, offset, weight)
output = input.new_empty(
DeformConvFunction._output_size(input, weight, ctx.padding,
ctx.dilation, ctx.stride))
ctx.bufs_ = [input.new_empty(0), input.new_empty(0)] # columns, ones
if not input.is_cuda:
raise NotImplementedError
else:
cur_im2col_step = min(ctx.im2col_step, input.shape[0])
assert (input.shape[0] %
cur_im2col_step) == 0, 'im2col step must divide batchsize'
_C.deform_conv_forward(
input,
weight,
offset,
output,
ctx.bufs_[0],
ctx.bufs_[1],
weight.size(3),
weight.size(2),
ctx.stride[1],
ctx.stride[0],
ctx.padding[1],
ctx.padding[0],
ctx.dilation[1],
ctx.dilation[0],
ctx.groups,
ctx.deformable_groups,
cur_im2col_step
)
return output
@staticmethod
def backward(ctx, grad_output):
input, offset, weight = ctx.saved_tensors
grad_input = grad_offset = grad_weight = None
if not grad_output.is_cuda:
raise NotImplementedError
else:
cur_im2col_step = min(ctx.im2col_step, input.shape[0])
assert (input.shape[0] %
cur_im2col_step) == 0, 'im2col step must divide batchsize'
if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]:
grad_input = torch.zeros_like(input)
grad_offset = torch.zeros_like(offset)
_C.deform_conv_backward_input(
input,
offset,
grad_output,
grad_input,
grad_offset,
weight,
ctx.bufs_[0],
weight.size(3),
weight.size(2),
ctx.stride[1],
ctx.stride[0],
ctx.padding[1],
ctx.padding[0],
ctx.dilation[1],
ctx.dilation[0],
ctx.groups,
ctx.deformable_groups,
cur_im2col_step
)
if ctx.needs_input_grad[2]:
grad_weight = torch.zeros_like(weight)
_C.deform_conv_backward_parameters(
input,
offset,
grad_output,
grad_weight,
ctx.bufs_[0],
ctx.bufs_[1],
weight.size(3),
weight.size(2),
ctx.stride[1],
ctx.stride[0],
ctx.padding[1],
ctx.padding[0],
ctx.dilation[1],
ctx.dilation[0],
ctx.groups,
ctx.deformable_groups,
1,
cur_im2col_step
)
return (grad_input, grad_offset, grad_weight, None, None, None, None, None)
@staticmethod
def _output_size(input, weight, padding, dilation, stride):
channels = weight.size(0)
output_size = (input.size(0), channels)
for d in range(input.dim() - 2):
in_size = input.size(d + 2)
pad = padding[d]
kernel = dilation[d] * (weight.size(d + 2) - 1) + 1
stride_ = stride[d]
output_size += ((in_size + (2 * pad) - kernel) // stride_ + 1, )
if not all(map(lambda s: s > 0, output_size)):
raise ValueError(
"convolution input is too small (output would be {})".format(
'x'.join(map(str, output_size))))
return output_size
class ModulatedDeformConvFunction(Function):
@staticmethod
def forward(
ctx,
input,
offset,
mask,
weight,
bias=None,
stride=1,
padding=0,
dilation=1,
groups=1,
deformable_groups=1
):
ctx.stride = stride
ctx.padding = padding
ctx.dilation = dilation
ctx.groups = groups
ctx.deformable_groups = deformable_groups
ctx.with_bias = bias is not None
if not ctx.with_bias:
bias = input.new_empty(1) # fake tensor
if not input.is_cuda:
raise NotImplementedError
if weight.requires_grad or mask.requires_grad or offset.requires_grad \
or input.requires_grad:
ctx.save_for_backward(input, offset, mask, weight, bias)
output = input.new_empty(
ModulatedDeformConvFunction._infer_shape(ctx, input, weight))
ctx._bufs = [input.new_empty(0), input.new_empty(0)]
_C.modulated_deform_conv_forward(
input,
weight,
bias,
ctx._bufs[0],
offset,
mask,
output,
ctx._bufs[1],
weight.shape[2],
weight.shape[3],
ctx.stride,
ctx.stride,
ctx.padding,
ctx.padding,
ctx.dilation,
ctx.dilation,
ctx.groups,
ctx.deformable_groups,
ctx.with_bias
)
return output
@staticmethod
def backward(ctx, grad_output):
if not grad_output.is_cuda:
raise NotImplementedError
input, offset, mask, weight, bias = ctx.saved_tensors
grad_input = torch.zeros_like(input)
grad_offset = torch.zeros_like(offset)
grad_mask = torch.zeros_like(mask)
grad_weight = torch.zeros_like(weight)
grad_bias = torch.zeros_like(bias)
_C.modulated_deform_conv_backward(
input,
weight,
bias,
ctx._bufs[0],
offset,
mask,
ctx._bufs[1],
grad_input,
grad_weight,
grad_bias,
grad_offset,
grad_mask,
grad_output,
weight.shape[2],
weight.shape[3],
ctx.stride,
ctx.stride,
ctx.padding,
ctx.padding,
ctx.dilation,
ctx.dilation,
ctx.groups,
ctx.deformable_groups,
ctx.with_bias
)
if not ctx.with_bias:
grad_bias = None
return (grad_input, grad_offset, grad_mask, grad_weight, grad_bias,
None, None, None, None, None)
@staticmethod
def _infer_shape(ctx, input, weight):
n = input.size(0)
channels_out = weight.size(0)
height, width = input.shape[2:4]
kernel_h, kernel_w = weight.shape[2:4]
height_out = (height + 2 * ctx.padding -
(ctx.dilation * (kernel_h - 1) + 1)) // ctx.stride + 1
width_out = (width + 2 * ctx.padding -
(ctx.dilation * (kernel_w - 1) + 1)) // ctx.stride + 1
return n, channels_out, height_out, width_out
deform_conv = DeformConvFunction.apply
modulated_deform_conv = ModulatedDeformConvFunction.apply
import math
import torch
import torch.nn as nn
from torch.nn.modules.utils import _pair
from .deform_conv_func import deform_conv, modulated_deform_conv
class DeformConv(nn.Module):
def __init__(
self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
groups=1,
deformable_groups=1,
bias=False
):
assert not bias
super(DeformConv, self).__init__()
self.with_bias = bias
assert in_channels % groups == 0, \
'in_channels {} cannot be divisible by groups {}'.format(
in_channels, groups)
assert out_channels % groups == 0, \
'out_channels {} cannot be divisible by groups {}'.format(
out_channels, groups)
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = _pair(kernel_size)
self.stride = _pair(stride)
self.padding = _pair(padding)
self.dilation = _pair(dilation)
self.groups = groups
self.deformable_groups = deformable_groups
self.weight = nn.Parameter(
torch.Tensor(out_channels, in_channels // self.groups,
*self.kernel_size))
self.reset_parameters()
def reset_parameters(self):
n = self.in_channels
for k in self.kernel_size:
n *= k
stdv = 1. / math.sqrt(n)
self.weight.data.uniform_(-stdv, stdv)
def forward(self, input, offset):
return deform_conv(input, offset, self.weight, self.stride,
self.padding, self.dilation, self.groups,
self.deformable_groups)
def __repr__(self):
return "".join([
"{}(".format(self.__class__.__name__),
"in_channels={}, ".format(self.in_channels),
"out_channels={}, ".format(self.out_channels),
"kernel_size={}, ".format(self.kernel_size),
"stride={}, ".format(self.stride),
"dilation={}, ".format(self.dilation),
"padding={}, ".format(self.padding),
"groups={}, ".format(self.groups),
"deformable_groups={}, ".format(self.deformable_groups),
"bias={})".format(self.with_bias),
])
class ModulatedDeformConv(nn.Module):
def __init__(
self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
groups=1,
deformable_groups=1,
bias=True
):
super(ModulatedDeformConv, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = _pair(kernel_size)
self.stride = stride
self.padding = padding
self.dilation = dilation
self.groups = groups
self.deformable_groups = deformable_groups
self.with_bias = bias
self.weight = nn.Parameter(torch.Tensor(
out_channels,
in_channels // groups,
*self.kernel_size
))
if bias:
self.bias = nn.Parameter(torch.Tensor(out_channels))
else:
self.register_parameter('bias', None)
self.reset_parameters()
def reset_parameters(self):
n = self.in_channels
for k in self.kernel_size:
n *= k
stdv = 1. / math.sqrt(n)
self.weight.data.uniform_(-stdv, stdv)
if self.bias is not None:
self.bias.data.zero_()
def forward(self, input, offset, mask):
return modulated_deform_conv(
input, offset, mask, self.weight, self.bias, self.stride,
self.padding, self.dilation, self.groups, self.deformable_groups)
def __repr__(self):
return "".join([
"{}(".format(self.__class__.__name__),
"in_channels={}, ".format(self.in_channels),
"out_channels={}, ".format(self.out_channels),
"kernel_size={}, ".format(self.kernel_size),
"stride={}, ".format(self.stride),
"dilation={}, ".format(self.dilation),
"padding={}, ".format(self.padding),
"groups={}, ".format(self.groups),
"deformable_groups={}, ".format(self.deformable_groups),
"bias={})".format(self.with_bias),
])
class ModulatedDeformConvPack(ModulatedDeformConv):
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
groups=1,
deformable_groups=1,
bias=True):
super(ModulatedDeformConvPack, self).__init__(
in_channels, out_channels, kernel_size, stride, padding, dilation,
groups, deformable_groups, bias)
self.conv_offset_mask = nn.Conv2d(
self.in_channels // self.groups,
self.deformable_groups * 3 * self.kernel_size[0] *
self.kernel_size[1],
kernel_size=self.kernel_size,
stride=_pair(self.stride),
padding=_pair(self.padding),
bias=True)
self.init_offset()
def init_offset(self):
self.conv_offset_mask.weight.data.zero_()
self.conv_offset_mask.bias.data.zero_()
def forward(self, input):
out = self.conv_offset_mask(input)
o1, o2, mask = torch.chunk(out, 3, dim=1)
offset = torch.cat((o1, o2), dim=1)
mask = torch.sigmoid(mask)
return modulated_deform_conv(
input, offset, mask, self.weight, self.bias, self.stride,
self.padding, self.dilation, self.groups, self.deformable_groups)
import torch
from torch.autograd import Function
from maskrcnn_benchmark import _C
class DeformRoIPoolingFunction(Function):
@staticmethod
def forward(
ctx,
data,
rois,
offset,
spatial_scale,
out_size,
out_channels,
no_trans,
group_size=1,
part_size=None,
sample_per_part=4,
trans_std=.0
):
ctx.spatial_scale = spatial_scale
ctx.out_size = out_size
ctx.out_channels = out_channels
ctx.no_trans = no_trans
ctx.group_size = group_size
ctx.part_size = out_size if part_size is None else part_size
ctx.sample_per_part = sample_per_part
ctx.trans_std = trans_std
assert 0.0 <= ctx.trans_std <= 1.0
if not data.is_cuda:
raise NotImplementedError
n = rois.shape[0]
output = data.new_empty(n, out_channels, out_size, out_size)
output_count = data.new_empty(n, out_channels, out_size, out_size)
_C.deform_psroi_pooling_forward(
data,
rois,
offset,
output,
output_count,
ctx.no_trans,
ctx.spatial_scale,
ctx.out_channels,
ctx.group_size,
ctx.out_size,
ctx.part_size,
ctx.sample_per_part,
ctx.trans_std
)
if data.requires_grad or rois.requires_grad or offset.requires_grad:
ctx.save_for_backward(data, rois, offset)
ctx.output_count = output_count
return output
@staticmethod
def backward(ctx, grad_output):
if not grad_output.is_cuda:
raise NotImplementedError
data, rois, offset = ctx.saved_tensors
output_count = ctx.output_count
grad_input = torch.zeros_like(data)
grad_rois = None
grad_offset = torch.zeros_like(offset)
_C.deform_psroi_pooling_backward(
grad_output,
data,
rois,
offset,
output_count,
grad_input,
grad_offset,
ctx.no_trans,
ctx.spatial_scale,
ctx.out_channels,
ctx.group_size,
ctx.out_size,
ctx.part_size,
ctx.sample_per_part,
ctx.trans_std
)
return (grad_input, grad_rois, grad_offset, None, None, None, None, None, None, None, None)
deform_roi_pooling = DeformRoIPoolingFunction.apply
from torch import nn
from .deform_pool_func import deform_roi_pooling
class DeformRoIPooling(nn.Module):
def __init__(self,
spatial_scale,
out_size,
out_channels,
no_trans,
group_size=1,
part_size=None,
sample_per_part=4,
trans_std=.0):
super(DeformRoIPooling, self).__init__()
self.spatial_scale = spatial_scale
self.out_size = out_size
self.out_channels = out_channels
self.no_trans = no_trans
self.group_size = group_size
self.part_size = out_size if part_size is None else part_size
self.sample_per_part = sample_per_part
self.trans_std = trans_std
def forward(self, data, rois, offset):
if self.no_trans:
offset = data.new_empty(0)
return deform_roi_pooling(
data, rois, offset, self.spatial_scale, self.out_size,
self.out_channels, self.no_trans, self.group_size, self.part_size,
self.sample_per_part, self.trans_std)
class DeformRoIPoolingPack(DeformRoIPooling):
def __init__(self,
spatial_scale,
out_size,
out_channels,
no_trans,
group_size=1,
part_size=None,
sample_per_part=4,
trans_std=.0,
deform_fc_channels=1024):
super(DeformRoIPoolingPack,
self).__init__(spatial_scale, out_size, out_channels, no_trans,
group_size, part_size, sample_per_part, trans_std)
self.deform_fc_channels = deform_fc_channels
if not no_trans:
self.offset_fc = nn.Sequential(
nn.Linear(self.out_size * self.out_size * self.out_channels,
self.deform_fc_channels),
nn.ReLU(inplace=True),
nn.Linear(self.deform_fc_channels, self.deform_fc_channels),
nn.ReLU(inplace=True),
nn.Linear(self.deform_fc_channels,
self.out_size * self.out_size * 2))
self.offset_fc[-1].weight.data.zero_()
self.offset_fc[-1].bias.data.zero_()
def forward(self, data, rois):
assert data.size(1) == self.out_channels
if self.no_trans:
offset = data.new_empty(0)
return deform_roi_pooling(
data, rois, offset, self.spatial_scale, self.out_size,
self.out_channels, self.no_trans, self.group_size,
self.part_size, self.sample_per_part, self.trans_std)
else:
n = rois.shape[0]
offset = data.new_empty(0)
x = deform_roi_pooling(data, rois, offset, self.spatial_scale,
self.out_size, self.out_channels, True,
self.group_size, self.part_size,
self.sample_per_part, self.trans_std)
offset = self.offset_fc(x.view(n, -1))
offset = offset.view(n, 2, self.out_size, self.out_size)
return deform_roi_pooling(
data, rois, offset, self.spatial_scale, self.out_size,
self.out_channels, self.no_trans, self.group_size,
self.part_size, self.sample_per_part, self.trans_std)
class ModulatedDeformRoIPoolingPack(DeformRoIPooling):
def __init__(self,
spatial_scale,
out_size,
out_channels,
no_trans,
group_size=1,
part_size=None,
sample_per_part=4,
trans_std=.0,
deform_fc_channels=1024):
super(ModulatedDeformRoIPoolingPack, self).__init__(
spatial_scale, out_size, out_channels, no_trans, group_size,
part_size, sample_per_part, trans_std)
self.deform_fc_channels = deform_fc_channels
if not no_trans:
self.offset_fc = nn.Sequential(
nn.Linear(self.out_size * self.out_size * self.out_channels,
self.deform_fc_channels),
nn.ReLU(inplace=True),
nn.Linear(self.deform_fc_channels, self.deform_fc_channels),
nn.ReLU(inplace=True),
nn.Linear(self.deform_fc_channels,
self.out_size * self.out_size * 2))
self.offset_fc[-1].weight.data.zero_()
self.offset_fc[-1].bias.data.zero_()
self.mask_fc = nn.Sequential(
nn.Linear(self.out_size * self.out_size * self.out_channels,
self.deform_fc_channels),
nn.ReLU(inplace=True),
nn.Linear(self.deform_fc_channels,
self.out_size * self.out_size * 1),
nn.Sigmoid())
self.mask_fc[2].weight.data.zero_()
self.mask_fc[2].bias.data.zero_()
def forward(self, data, rois):
assert data.size(1) == self.out_channels
if self.no_trans:
offset = data.new_empty(0)
return deform_roi_pooling(
data, rois, offset, self.spatial_scale, self.out_size,
self.out_channels, self.no_trans, self.group_size,
self.part_size, self.sample_per_part, self.trans_std)
else:
n = rois.shape[0]
offset = data.new_empty(0)
x = deform_roi_pooling(data, rois, offset, self.spatial_scale,
self.out_size, self.out_channels, True,
self.group_size, self.part_size,
self.sample_per_part, self.trans_std)
offset = self.offset_fc(x.view(n, -1))
offset = offset.view(n, 2, self.out_size, self.out_size)
mask = self.mask_fc(x.view(n, -1))
mask = mask.view(n, 1, self.out_size, self.out_size)
return deform_roi_pooling(
data, rois, offset, self.spatial_scale, self.out_size,
self.out_channels, self.no_trans, self.group_size,
self.part_size, self.sample_per_part, self.trans_std) * mask
...@@ -11,6 +11,7 @@ is implemented ...@@ -11,6 +11,7 @@ is implemented
import math import math
import torch import torch
from torch import nn
from torch.nn.modules.utils import _ntuple from torch.nn.modules.utils import _ntuple
...@@ -108,3 +109,86 @@ def interpolate( ...@@ -108,3 +109,86 @@ def interpolate(
output_shape = tuple(_output_size(2)) output_shape = tuple(_output_size(2))
output_shape = input.shape[:-2] + output_shape output_shape = input.shape[:-2] + output_shape
return _NewEmptyTensorOp.apply(input, output_shape) return _NewEmptyTensorOp.apply(input, output_shape)
class DFConv2d(nn.Module):
"""Deformable convolutional layer"""
def __init__(
self,
in_channels,
out_channels,
with_modulated_dcn=True,
kernel_size=3,
stride=1,
groups=1,
dilation=1,
deformable_groups=1,
bias=False
):
super(DFConv2d, self).__init__()
if isinstance(kernel_size, (list, tuple)):
assert len(kernel_size) == 2
offset_base_channels = kernel_size[0] * kernel_size[1]
else:
offset_base_channels = kernel_size * kernel_size
if with_modulated_dcn:
from maskrcnn_benchmark.layers import ModulatedDeformConv
offset_channels = offset_base_channels * 3 #default: 27
conv_block = ModulatedDeformConv
else:
from maskrcnn_benchmark.layers import DeformConv
offset_channels = offset_base_channels * 2 #default: 18
conv_block = DeformConv
self.offset = Conv2d(
in_channels,
deformable_groups * offset_channels,
kernel_size=kernel_size,
stride= stride,
padding= dilation,
groups=1,
dilation=dilation
)
for l in [self.offset,]:
nn.init.kaiming_uniform_(l.weight, a=1)
torch.nn.init.constant_(l.bias, 0.)
self.conv = conv_block(
in_channels,
out_channels,
kernel_size=kernel_size,
stride= stride,
padding=dilation,
dilation=dilation,
groups=groups,
deformable_groups=deformable_groups,
bias=bias
)
self.with_modulated_dcn = with_modulated_dcn
self.kernel_size = kernel_size
self.stride = stride
self.padding = dilation
self.dilation = dilation
def forward(self, x):
if x.numel() > 0:
if not self.with_modulated_dcn:
offset = self.offset(x)
x = self.conv(x, offset)
else:
offset_mask = self.offset(x)
offset = offset_mask[:, :18, :, :]
mask = offset_mask[:, -9:, :, :].sigmoid()
x = self.conv(x, offset, mask)
return x
# get output shape
output_shape = [
(i + 2 * p - (di * (k - 1) + 1)) // d + 1
for i, p, di, k, d in zip(
x.shape[-2:],
self.padding,
self.dilation,
self.kernel_size,
self.stride
)
]
output_shape = [x.shape[0], self.conv.weight.shape[0]] + output_shape
return _NewEmptyTensorOp.apply(x, output_shape)
...@@ -24,6 +24,7 @@ from torch import nn ...@@ -24,6 +24,7 @@ from torch import nn
from maskrcnn_benchmark.layers import FrozenBatchNorm2d from maskrcnn_benchmark.layers import FrozenBatchNorm2d
from maskrcnn_benchmark.layers import Conv2d from maskrcnn_benchmark.layers import Conv2d
from maskrcnn_benchmark.layers import DFConv2d
from maskrcnn_benchmark.modeling.make_layers import group_norm from maskrcnn_benchmark.modeling.make_layers import group_norm
from maskrcnn_benchmark.utils.registry import Registry from maskrcnn_benchmark.utils.registry import Registry
...@@ -106,6 +107,7 @@ class ResNet(nn.Module): ...@@ -106,6 +107,7 @@ class ResNet(nn.Module):
stage2_relative_factor = 2 ** (stage_spec.index - 1) stage2_relative_factor = 2 ** (stage_spec.index - 1)
bottleneck_channels = stage2_bottleneck_channels * stage2_relative_factor bottleneck_channels = stage2_bottleneck_channels * stage2_relative_factor
out_channels = stage2_out_channels * stage2_relative_factor out_channels = stage2_out_channels * stage2_relative_factor
stage_with_dcn = cfg.MODEL.RESNETS.STAGE_WITH_DCN[stage_spec.index -1]
module = _make_stage( module = _make_stage(
transformation_module, transformation_module,
in_channels, in_channels,
...@@ -115,6 +117,11 @@ class ResNet(nn.Module): ...@@ -115,6 +117,11 @@ class ResNet(nn.Module):
num_groups, num_groups,
cfg.MODEL.RESNETS.STRIDE_IN_1X1, cfg.MODEL.RESNETS.STRIDE_IN_1X1,
first_stride=int(stage_spec.index > 1) + 1, first_stride=int(stage_spec.index > 1) + 1,
dcn_config={
"stage_with_dcn": stage_with_dcn,
"with_modulated_dcn": cfg.MODEL.RESNETS.WITH_MODULATED_DCN,
"deformable_groups": cfg.MODEL.RESNETS.DEFORMABLE_GROUPS,
}
) )
in_channels = out_channels in_channels = out_channels
self.add_module(name, module) self.add_module(name, module)
...@@ -155,7 +162,8 @@ class ResNetHead(nn.Module): ...@@ -155,7 +162,8 @@ class ResNetHead(nn.Module):
stride_in_1x1=True, stride_in_1x1=True,
stride_init=None, stride_init=None,
res2_out_channels=256, res2_out_channels=256,
dilation=1 dilation=1,
dcn_config={}
): ):
super(ResNetHead, self).__init__() super(ResNetHead, self).__init__()
...@@ -182,7 +190,8 @@ class ResNetHead(nn.Module): ...@@ -182,7 +190,8 @@ class ResNetHead(nn.Module):
num_groups, num_groups,
stride_in_1x1, stride_in_1x1,
first_stride=stride, first_stride=stride,
dilation=dilation dilation=dilation,
dcn_config=dcn_config
) )
stride = None stride = None
self.add_module(name, module) self.add_module(name, module)
...@@ -204,7 +213,8 @@ def _make_stage( ...@@ -204,7 +213,8 @@ def _make_stage(
num_groups, num_groups,
stride_in_1x1, stride_in_1x1,
first_stride, first_stride,
dilation=1 dilation=1,
dcn_config={}
): ):
blocks = [] blocks = []
stride = first_stride stride = first_stride
...@@ -217,7 +227,8 @@ def _make_stage( ...@@ -217,7 +227,8 @@ def _make_stage(
num_groups, num_groups,
stride_in_1x1, stride_in_1x1,
stride, stride,
dilation=dilation dilation=dilation,
dcn_config=dcn_config
) )
) )
stride = 1 stride = 1
...@@ -235,7 +246,8 @@ class Bottleneck(nn.Module): ...@@ -235,7 +246,8 @@ class Bottleneck(nn.Module):
stride_in_1x1, stride_in_1x1,
stride, stride,
dilation, dilation,
norm_func norm_func,
dcn_config
): ):
super(Bottleneck, self).__init__() super(Bottleneck, self).__init__()
...@@ -271,7 +283,22 @@ class Bottleneck(nn.Module): ...@@ -271,7 +283,22 @@ class Bottleneck(nn.Module):
) )
self.bn1 = norm_func(bottleneck_channels) self.bn1 = norm_func(bottleneck_channels)
# TODO: specify init for the above # TODO: specify init for the above
with_dcn = dcn_config.get("stage_with_dcn", False)
if with_dcn:
deformable_groups = dcn_config.get("deformable_groups", 1)
with_modulated_dcn = dcn_config.get("with_modulated_dcn", False)
self.conv2 = DFConv2d(
bottleneck_channels,
bottleneck_channels,
with_modulated_dcn=with_modulated_dcn,
kernel_size=3,
stride=stride_3x3,
groups=num_groups,
dilation=dilation,
deformable_groups=deformable_groups,
bias=False
)
else:
self.conv2 = Conv2d( self.conv2 = Conv2d(
bottleneck_channels, bottleneck_channels,
bottleneck_channels, bottleneck_channels,
...@@ -282,6 +309,8 @@ class Bottleneck(nn.Module): ...@@ -282,6 +309,8 @@ class Bottleneck(nn.Module):
groups=num_groups, groups=num_groups,
dilation=dilation dilation=dilation
) )
nn.init.kaiming_uniform_(self.conv2.weight, a=1)
self.bn2 = norm_func(bottleneck_channels) self.bn2 = norm_func(bottleneck_channels)
self.conv3 = Conv2d( self.conv3 = Conv2d(
...@@ -289,7 +318,7 @@ class Bottleneck(nn.Module): ...@@ -289,7 +318,7 @@ class Bottleneck(nn.Module):
) )
self.bn3 = norm_func(out_channels) self.bn3 = norm_func(out_channels)
for l in [self.conv1, self.conv2, self.conv3,]: for l in [self.conv1, self.conv3,]:
nn.init.kaiming_uniform_(l.weight, a=1) nn.init.kaiming_uniform_(l.weight, a=1)
def forward(self, x): def forward(self, x):
...@@ -346,7 +375,8 @@ class BottleneckWithFixedBatchNorm(Bottleneck): ...@@ -346,7 +375,8 @@ class BottleneckWithFixedBatchNorm(Bottleneck):
num_groups=1, num_groups=1,
stride_in_1x1=True, stride_in_1x1=True,
stride=1, stride=1,
dilation=1 dilation=1,
dcn_config={}
): ):
super(BottleneckWithFixedBatchNorm, self).__init__( super(BottleneckWithFixedBatchNorm, self).__init__(
in_channels=in_channels, in_channels=in_channels,
...@@ -356,7 +386,8 @@ class BottleneckWithFixedBatchNorm(Bottleneck): ...@@ -356,7 +386,8 @@ class BottleneckWithFixedBatchNorm(Bottleneck):
stride_in_1x1=stride_in_1x1, stride_in_1x1=stride_in_1x1,
stride=stride, stride=stride,
dilation=dilation, dilation=dilation,
norm_func=FrozenBatchNorm2d norm_func=FrozenBatchNorm2d,
dcn_config=dcn_config
) )
...@@ -376,7 +407,8 @@ class BottleneckWithGN(Bottleneck): ...@@ -376,7 +407,8 @@ class BottleneckWithGN(Bottleneck):
num_groups=1, num_groups=1,
stride_in_1x1=True, stride_in_1x1=True,
stride=1, stride=1,
dilation=1 dilation=1,
dcn_config={}
): ):
super(BottleneckWithGN, self).__init__( super(BottleneckWithGN, self).__init__(
in_channels=in_channels, in_channels=in_channels,
...@@ -386,7 +418,8 @@ class BottleneckWithGN(Bottleneck): ...@@ -386,7 +418,8 @@ class BottleneckWithGN(Bottleneck):
stride_in_1x1=stride_in_1x1, stride_in_1x1=stride_in_1x1,
stride=stride, stride=stride,
dilation=dilation, dilation=dilation,
norm_func=group_norm norm_func=group_norm,
dcn_config=dcn_config
) )
......
...@@ -143,6 +143,33 @@ def _load_c2_pickled_weights(file_path): ...@@ -143,6 +143,33 @@ def _load_c2_pickled_weights(file_path):
return weights return weights
def _rename_conv_weights_for_deformable_conv_layers(state_dict, cfg):
import re
logger = logging.getLogger(__name__)
logger.info("Remapping conv weights for deformable conv weights")
layer_keys = sorted(state_dict.keys())
for ix, stage_with_dcn in enumerate(cfg.MODEL.RESNETS.STAGE_WITH_DCN, 1):
if not stage_with_dcn:
continue
for old_key in layer_keys:
pattern = ".*layer{}.*conv2.*".format(ix)
r = re.match(pattern, old_key)
if r is None:
continue
for param in ["weight", "bias"]:
if old_key.find(param) is -1:
continue
new_key = old_key.replace(
"conv2.{}".format(param), "conv2.conv.{}".format(param)
)
logger.info("pattern: {}, old_key: {}, new_key: {}".format(
pattern, old_key, new_key
))
state_dict[new_key] = state_dict[old_key]
del state_dict[old_key]
return state_dict
_C2_STAGE_NAMES = { _C2_STAGE_NAMES = {
"R-50": ["1.2", "2.3", "3.5", "4.2"], "R-50": ["1.2", "2.3", "3.5", "4.2"],
"R-101": ["1.2", "2.3", "3.22", "4.2"], "R-101": ["1.2", "2.3", "3.22", "4.2"],
...@@ -168,6 +195,10 @@ def load_resnet_c2_format(cfg, f): ...@@ -168,6 +195,10 @@ def load_resnet_c2_format(cfg, f):
arch = arch.replace("-RETINANET", "") arch = arch.replace("-RETINANET", "")
stages = _C2_STAGE_NAMES[arch] stages = _C2_STAGE_NAMES[arch]
state_dict = _rename_weights_for_resnet(state_dict, stages) state_dict = _rename_weights_for_resnet(state_dict, stages)
# ***********************************
# for deformable convolutional layer
state_dict = _rename_conv_weights_for_deformable_conv_layers(state_dict, cfg)
# ***********************************
return dict(model=state_dict) return dict(model=state_dict)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment