Commit 6b1ab017 authored by Cheng-Yang Fu's avatar Cheng-Yang Fu Committed by Francisco Massa

Add RetinaNet Implementation (#102)

* Add RetinetNet parameters in cfg.

* hot fix.

* Add the retinanet head module now.

* Add the function to generate the anchors for RetinaNet.

* Add the SigmoidFocalLoss cuda operator.

* Fix the bug in the extra layers.

* Change the normalizer for SigmoidFocalLoss

* Support multiscale in training.

* Add retinannet  training script.

* Add the inference part of RetinaNet.

* Fix the bug when building the extra layers in retinanet.
Update the matching part in retinanet_loss.

* Add the first version of the inference of RetinaNet.
Need to check it again to see if is there any room for speed
improvement.

* Remove the  retinanet_R-50-FPN_2x.yaml first.

* Optimize the retinanet postprocessing.

* quick fix.

* Add script for training RetinaNet with ResNet101 backbone.

* Move cfg.RETINANET to cfg.MODEL.RETINANET

* Remove the variables which are not used.

* revert boxlist_ops.
Generate Empty BoxLists instead of [] in retinanet_infer

* Remove the not used commented lines.
Add NUM_DETECTIONS_PER_IMAGE

* remove the not used codes.

* Move retinanet related files under Modeling/rpn/retinanet

* Add retinanet_X_101_32x8d_FPN_1x.yaml script.
This model is not fully validated. I only trained it around 5000
iterations and everything is fine.

* set RETINANET.PRE_NMS_TOP_N as 0 in level5 (p7), because previous setting may generate zero detections and could cause
the program break.
This part is used in original Detectron setting.

* Fix the rpn only bug when the training ends.

* Minor improvements

* Comments and add Python-only implementation

* Bugfix and remove commented code

* keep the generalized_rcnn same.
Move the build_retinanet inside build_rpn.

* Add USE_C5 in the MODEL.RETINANET

* Add two configs using P5 to generate P6.

* fix the bug when loading the Caffe2 ImageNet pretrained model.

* Reduce the code depulication of RPN loss and RetinaNet loss.

* Remove the comment which is not used.

* Remove the hard coded number of classes.

* share the foward part of rpn inference.

* fix the bug in rpn inference.

* Remove the conditional part in the inference.

* Bug fix: add the utils file for permute and flatten of the box
prediction layers.

* Update the comment.

* quick fix. Adding import cat.

* quick fix: forget including import.

* Adjust the normalization part according to Detectron's setting.

* Use the bbox reg normalization term.

* Clean the code according to recent review.

* Using CUDA version for training now. And the python version for training
on cpu.

* rename the directory to retinanet.

* Make the train and val datasets are consistent with mask r-cnn setting.

* add comment.
parent 595694cb
MODEL:
META_ARCHITECTURE: "GeneralizedRCNN"
WEIGHT: "catalog://ImageNetPretrained/MSRA/R-101"
RPN_ONLY: True
RETINANET_ON: True
BACKBONE:
CONV_BODY: "R-101-FPN-RETINANET"
OUT_CHANNELS: 256
RPN:
USE_FPN: True
FG_IOU_THRESHOLD: 0.5
BG_IOU_THRESHOLD: 0.4
ANCHOR_STRIDE: (4, 8, 16, 32, 64)
PRE_NMS_TOP_N_TRAIN: 2000
PRE_NMS_TOP_N_TEST: 1000
POST_NMS_TOP_N_TEST: 1000
FPN_POST_NMS_TOP_N_TEST: 1000
ROI_HEADS:
USE_FPN: True
BATCH_SIZE_PER_IMAGE: 256
ROI_BOX_HEAD:
POOLER_RESOLUTION: 7
POOLER_SCALES: (0.25, 0.125, 0.0625, 0.03125)
POOLER_SAMPLING_RATIO: 2
FEATURE_EXTRACTOR: "FPN2MLPFeatureExtractor"
PREDICTOR: "FPNPredictor"
RETINANET:
SCALES_PER_OCTAVE: 3
STRADDLE_THRESH: -1
FG_IOU_THRESHOLD: 0.5
BG_IOU_THRESHOLD: 0.4
DATASETS:
TRAIN: ("coco_2014_train", "coco_2014_valminusminival")
TEST: ("coco_2014_minival",)
INPUT:
MIN_SIZE_TRAIN: (800, )
MAX_SIZE_TRAIN: 1333
MIN_SIZE_TEST: 800
MAX_SIZE_TEST: 1333
DATALOADER:
SIZE_DIVISIBILITY: 32
SOLVER:
# Assume 4 gpus
BASE_LR: 0.005
WEIGHT_DECAY: 0.0001
STEPS: (120000, 160000)
MAX_ITER: 180000
IMS_PER_BATCH: 8
MODEL:
META_ARCHITECTURE: "GeneralizedRCNN"
WEIGHT: "catalog://ImageNetPretrained/MSRA/R-101"
RPN_ONLY: True
RETINANET_ON: True
BACKBONE:
CONV_BODY: "R-101-FPN-RETINANET"
OUT_CHANNELS: 256
RPN:
USE_FPN: True
FG_IOU_THRESHOLD: 0.5
BG_IOU_THRESHOLD: 0.4
ANCHOR_STRIDE: (4, 8, 16, 32, 64)
PRE_NMS_TOP_N_TRAIN: 2000
PRE_NMS_TOP_N_TEST: 1000
POST_NMS_TOP_N_TEST: 1000
FPN_POST_NMS_TOP_N_TEST: 1000
ROI_HEADS:
USE_FPN: True
BATCH_SIZE_PER_IMAGE: 256
ROI_BOX_HEAD:
POOLER_RESOLUTION: 7
POOLER_SCALES: (0.25, 0.125, 0.0625, 0.03125)
POOLER_SAMPLING_RATIO: 2
FEATURE_EXTRACTOR: "FPN2MLPFeatureExtractor"
PREDICTOR: "FPNPredictor"
RETINANET:
SCALES_PER_OCTAVE: 3
STRADDLE_THRESH: -1
USE_C5: False
FG_IOU_THRESHOLD: 0.5
BG_IOU_THRESHOLD: 0.4
DATASETS:
TRAIN: ("coco_2014_train", "coco_2014_valminusminival")
TEST: ("coco_2014_minival",)
INPUT:
MIN_SIZE_TRAIN: (800, )
MAX_SIZE_TRAIN: 1333
MIN_SIZE_TEST: 800
MAX_SIZE_TEST: 1333
DATALOADER:
SIZE_DIVISIBILITY: 32
SOLVER:
# Assume 4 gpus
BASE_LR: 0.005
WEIGHT_DECAY: 0.0001
STEPS: (120000, 160000)
MAX_ITER: 180000
IMS_PER_BATCH: 8
MODEL:
META_ARCHITECTURE: "GeneralizedRCNN"
WEIGHT: "catalog://ImageNetPretrained/MSRA/R-50"
RPN_ONLY: True
RETINANET_ON: True
BACKBONE:
CONV_BODY: "R-50-FPN-RETINANET"
OUT_CHANNELS: 256
RPN:
USE_FPN: True
FG_IOU_THRESHOLD: 0.5
BG_IOU_THRESHOLD: 0.4
ANCHOR_STRIDE: (4, 8, 16, 32, 64)
PRE_NMS_TOP_N_TRAIN: 2000
PRE_NMS_TOP_N_TEST: 1000
POST_NMS_TOP_N_TEST: 1000
FPN_POST_NMS_TOP_N_TEST: 1000
ROI_HEADS:
USE_FPN: True
BATCH_SIZE_PER_IMAGE: 256
ROI_BOX_HEAD:
POOLER_RESOLUTION: 7
POOLER_SCALES: (0.25, 0.125, 0.0625, 0.03125)
POOLER_SAMPLING_RATIO: 2
FEATURE_EXTRACTOR: "FPN2MLPFeatureExtractor"
PREDICTOR: "FPNPredictor"
RETINANET:
SCALES_PER_OCTAVE: 3
STRADDLE_THRESH: -1
FG_IOU_THRESHOLD: 0.5
BG_IOU_THRESHOLD: 0.4
DATASETS:
TRAIN: ("coco_2014_train", "coco_2014_valminusminival")
TEST: ("coco_2014_minival",)
INPUT:
MIN_SIZE_TRAIN: (800,)
MAX_SIZE_TRAIN: 1333
MIN_SIZE_TEST: 800
MAX_SIZE_TEST: 1333
DATALOADER:
SIZE_DIVISIBILITY: 32
SOLVER:
# Assume 4 gpus
BASE_LR: 0.005
WEIGHT_DECAY: 0.0001
STEPS: (120000, 160000)
MAX_ITER: 180000
IMS_PER_BATCH: 8
MODEL:
META_ARCHITECTURE: "GeneralizedRCNN"
WEIGHT: "catalog://ImageNetPretrained/MSRA/R-50"
RPN_ONLY: True
RETINANET_ON: True
BACKBONE:
CONV_BODY: "R-50-FPN-RETINANET"
OUT_CHANNELS: 256
RPN:
USE_FPN: True
FG_IOU_THRESHOLD: 0.5
BG_IOU_THRESHOLD: 0.4
ANCHOR_STRIDE: (4, 8, 16, 32, 64)
PRE_NMS_TOP_N_TRAIN: 2000
PRE_NMS_TOP_N_TEST: 1000
POST_NMS_TOP_N_TEST: 1000
FPN_POST_NMS_TOP_N_TEST: 1000
ROI_HEADS:
USE_FPN: True
BATCH_SIZE_PER_IMAGE: 256
ROI_BOX_HEAD:
POOLER_RESOLUTION: 7
POOLER_SCALES: (0.25, 0.125, 0.0625, 0.03125)
POOLER_SAMPLING_RATIO: 2
FEATURE_EXTRACTOR: "FPN2MLPFeatureExtractor"
PREDICTOR: "FPNPredictor"
RETINANET:
SCALES_PER_OCTAVE: 3
STRADDLE_THRESH: -1
FG_IOU_THRESHOLD: 0.5
BG_IOU_THRESHOLD: 0.4
DATASETS:
TRAIN: ("coco_2014_minival",)
TEST: ("coco_2014_minival",)
INPUT:
MIN_SIZE_TRAIN: (600,)
MAX_SIZE_TRAIN: 1000
MIN_SIZE_TEST: 800
MAX_SIZE_TEST: 1000
DATALOADER:
SIZE_DIVISIBILITY: 32
SOLVER:
BASE_LR: 0.005
WEIGHT_DECAY: 0.0001
STEPS: (3500,)
MAX_ITER: 4000
IMS_PER_BATCH: 4
MODEL:
META_ARCHITECTURE: "GeneralizedRCNN"
WEIGHT: "catalog://ImageNetPretrained/MSRA/R-50"
RPN_ONLY: True
RETINANET_ON: True
BACKBONE:
CONV_BODY: "R-50-FPN-RETINANET"
OUT_CHANNELS: 256
RPN:
USE_FPN: True
FG_IOU_THRESHOLD: 0.5
BG_IOU_THRESHOLD: 0.4
ANCHOR_STRIDE: (4, 8, 16, 32, 64)
PRE_NMS_TOP_N_TRAIN: 2000
PRE_NMS_TOP_N_TEST: 1000
POST_NMS_TOP_N_TEST: 1000
FPN_POST_NMS_TOP_N_TEST: 1000
ROI_HEADS:
USE_FPN: True
BATCH_SIZE_PER_IMAGE: 256
ROI_BOX_HEAD:
POOLER_RESOLUTION: 7
POOLER_SCALES: (0.25, 0.125, 0.0625, 0.03125)
POOLER_SAMPLING_RATIO: 2
FEATURE_EXTRACTOR: "FPN2MLPFeatureExtractor"
PREDICTOR: "FPNPredictor"
RETINANET:
SCALES_PER_OCTAVE: 3
STRADDLE_THRESH: -1
USE_C5: False
FG_IOU_THRESHOLD: 0.5
BG_IOU_THRESHOLD: 0.4
DATASETS:
TRAIN: ("coco_2014_train", "coco_2014_valminusminival")
TEST: ("coco_2014_minival",)
INPUT:
MIN_SIZE_TRAIN: (800,)
MAX_SIZE_TRAIN: 1333
MIN_SIZE_TEST: 800
MAX_SIZE_TEST: 1333
DATALOADER:
SIZE_DIVISIBILITY: 32
SOLVER:
# Assume 4 gpus
BASE_LR: 0.005
WEIGHT_DECAY: 0.0001
STEPS: (120000, 160000)
MAX_ITER: 180000
IMS_PER_BATCH: 8
MODEL:
META_ARCHITECTURE: "GeneralizedRCNN"
WEIGHT: "catalog://ImageNetPretrained/FAIR/20171220/X-101-32x8d"
RPN_ONLY: True
RETINANET_ON: True
BACKBONE:
CONV_BODY: "R-101-FPN-RETINANET"
OUT_CHANNELS: 256
RPN:
USE_FPN: True
FG_IOU_THRESHOLD: 0.5
BG_IOU_THRESHOLD: 0.4
ANCHOR_STRIDE: (4, 8, 16, 32, 64)
PRE_NMS_TOP_N_TRAIN: 2000
PRE_NMS_TOP_N_TEST: 1000
POST_NMS_TOP_N_TEST: 1000
FPN_POST_NMS_TOP_N_TEST: 1000
ROI_HEADS:
USE_FPN: True
BATCH_SIZE_PER_IMAGE: 256
ROI_BOX_HEAD:
POOLER_RESOLUTION: 7
POOLER_SCALES: (0.25, 0.125, 0.0625, 0.03125)
POOLER_SAMPLING_RATIO: 2
FEATURE_EXTRACTOR: "FPN2MLPFeatureExtractor"
PREDICTOR: "FPNPredictor"
RESNETS:
STRIDE_IN_1X1: False
NUM_GROUPS: 32
WIDTH_PER_GROUP: 8
RETINANET:
SCALES_PER_OCTAVE: 3
STRADDLE_THRESH: -1
FG_IOU_THRESHOLD: 0.5
BG_IOU_THRESHOLD: 0.4
DATASETS:
TRAIN: ("coco_2014_train", "coco_2014_valminusminival")
TEST: ("coco_2014_minival",)
INPUT:
MIN_SIZE_TRAIN: (800, )
MAX_SIZE_TRAIN: 1333
MIN_SIZE_TEST: 800
MAX_SIZE_TEST: 1333
DATALOADER:
SIZE_DIVISIBILITY: 32
SOLVER:
# Assume 4 gpus
BASE_LR: 0.0025
WEIGHT_DECAY: 0.0001
STEPS: (240000, 320000)
MAX_ITER: 360000
IMS_PER_BATCH: 4
......@@ -23,6 +23,7 @@ _C = CN()
_C.MODEL = CN()
_C.MODEL.RPN_ONLY = False
_C.MODEL.MASK_ON = False
_C.MODEL.RETINANET_ON = False
_C.MODEL.KEYPOINT_ON = False
_C.MODEL.DEVICE = "cuda"
_C.MODEL.META_ARCHITECTURE = "GeneralizedRCNN"
......@@ -273,6 +274,67 @@ _C.MODEL.RESNETS.RES5_DILATION = 1
_C.MODEL.RESNETS.RES2_OUT_CHANNELS = 256
_C.MODEL.RESNETS.STEM_OUT_CHANNELS = 64
# ---------------------------------------------------------------------------- #
# RetinaNet Options (Follow the Detectron version)
# ---------------------------------------------------------------------------- #
_C.MODEL.RETINANET = CN()
# This is the number of foreground classes and background.
_C.MODEL.RETINANET.NUM_CLASSES = 81
# Anchor aspect ratios to use
_C.MODEL.RETINANET.ANCHOR_SIZES = (32, 64, 128, 256, 512)
_C.MODEL.RETINANET.ASPECT_RATIOS = (0.5, 1.0, 2.0)
_C.MODEL.RETINANET.ANCHOR_STRIDES = (8, 16, 32, 64, 128)
_C.MODEL.RETINANET.STRADDLE_THRESH = 0
# Anchor scales per octave
_C.MODEL.RETINANET.OCTAVE = 2.0
_C.MODEL.RETINANET.SCALES_PER_OCTAVE = 3
# Use C5 or P5 to generate P6
_C.MODEL.RETINANET.USE_C5 = True
# Convolutions to use in the cls and bbox tower
# NOTE: this doesn't include the last conv for logits
_C.MODEL.RETINANET.NUM_CONVS = 4
# Weight for bbox_regression loss
_C.MODEL.RETINANET.BBOX_REG_WEIGHT = 4.0
# Smooth L1 loss beta for bbox regression
_C.MODEL.RETINANET.BBOX_REG_BETA = 0.11
# During inference, #locs to select based on cls score before NMS is performed
# per FPN level
_C.MODEL.RETINANET.PRE_NMS_TOP_N = 1000
# IoU overlap ratio for labeling an anchor as positive
# Anchors with >= iou overlap are labeled positive
_C.MODEL.RETINANET.FG_IOU_THRESHOLD = 0.5
# IoU overlap ratio for labeling an anchor as negative
# Anchors with < iou overlap are labeled negative
_C.MODEL.RETINANET.BG_IOU_THRESHOLD = 0.4
# Focal loss parameter: alpha
_C.MODEL.RETINANET.LOSS_ALPHA = 0.25
# Focal loss parameter: gamma
_C.MODEL.RETINANET.LOSS_GAMMA = 2.0
# Prior prob for the positives at the beginning of training. This is used to set
# the bias init for the logits layer
_C.MODEL.RETINANET.PRIOR_PROB = 0.01
# Inference cls score threshold, anchors with score > INFERENCE_TH are
# considered for inference
_C.MODEL.RETINANET.INFERENCE_TH = 0.05
# NMS threshold used in RetinaNet
_C.MODEL.RETINANET.NMS_TH = 0.4
# ---------------------------------------------------------------------------- #
# Solver
# ---------------------------------------------------------------------------- #
......@@ -311,6 +373,8 @@ _C.TEST.EXPECTED_RESULTS_SIGMA_TOL = 4
# This is global, so if we have 8 GPUs and IMS_PER_BATCH = 16, each GPU will
# see 2 images per batch
_C.TEST.IMS_PER_BATCH = 8
# Number of detections per image
_C.TEST.DETECTIONS_PER_IMG = 100
# ---------------------------------------------------------------------------- #
......
#pragma once
#include "cpu/vision.h"
#ifdef WITH_CUDA
#include "cuda/vision.h"
#endif
// Interface for Python
at::Tensor SigmoidFocalLoss_forward(
const at::Tensor& logits,
const at::Tensor& targets,
const int num_classes,
const float gamma,
const float alpha) {
if (logits.type().is_cuda()) {
#ifdef WITH_CUDA
return SigmoidFocalLoss_forward_cuda(logits, targets, num_classes, gamma, alpha);
#else
AT_ERROR("Not compiled with GPU support");
#endif
}
AT_ERROR("Not implemented on the CPU");
}
at::Tensor SigmoidFocalLoss_backward(
const at::Tensor& logits,
const at::Tensor& targets,
const at::Tensor& d_losses,
const int num_classes,
const float gamma,
const float alpha) {
if (logits.type().is_cuda()) {
#ifdef WITH_CUDA
return SigmoidFocalLoss_backward_cuda(logits, targets, d_losses, num_classes, gamma, alpha);
#else
AT_ERROR("Not compiled with GPU support");
#endif
}
AT_ERROR("Not implemented on the CPU");
}
// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
// This file is modified from https://github.com/pytorch/pytorch/blob/master/modules/detectron/sigmoid_focal_loss_op.cu
// Cheng-Yang Fu
// cyfu@cs.unc.edu
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <THC/THC.h>
#include <THC/THCAtomics.cuh>
#include <THC/THCDeviceUtils.cuh>
#include <cfloat>
// TODO make it in a common file
#define CUDA_1D_KERNEL_LOOP(i, n) \
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; \
i += blockDim.x * gridDim.x)
template <typename T>
__global__ void SigmoidFocalLossForward(const int nthreads,
const T* logits,
const int* targets,
const int num_classes,
const float gamma,
const float alpha,
const int num,
T* losses) {
CUDA_1D_KERNEL_LOOP(i, nthreads) {
int n = i / num_classes;
int d = i % num_classes; // current class[0~79];
int t = targets[n]; // target class [1~80];
// Decide it is positive or negative case.
T c1 = (t == (d+1));
T c2 = (t>=0 & t != (d+1));
T zn = (1.0 - alpha);
T zp = (alpha);
// p = 1. / 1. + expf(-x); p = sigmoid(x)
T p = 1. / (1. + expf(-logits[i]));
// (1-p)**gamma * log(p) where
T term1 = powf((1. - p), gamma) * logf(max(p, FLT_MIN));
// p**gamma * log(1-p)
T term2 = powf(p, gamma) *
(-1. * logits[i] * (logits[i] >= 0) -
logf(1. + expf(logits[i] - 2. * logits[i] * (logits[i] >= 0))));
losses[i] = 0.0;
losses[i] += -c1 * term1 * zp;
losses[i] += -c2 * term2 * zn;
} // CUDA_1D_KERNEL_LOOP
} // SigmoidFocalLossForward
template <typename T>
__global__ void SigmoidFocalLossBackward(const int nthreads,
const T* logits,
const int* targets,
const T* d_losses,
const int num_classes,
const float gamma,
const float alpha,
const int num,
T* d_logits) {
CUDA_1D_KERNEL_LOOP(i, nthreads) {
int n = i / num_classes;
int d = i % num_classes; // current class[0~79];
int t = targets[n]; // target class [1~80], 0 is background;
// Decide it is positive or negative case.
T c1 = (t == (d+1));
T c2 = (t>=0 & t != (d+1));
T zn = (1.0 - alpha);
T zp = (alpha);
// p = 1. / 1. + expf(-x); p = sigmoid(x)
T p = 1. / (1. + expf(-logits[i]));
// (1-p)**g * (1 - p - g*p*log(p)
T term1 = powf((1. - p), gamma) *
(1. - p - (p * gamma * logf(max(p, FLT_MIN))));
// (p**g) * (g*(1-p)*log(1-p) - p)
T term2 = powf(p, gamma) *
((-1. * logits[i] * (logits[i] >= 0) -
logf(1. + expf(logits[i] - 2. * logits[i] * (logits[i] >= 0)))) *
(1. - p) * gamma - p);
d_logits[i] = 0.0;
d_logits[i] += -c1 * term1 * zp;
d_logits[i] += -c2 * term2 * zn;
d_logits[i] = d_logits[i] * d_losses[i];
} // CUDA_1D_KERNEL_LOOP
} // SigmoidFocalLossBackward
at::Tensor SigmoidFocalLoss_forward_cuda(
const at::Tensor& logits,
const at::Tensor& targets,
const int num_classes,
const float gamma,
const float alpha) {
AT_ASSERTM(logits.type().is_cuda(), "logits must be a CUDA tensor");
AT_ASSERTM(targets.type().is_cuda(), "targets must be a CUDA tensor");
AT_ASSERTM(logits.dim() == 2, "logits should be NxClass");
const int num_samples = logits.size(0);
auto losses = at::empty({num_samples, logits.size(1)}, logits.options());
auto losses_size = num_samples * logits.size(1);
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
dim3 grid(std::min(THCCeilDiv(losses_size, 512L), 4096L));
dim3 block(512);
if (losses.numel() == 0) {
THCudaCheck(cudaGetLastError());
return losses;
}
AT_DISPATCH_FLOATING_TYPES(logits.type(), "SigmoidFocalLoss_forward", [&] {
SigmoidFocalLossForward<scalar_t><<<grid, block, 0, stream>>>(
losses_size,
logits.contiguous().data<scalar_t>(),
targets.contiguous().data<int>(),
num_classes,
gamma,
alpha,
num_samples,
losses.data<scalar_t>());
});
THCudaCheck(cudaGetLastError());
return losses;
}
at::Tensor SigmoidFocalLoss_backward_cuda(
const at::Tensor& logits,
const at::Tensor& targets,
const at::Tensor& d_losses,
const int num_classes,
const float gamma,
const float alpha) {
AT_ASSERTM(logits.type().is_cuda(), "logits must be a CUDA tensor");
AT_ASSERTM(targets.type().is_cuda(), "targets must be a CUDA tensor");
AT_ASSERTM(d_losses.type().is_cuda(), "d_losses must be a CUDA tensor");
AT_ASSERTM(logits.dim() == 2, "logits should be NxClass");
const int num_samples = logits.size(0);
AT_ASSERTM(logits.size(1) == num_classes, "logits.size(1) should be num_classes");
auto d_logits = at::zeros({num_samples, num_classes}, logits.options());
auto d_logits_size = num_samples * logits.size(1);
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
dim3 grid(std::min(THCCeilDiv(d_logits_size, 512L), 4096L));
dim3 block(512);
if (d_logits.numel() == 0) {
THCudaCheck(cudaGetLastError());
return d_logits;
}
AT_DISPATCH_FLOATING_TYPES(logits.type(), "SigmoidFocalLoss_backward", [&] {
SigmoidFocalLossBackward<scalar_t><<<grid, block, 0, stream>>>(
d_logits_size,
logits.contiguous().data<scalar_t>(),
targets.contiguous().data<int>(),
d_losses.contiguous().data<scalar_t>(),
num_classes,
gamma,
alpha,
num_samples,
d_logits.data<scalar_t>());
});
THCudaCheck(cudaGetLastError());
return d_logits;
}
......@@ -3,6 +3,21 @@
#include <torch/extension.h>
at::Tensor SigmoidFocalLoss_forward_cuda(
const at::Tensor& logits,
const at::Tensor& targets,
const int num_classes,
const float gamma,
const float alpha);
at::Tensor SigmoidFocalLoss_backward_cuda(
const at::Tensor& logits,
const at::Tensor& targets,
const at::Tensor& d_losses,
const int num_classes,
const float gamma,
const float alpha);
at::Tensor ROIAlign_forward_cuda(const at::Tensor& input,
const at::Tensor& rois,
const float spatial_scale,
......
......@@ -2,7 +2,7 @@
#include "nms.h"
#include "ROIAlign.h"
#include "ROIPool.h"
#include "SigmoidFocalLoss.h"
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("nms", &nms, "non-maximum suppression");
......@@ -10,4 +10,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("roi_align_backward", &ROIAlign_backward, "ROIAlign_backward");
m.def("roi_pool_forward", &ROIPool_forward, "ROIPool_forward");
m.def("roi_pool_backward", &ROIPool_backward, "ROIPool_backward");
m.def("sigmoid_focalloss_forward", &SigmoidFocalLoss_forward, "SigmoidFocalLoss_forward");
m.def("sigmoid_focalloss_backward", &SigmoidFocalLoss_backward, "SigmoidFocalLoss_backward");
}
......@@ -11,8 +11,10 @@ from .roi_align import roi_align
from .roi_pool import ROIPool
from .roi_pool import roi_pool
from .smooth_l1_loss import smooth_l1_loss
from .sigmoid_focal_loss import SigmoidFocalLoss
__all__ = ["nms", "roi_align", "ROIAlign", "roi_pool", "ROIPool",
"smooth_l1_loss", "Conv2d", "ConvTranspose2d", "interpolate",
"FrozenBatchNorm2d", "SigmoidFocalLoss"
]
__all__ = ["nms", "roi_align", "ROIAlign", "roi_pool", "ROIPool",
"smooth_l1_loss", "Conv2d", "ConvTranspose2d", "interpolate",
"FrozenBatchNorm2d",
]
import torch
from torch import nn
from torch.autograd import Function
from torch.autograd.function import once_differentiable
from maskrcnn_benchmark import _C
# TODO: Use JIT to replace CUDA implementation in the future.
class _SigmoidFocalLoss(Function):
@staticmethod
def forward(ctx, logits, targets, gamma, alpha):
ctx.save_for_backward(logits, targets)
num_classes = logits.shape[1]
ctx.num_classes = num_classes
ctx.gamma = gamma
ctx.alpha = alpha
losses = _C.sigmoid_focalloss_forward(
logits, targets, num_classes, gamma, alpha
)
return losses
@staticmethod
@once_differentiable
def backward(ctx, d_loss):
logits, targets = ctx.saved_tensors
num_classes = ctx.num_classes
gamma = ctx.gamma
alpha = ctx.alpha
d_loss = d_loss.contiguous()
d_logits = _C.sigmoid_focalloss_backward(
logits, targets, d_loss, num_classes, gamma, alpha
)
return d_logits, None, None, None, None
sigmoid_focal_loss_cuda = _SigmoidFocalLoss.apply
def sigmoid_focal_loss_cpu(logits, targets, gamma, alpha):
num_classes = logits.shape[1]
gamma = gamma[0]
alpha = alpha[0]
dtype = targets.dtype
device = targets.device
class_range = torch.arange(1, num_classes+1, dtype=dtype, device=device).unsqueeze(0)
t = targets.unsqueeze(1)
p = torch.sigmoid(logits)
term1 = (1 - p) ** gamma * torch.log(p)
term2 = p ** gamma * torch.log(1 - p)
return -(t == class_range).float() * term1 * alpha - ((t != class_range) * (t >= 0)).float() * term2 * (1 - alpha)
class SigmoidFocalLoss(nn.Module):
def __init__(self, gamma, alpha):
super(SigmoidFocalLoss, self).__init__()
self.gamma = gamma
self.alpha = alpha
def forward(self, logits, targets):
device = logits.device
if logits.is_cuda:
loss_func = sigmoid_focal_loss_cuda
else:
loss_func = sigmoid_focal_loss_cpu
loss = loss_func(logits, targets, self.gamma, self.alpha)
return loss.sum()
def __repr__(self):
tmpstr = self.__class__.__name__ + "("
tmpstr += "gamma=" + str(self.gamma)
tmpstr += ", alpha=" + str(self.alpha)
tmpstr += ")"
return tmpstr
......@@ -42,6 +42,29 @@ def build_resnet_fpn_backbone(cfg):
model = nn.Sequential(OrderedDict([("body", body), ("fpn", fpn)]))
return model
@registry.BACKBONES.register("R-50-FPN-RETINANET")
@registry.BACKBONES.register("R-101-FPN-RETINANET")
def build_resnet_fpn_p3p7_backbone(cfg):
body = resnet.ResNet(cfg)
in_channels_stage2 = cfg.MODEL.RESNETS.RES2_OUT_CHANNELS
out_channels = cfg.MODEL.BACKBONE.OUT_CHANNELS
in_channels_p6p7 = in_channels_stage2 * 8 if cfg.MODEL.RETINANET.USE_C5 \
else out_channels
fpn = fpn_module.FPN(
in_channels_list=[
0,
in_channels_stage2 * 2,
in_channels_stage2 * 4,
in_channels_stage2 * 8,
],
out_channels=out_channels,
conv_block=conv_with_kaiming_uniform(
cfg.MODEL.FPN.USE_GN, cfg.MODEL.FPN.USE_RELU
),
top_blocks=fpn_module.LastLevelP6P7(in_channels_p6p7, out_channels),
)
model = nn.Sequential(OrderedDict([("body", body), ("fpn", fpn)]))
return model
def build_backbone(cfg):
assert cfg.MODEL.BACKBONE.CONV_BODY in registry.BACKBONES, \
......
......@@ -29,6 +29,9 @@ class FPN(nn.Module):
for idx, in_channels in enumerate(in_channels_list, 1):
inner_block = "fpn_inner{}".format(idx)
layer_block = "fpn_layer{}".format(idx)
if in_channels == 0:
continue
inner_block_module = conv_block(in_channels, out_channels, 1)
layer_block_module = conv_block(out_channels, out_channels, 3, 1)
self.add_module(inner_block, inner_block_module)
......@@ -51,6 +54,8 @@ class FPN(nn.Module):
for feature, inner_block, layer_block in zip(
x[:-1][::-1], self.inner_blocks[:-1][::-1], self.layer_blocks[:-1][::-1]
):
if not inner_block:
continue
inner_top_down = F.interpolate(last_inner, scale_factor=2, mode="nearest")
inner_lateral = getattr(self, inner_block)(feature)
# TODO use size instead of scale to make it robust to different sizes
......@@ -59,7 +64,10 @@ class FPN(nn.Module):
last_inner = inner_lateral + inner_top_down
results.insert(0, getattr(self, layer_block)(last_inner))
if self.top_blocks is not None:
if isinstance(self.top_blocks, LastLevelP6P7):
last_results = self.top_blocks(x[-1], results[-1])
results.extend(last_results)
elif isinstance(self.top_blocks, LastLevelMaxPool):
last_results = self.top_blocks(results[-1])
results.extend(last_results)
......@@ -69,3 +77,23 @@ class FPN(nn.Module):
class LastLevelMaxPool(nn.Module):
def forward(self, x):
return [F.max_pool2d(x, 1, 2, 0)]
class LastLevelP6P7(nn.Module):
"""
This module is used in RetinaNet to generate extra layers, P6 and P7.
"""
def __init__(self, in_channels, out_channels):
super(LastLevelP6P7, self).__init__()
self.p6 = nn.Conv2d(in_channels, out_channels, 3, 2, 1)
self.p7 = nn.Conv2d(out_channels, out_channels, 3, 2, 1)
for module in [self.p6, self.p7]:
nn.init.kaiming_uniform_(module.weight, a=1)
nn.init.constant_(module.bias, 0)
self.use_P5 = in_channels == out_channels
def forward(self, c5, p5):
x = p5 if self.use_P5 else c5
p6 = self.p6(x)
p7 = self.p7(F.relu(p6))
return [p6, p7]
......@@ -410,6 +410,8 @@ _STAGE_SPECS = Registry({
"R-101-C4": ResNet101StagesTo4,
"R-101-C5": ResNet101StagesTo5,
"R-50-FPN": ResNet50FPNStagesTo5,
"R-50-FPN-RETINANET": ResNet50FPNStagesTo5,
"R-101-FPN": ResNet101FPNStagesTo5,
"R-101-FPN-RETINANET": ResNet101FPNStagesTo5,
"R-152-FPN": ResNet152FPNStagesTo5,
})
......@@ -59,6 +59,9 @@ def build_roi_heads(cfg):
# individually create the heads, that will be combined together
# afterwards
roi_heads = []
if cfg.MODEL.RETINANET_ON:
return []
if not cfg.MODEL.RPN_ONLY:
roi_heads.append(("box", build_roi_box_head(cfg)))
if cfg.MODEL.MASK_ON:
......
......@@ -54,8 +54,13 @@ class AnchorGenerator(nn.Module):
else:
if len(anchor_strides) != len(sizes):
raise RuntimeError("FPN should have #anchor_strides == #sizes")
cell_anchors = [
generate_anchors(anchor_stride, (size,), aspect_ratios).float()
generate_anchors(
anchor_stride,
size if isinstance(size, (tuple, list)) else (size,),
aspect_ratios
).float()
for anchor_stride, size in zip(anchor_strides, sizes)
]
self.strides = anchor_strides
......@@ -138,6 +143,28 @@ def make_anchor_generator(config):
return anchor_generator
def make_anchor_generator_retinanet(config):
anchor_sizes = config.MODEL.RETINANET.ANCHOR_SIZES
aspect_ratios = config.MODEL.RETINANET.ASPECT_RATIOS
anchor_strides = config.MODEL.RETINANET.ANCHOR_STRIDES
straddle_thresh = config.MODEL.RETINANET.STRADDLE_THRESH
octave = config.MODEL.RETINANET.OCTAVE
scales_per_octave = config.MODEL.RETINANET.SCALES_PER_OCTAVE
assert len(anchor_strides) == len(anchor_sizes), "Only support FPN now"
new_anchor_sizes = []
for size in anchor_sizes:
per_layer_anchor_sizes = []
for scale_per_octave in range(scales_per_octave):
octave_scale = octave ** (scale_per_octave / float(scales_per_octave))
per_layer_anchor_sizes.append(octave_scale * size)
new_anchor_sizes.append(tuple(per_layer_anchor_sizes))
anchor_generator = AnchorGenerator(
tuple(new_anchor_sizes), aspect_ratios, anchor_strides, straddle_thresh
)
return anchor_generator
# Copyright (c) 2017-present, Facebook, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
......
......@@ -8,7 +8,7 @@ from maskrcnn_benchmark.structures.boxlist_ops import boxlist_nms
from maskrcnn_benchmark.structures.boxlist_ops import remove_small_boxes
from ..utils import cat
from .utils import permute_and_flatten
class RPNPostProcessor(torch.nn.Module):
"""
......@@ -82,10 +82,10 @@ class RPNPostProcessor(torch.nn.Module):
N, A, H, W = objectness.shape
# put in the same format as anchors
objectness = objectness.permute(0, 2, 3, 1).reshape(N, -1)
objectness = permute_and_flatten(objectness, N, A, 1, H, W).view(N, -1)
objectness = objectness.sigmoid()
box_regression = box_regression.view(N, -1, 4, H, W).permute(0, 3, 4, 1, 2)
box_regression = box_regression.reshape(N, -1, 4)
box_regression = permute_and_flatten(box_regression, N, A, 4, H, W)
num_anchors = A * H * W
......
......@@ -7,6 +7,8 @@ file
import torch
from torch.nn import functional as F
from .utils import concat_box_prediction_layers
from ..balanced_positive_negative_sampler import BalancedPositiveNegativeSampler
from ..utils import cat
......@@ -21,7 +23,8 @@ class RPNLossComputation(object):
This class computes the RPN loss.
"""
def __init__(self, proposal_matcher, fg_bg_sampler, box_coder):
def __init__(self, proposal_matcher, fg_bg_sampler, box_coder,
generate_labels_func):
"""
Arguments:
proposal_matcher (Matcher)
......@@ -32,13 +35,16 @@ class RPNLossComputation(object):
self.proposal_matcher = proposal_matcher
self.fg_bg_sampler = fg_bg_sampler
self.box_coder = box_coder
self.copied_fields = []
self.generate_labels_func = generate_labels_func
self.discard_cases = ['not_visibility', 'between_thresholds']
def match_targets_to_anchors(self, anchor, target):
def match_targets_to_anchors(self, anchor, target, copied_fields=[]):
match_quality_matrix = boxlist_iou(target, anchor)
matched_idxs = self.proposal_matcher(match_quality_matrix)
# RPN doesn't need any fields from target
# for creating the labels, so clear them all
target = target.copy_with_fields([])
target = target.copy_with_fields(copied_fields)
# get the targets corresponding GT for each anchor
# NB: need to clamp the indices because we can have a single
# GT in the image, and matched_idxs can be -2, which goes
......@@ -52,18 +58,25 @@ class RPNLossComputation(object):
regression_targets = []
for anchors_per_image, targets_per_image in zip(anchors, targets):
matched_targets = self.match_targets_to_anchors(
anchors_per_image, targets_per_image
anchors_per_image, targets_per_image, self.copied_fields
)
matched_idxs = matched_targets.get_field("matched_idxs")
labels_per_image = matched_idxs >= 0
labels_per_image = self.generate_labels_func(matched_targets)
labels_per_image = labels_per_image.to(dtype=torch.float32)
# Background (negative examples)
bg_indices = matched_idxs == Matcher.BELOW_LOW_THRESHOLD
labels_per_image[bg_indices] = 0
# discard anchors that go out of the boundaries of the image
labels_per_image[~anchors_per_image.get_field("visibility")] = -1
if "not_visibility" in self.discard_cases:
labels_per_image[~anchors_per_image.get_field("visibility")] = -1
# discard indices that are between thresholds
inds_to_discard = matched_idxs == Matcher.BETWEEN_THRESHOLDS
labels_per_image[inds_to_discard] = -1
if "between_thresholds" in self.discard_cases:
inds_to_discard = matched_idxs == Matcher.BETWEEN_THRESHOLDS
labels_per_image[inds_to_discard] = -1
# compute regression targets
regression_targets_per_image = self.box_coder.encode(
......@@ -75,6 +88,7 @@ class RPNLossComputation(object):
return labels, regression_targets
def __call__(self, anchors, objectness, box_regression, targets):
"""
Arguments:
......@@ -95,29 +109,10 @@ class RPNLossComputation(object):
sampled_inds = torch.cat([sampled_pos_inds, sampled_neg_inds], dim=0)
objectness_flattened = []
box_regression_flattened = []
# for each feature level, permute the outputs to make them be in the
# same format as the labels. Note that the labels are computed for
# all feature levels concatenated, so we keep the same representation
# for the objectness and the box_regression
for objectness_per_level, box_regression_per_level in zip(
objectness, box_regression
):
N, A, H, W = objectness_per_level.shape
objectness_per_level = objectness_per_level.permute(0, 2, 3, 1).reshape(
N, -1
)
box_regression_per_level = box_regression_per_level.view(N, -1, 4, H, W)
box_regression_per_level = box_regression_per_level.permute(0, 3, 4, 1, 2)
box_regression_per_level = box_regression_per_level.reshape(N, -1, 4)
objectness_flattened.append(objectness_per_level)
box_regression_flattened.append(box_regression_per_level)
# concatenate on the first dimension (representing the feature levels), to
# take into account the way the labels were generated (with all feature maps
# being concatenated as well)
objectness = cat(objectness_flattened, dim=1).reshape(-1)
box_regression = cat(box_regression_flattened, dim=1).reshape(-1, 4)
objectness, box_regression = \
concat_box_prediction_layers(objectness, box_regression)
objectness = objectness.squeeze()
labels = torch.cat(labels, dim=0)
regression_targets = torch.cat(regression_targets, dim=0)
......@@ -135,6 +130,12 @@ class RPNLossComputation(object):
return objectness_loss, box_loss
# This function should be overwritten in RetinaNet
def generate_rpn_labels(matched_targets):
matched_idxs = matched_targets.get_field("matched_idxs")
labels_per_image = matched_idxs >= 0
return labels_per_image
def make_rpn_loss_evaluator(cfg, box_coder):
matcher = Matcher(
......@@ -147,5 +148,10 @@ def make_rpn_loss_evaluator(cfg, box_coder):
cfg.MODEL.RPN.BATCH_SIZE_PER_IMAGE, cfg.MODEL.RPN.POSITIVE_FRACTION
)
loss_evaluator = RPNLossComputation(matcher, fg_bg_sampler, box_coder)
loss_evaluator = RPNLossComputation(
matcher,
fg_bg_sampler,
box_coder,
generate_rpn_labels
)
return loss_evaluator
import torch
from ..inference import RPNPostProcessor
from ..utils import permute_and_flatten
from maskrcnn_benchmark.modeling.box_coder import BoxCoder
from maskrcnn_benchmark.modeling.utils import cat
from maskrcnn_benchmark.structures.bounding_box import BoxList
from maskrcnn_benchmark.structures.boxlist_ops import cat_boxlist
from maskrcnn_benchmark.structures.boxlist_ops import boxlist_nms
from maskrcnn_benchmark.structures.boxlist_ops import remove_small_boxes
class RetinaNetPostProcessor(RPNPostProcessor):
"""
Performs post-processing on the outputs of the RetinaNet boxes.
This is only used in the testing.
"""
def __init__(
self,
pre_nms_thresh,
pre_nms_top_n,
nms_thresh,
fpn_post_nms_top_n,
min_size,
num_classes,
box_coder=None,
):
"""
Arguments:
pre_nms_thresh (float)
pre_nms_top_n (int)
nms_thresh (float)
fpn_post_nms_top_n (int)
min_size (int)
num_classes (int)
box_coder (BoxCoder)
"""
super(RetinaNetPostProcessor, self).__init__(
pre_nms_thresh, 0, nms_thresh, min_size
)
self.pre_nms_thresh = pre_nms_thresh
self.pre_nms_top_n = pre_nms_top_n
self.nms_thresh = nms_thresh
self.fpn_post_nms_top_n = fpn_post_nms_top_n
self.min_size = min_size
self.num_classes = num_classes
if box_coder is None:
box_coder = BoxCoder(weights=(10., 10., 5., 5.))
self.box_coder = box_coder
def add_gt_proposals(self, proposals, targets):
"""
This function is not used in RetinaNet
"""
pass
def forward_for_single_feature_map(
self, anchors, box_cls, box_regression):
"""
Arguments:
anchors: list[BoxList]
box_cls: tensor of size N, A * C, H, W
box_regression: tensor of size N, A * 4, H, W
"""
device = box_cls.device
N, _, H, W = box_cls.shape
A = box_regression.size(1) // 4
C = box_cls.size(1) // A
# put in the same format as anchors
box_cls = permute_and_flatten(box_cls, N, A, C, H, W)
box_cls = box_cls.sigmoid()
box_regression = permute_and_flatten(box_regression, N, A, 4, H, W)
box_regression = box_regression.reshape(N, -1, 4)
num_anchors = A * H * W
candidate_inds = box_cls > self.pre_nms_thresh
pre_nms_top_n = candidate_inds.view(N, -1).sum(1)
pre_nms_top_n = pre_nms_top_n.clamp(max=self.pre_nms_top_n)
results = []
for per_box_cls, per_box_regression, per_pre_nms_top_n, \
per_candidate_inds, per_anchors in zip(
box_cls,
box_regression,
pre_nms_top_n,
candidate_inds,
anchors):
# Sort and select TopN
# TODO most of this can be made out of the loop for
# all images.
# TODO:Yang: Not easy to do. Because the numbers of detections are
# different in each image. Therefore, this part needs to be done
# per image.
per_box_cls = per_box_cls[per_candidate_inds]
per_box_cls, top_k_indices = \
per_box_cls.topk(per_pre_nms_top_n, sorted=False)
per_candidate_nonzeros = \
per_candidate_inds.nonzero()[top_k_indices, :]
per_box_loc = per_candidate_nonzeros[:, 0]
per_class = per_candidate_nonzeros[:, 1]
per_class += 1
detections = self.box_coder.decode(
per_box_regression[per_box_loc, :].view(-1, 4),
per_anchors.bbox[per_box_loc, :].view(-1, 4)
)
boxlist = BoxList(detections, per_anchors.size, mode="xyxy")
boxlist.add_field("labels", per_class)
boxlist.add_field("scores", per_box_cls)
boxlist = boxlist.clip_to_image(remove_empty=False)
boxlist = remove_small_boxes(boxlist, self.min_size)
results.append(boxlist)
return results
# TODO very similar to filter_results from PostProcessor
# but filter_results is per image
# TODO Yang: solve this issue in the future. No good solution
# right now.
def select_over_all_levels(self, boxlists):
num_images = len(boxlists)
results = []
for i in range(num_images):
scores = boxlists[i].get_field("scores")
labels = boxlists[i].get_field("labels")
boxes = boxlists[i].bbox
boxlist = boxlists[i]
result = []
# skip the background
for j in range(1, self.num_classes):
inds = (labels == j).nonzero().view(-1)
scores_j = scores[inds]
boxes_j = boxes[inds, :].view(-1, 4)
boxlist_for_class = BoxList(boxes_j, boxlist.size, mode="xyxy")
boxlist_for_class.add_field("scores", scores_j)
boxlist_for_class = boxlist_nms(
boxlist_for_class, self.nms_thresh,
score_field="scores"
)
num_labels = len(boxlist_for_class)
boxlist_for_class.add_field(
"labels", torch.full((num_labels,), j,
dtype=torch.int64,
device=scores.device)
)
result.append(boxlist_for_class)
result = cat_boxlist(result)
number_of_detections = len(result)
# Limit to max_per_image detections **over all classes**
if number_of_detections > self.fpn_post_nms_top_n > 0:
cls_scores = result.get_field("scores")
image_thresh, _ = torch.kthvalue(
cls_scores.cpu(),
number_of_detections - self.fpn_post_nms_top_n + 1
)
keep = cls_scores >= image_thresh.item()
keep = torch.nonzero(keep).squeeze(1)
result = result[keep]
results.append(result)
return results
def make_retinanet_postprocessor(config, rpn_box_coder, is_train):
pre_nms_thresh = config.MODEL.RETINANET.INFERENCE_TH
pre_nms_top_n = config.MODEL.RETINANET.PRE_NMS_TOP_N
nms_thresh = config.MODEL.RETINANET.NMS_TH
fpn_post_nms_top_n = config.TEST.DETECTIONS_PER_IMG
min_size = 0
box_selector = RetinaNetPostProcessor(
pre_nms_thresh=pre_nms_thresh,
pre_nms_top_n=pre_nms_top_n,
nms_thresh=nms_thresh,
fpn_post_nms_top_n=fpn_post_nms_top_n,
min_size=min_size,
num_classes=config.MODEL.RETINANET.NUM_CLASSES,
box_coder=rpn_box_coder,
)
return box_selector
"""
This file contains specific functions for computing losses on the RetinaNet
file
"""
import torch
from torch.nn import functional as F
from ..utils import concat_box_prediction_layers
from maskrcnn_benchmark.layers import smooth_l1_loss
from maskrcnn_benchmark.layers import SigmoidFocalLoss
from maskrcnn_benchmark.modeling.matcher import Matcher
from maskrcnn_benchmark.modeling.utils import cat
from maskrcnn_benchmark.structures.boxlist_ops import boxlist_iou
from maskrcnn_benchmark.structures.boxlist_ops import cat_boxlist
from maskrcnn_benchmark.modeling.rpn.loss import RPNLossComputation
class RetinaNetLossComputation(RPNLossComputation):
"""
This class computes the RetinaNet loss.
"""
def __init__(self, proposal_matcher, box_coder,
generate_labels_func,
sigmoid_focal_loss,
bbox_reg_beta=0.11,
regress_norm=1.0):
"""
Arguments:
proposal_matcher (Matcher)
box_coder (BoxCoder)
"""
self.proposal_matcher = proposal_matcher
self.box_coder = box_coder
self.box_cls_loss_func = sigmoid_focal_loss
self.bbox_reg_beta = bbox_reg_beta
self.copied_fields = ['labels']
self.generate_labels_func = generate_labels_func
self.discard_cases = ['between_thresholds']
self.regress_norm = regress_norm
def __call__(self, anchors, box_cls, box_regression, targets):
"""
Arguments:
anchors (list[BoxList])
box_cls (list[Tensor])
box_regression (list[Tensor])
targets (list[BoxList])
Returns:
retinanet_cls_loss (Tensor)
retinanet_regression_loss (Tensor
"""
anchors = [cat_boxlist(anchors_per_image) for anchors_per_image in anchors]
labels, regression_targets = self.prepare_targets(anchors, targets)
N = len(labels)
box_cls, box_regression = \
concat_box_prediction_layers(box_cls, box_regression)
labels = torch.cat(labels, dim=0)
regression_targets = torch.cat(regression_targets, dim=0)
pos_inds = torch.nonzero(labels > 0).squeeze(1)
retinanet_regression_loss = smooth_l1_loss(
box_regression[pos_inds],
regression_targets[pos_inds],
beta=self.bbox_reg_beta,
size_average=False,
) / (max(1, pos_inds.numel() * self.regress_norm))
labels = labels.int()
retinanet_cls_loss = self.box_cls_loss_func(
box_cls,
labels
) / (pos_inds.numel() + N)
return retinanet_cls_loss, retinanet_regression_loss
def generate_retinanet_labels(matched_targets):
labels_per_image = matched_targets.get_field("labels")
return labels_per_image
def make_retinanet_loss_evaluator(cfg, box_coder):
matcher = Matcher(
cfg.MODEL.RETINANET.FG_IOU_THRESHOLD,
cfg.MODEL.RETINANET.BG_IOU_THRESHOLD,
allow_low_quality_matches=True,
)
sigmoid_focal_loss = SigmoidFocalLoss(
cfg.MODEL.RETINANET.LOSS_GAMMA,
cfg.MODEL.RETINANET.LOSS_ALPHA
)
loss_evaluator = RetinaNetLossComputation(
matcher,
box_coder,
generate_retinanet_labels,
sigmoid_focal_loss,
bbox_reg_beta = cfg.MODEL.RETINANET.BBOX_REG_BETA,
regress_norm = cfg.MODEL.RETINANET.BBOX_REG_WEIGHT,
)
return loss_evaluator
import math
import torch
import torch.nn.functional as F
from torch import nn
from .inference import make_retinanet_postprocessor
from .loss import make_retinanet_loss_evaluator
from ..anchor_generator import make_anchor_generator_retinanet
from maskrcnn_benchmark.modeling.box_coder import BoxCoder
class RetinaNetHead(torch.nn.Module):
"""
Adds a RetinNet head with classification and regression heads
"""
def __init__(self, cfg):
"""
Arguments:
in_channels (int): number of channels of the input feature
num_anchors (int): number of anchors to be predicted
"""
super(RetinaNetHead, self).__init__()
# TODO: Implement the sigmoid version first.
num_classes = cfg.MODEL.RETINANET.NUM_CLASSES - 1
in_channels = cfg.MODEL.BACKBONE.OUT_CHANNELS
num_anchors = len(cfg.MODEL.RETINANET.ASPECT_RATIOS) \
* cfg.MODEL.RETINANET.SCALES_PER_OCTAVE
cls_tower = []
bbox_tower = []
for i in range(cfg.MODEL.RETINANET.NUM_CONVS):
cls_tower.append(
nn.Conv2d(
in_channels,
in_channels,
kernel_size=3,
stride=1,
padding=1
)
)
cls_tower.append(nn.ReLU())
bbox_tower.append(
nn.Conv2d(
in_channels,
in_channels,
kernel_size=3,
stride=1,
padding=1
)
)
bbox_tower.append(nn.ReLU())
self.add_module('cls_tower', nn.Sequential(*cls_tower))
self.add_module('bbox_tower', nn.Sequential(*bbox_tower))
self.cls_logits = nn.Conv2d(
in_channels, num_anchors * num_classes, kernel_size=3, stride=1,
padding=1
)
self.bbox_pred = nn.Conv2d(
in_channels, num_anchors * 4, kernel_size=3, stride=1,
padding=1
)
# Initialization
for modules in [self.cls_tower, self.bbox_tower, self.cls_logits,
self.bbox_pred]:
for l in modules.modules():
if isinstance(l, nn.Conv2d):
torch.nn.init.normal_(l.weight, std=0.01)
torch.nn.init.constant_(l.bias, 0)
# retinanet_bias_init
prior_prob = cfg.MODEL.RETINANET.PRIOR_PROB
bias_value = -math.log((1 - prior_prob) / prior_prob)
torch.nn.init.constant_(self.cls_logits.bias, bias_value)
def forward(self, x):
logits = []
bbox_reg = []
for feature in x:
logits.append(self.cls_logits(self.cls_tower(feature)))
bbox_reg.append(self.bbox_pred(self.bbox_tower(feature)))
return logits, bbox_reg
class RetinaNetModule(torch.nn.Module):
"""
Module for RetinaNet computation. Takes feature maps from the backbone and
RetinaNet outputs and losses. Only Test on FPN now.
"""
def __init__(self, cfg):
super(RetinaNetModule, self).__init__()
self.cfg = cfg.clone()
anchor_generator = make_anchor_generator_retinanet(cfg)
head = RetinaNetHead(cfg)
box_coder = BoxCoder(weights=(10., 10., 5., 5.))
box_selector_test = make_retinanet_postprocessor(cfg, box_coder, is_train=False)
loss_evaluator = make_retinanet_loss_evaluator(cfg, box_coder)
self.anchor_generator = anchor_generator
self.head = head
self.box_selector_test = box_selector_test
self.loss_evaluator = loss_evaluator
def forward(self, images, features, targets=None):
"""
Arguments:
images (ImageList): images for which we want to compute the predictions
features (list[Tensor]): features computed from the images that are
used for computing the predictions. Each tensor in the list
correspond to different feature levels
targets (list[BoxList): ground-truth boxes present in the image (optional)
Returns:
boxes (list[BoxList]): the predicted boxes from the RPN, one BoxList per
image.
losses (dict[Tensor]): the losses for the model during training. During
testing, it is an empty dict.
"""
box_cls, box_regression = self.head(features)
anchors = self.anchor_generator(images, features)
if self.training:
return self._forward_train(anchors, box_cls, box_regression, targets)
else:
return self._forward_test(anchors, box_cls, box_regression)
def _forward_train(self, anchors, box_cls, box_regression, targets):
loss_box_cls, loss_box_reg = self.loss_evaluator(
anchors, box_cls, box_regression, targets
)
losses = {
"loss_retina_cls": loss_box_cls,
"loss_retina_reg": loss_box_reg,
}
return anchors, losses
def _forward_test(self, anchors, box_cls, box_regression):
boxes = self.box_selector_test(anchors, box_cls, box_regression)
return boxes, {}
def build_retinanet(cfg):
return RetinaNetModule(cfg)
......@@ -5,11 +5,11 @@ from torch import nn
from maskrcnn_benchmark.modeling import registry
from maskrcnn_benchmark.modeling.box_coder import BoxCoder
from maskrcnn_benchmark.modeling.rpn.retinanet.retinanet import build_retinanet
from .loss import make_rpn_loss_evaluator
from .anchor_generator import make_anchor_generator
from .inference import make_rpn_postprocessor
@registry.RPN_HEADS.register("SingleConvRPNHead")
class RPNHead(nn.Module):
"""
......@@ -142,4 +142,7 @@ def build_rpn(cfg):
"""
This gives the gist of it. Not super important because it doesn't change as much
"""
if cfg.MODEL.RETINANET_ON:
return build_retinanet(cfg)
return RPNModule(cfg)
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
"""
Utility functions minipulating the prediction layers
"""
from ..utils import cat
import torch
def permute_and_flatten(layer, N, A, C, H, W):
layer = layer.view(N, -1, C, H, W)
layer = layer.permute(0, 3, 4, 1, 2)
layer = layer.reshape(N, -1, C)
return layer
def concat_box_prediction_layers(box_cls, box_regression):
box_cls_flattened = []
box_regression_flattened = []
# for each feature level, permute the outputs to make them be in the
# same format as the labels. Note that the labels are computed for
# all feature levels concatenated, so we keep the same representation
# for the objectness and the box_regression
for box_cls_per_level, box_regression_per_level in zip(
box_cls, box_regression
):
N, AxC, H, W = box_cls_per_level.shape
Ax4 = box_regression_per_level.shape[1]
A = Ax4 // 4
C = AxC // A
box_cls_per_level = permute_and_flatten(
box_cls_per_level, N, A, C, H, W
)
box_cls_flattened.append(box_cls_per_level)
box_regression_per_level = permute_and_flatten(
box_regression_per_level, N, A, 4, H, W
)
box_regression_flattened.append(box_regression_per_level)
# concatenate on the first dimension (representing the feature levels), to
# take into account the way the labels were generated (with all feature maps
# being concatenated as well)
box_cls = cat(box_cls_flattened, dim=1).reshape(-1, C)
box_regression = cat(box_regression_flattened, dim=1).reshape(-1, 4)
return box_cls, box_regression
......@@ -157,12 +157,15 @@ C2_FORMAT_LOADER = Registry()
@C2_FORMAT_LOADER.register("R-101-C4")
@C2_FORMAT_LOADER.register("R-101-C5")
@C2_FORMAT_LOADER.register("R-50-FPN")
@C2_FORMAT_LOADER.register("R-50-FPN-RETINANET")
@C2_FORMAT_LOADER.register("R-101-FPN")
@C2_FORMAT_LOADER.register("R-101-FPN-RETINANET")
@C2_FORMAT_LOADER.register("R-152-FPN")
def load_resnet_c2_format(cfg, f):
state_dict = _load_c2_pickled_weights(f)
conv_body = cfg.MODEL.BACKBONE.CONV_BODY
arch = conv_body.replace("-C4", "").replace("-C5", "").replace("-FPN", "")
arch = arch.replace("-RETINANET", "")
stages = _C2_STAGE_NAMES[arch]
state_dict = _rename_weights_for_resnet(state_dict, stages)
return dict(model=state_dict)
......
......@@ -84,7 +84,7 @@ def main():
data_loader_val,
dataset_name=dataset_name,
iou_types=iou_types,
box_only=cfg.MODEL.RPN_ONLY,
box_only=False if cfg.MODEL.RETINANET_ON else cfg.MODEL.RPN_ONLY,
device=cfg.MODEL.DEVICE,
expected_results=cfg.TEST.EXPECTED_RESULTS,
expected_results_sigma_tol=cfg.TEST.EXPECTED_RESULTS_SIGMA_TOL,
......
......@@ -99,7 +99,7 @@ def test(cfg, model, distributed):
data_loader_val,
dataset_name=dataset_name,
iou_types=iou_types,
box_only=cfg.MODEL.RPN_ONLY,
box_only=False if cfg.MODEL.RETINANET_ON else cfg.MODEL.RPN_ONLY,
device=cfg.MODEL.DEVICE,
expected_results=cfg.TEST.EXPECTED_RESULTS,
expected_results_sigma_tol=cfg.TEST.EXPECTED_RESULTS_SIGMA_TOL,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment