Commit 6b1ab017 authored by Cheng-Yang Fu's avatar Cheng-Yang Fu Committed by Francisco Massa

Add RetinaNet Implementation (#102)

* Add RetinetNet parameters in cfg.

* hot fix.

* Add the retinanet head module now.

* Add the function to generate the anchors for RetinaNet.

* Add the SigmoidFocalLoss cuda operator.

* Fix the bug in the extra layers.

* Change the normalizer for SigmoidFocalLoss

* Support multiscale in training.

* Add retinannet  training script.

* Add the inference part of RetinaNet.

* Fix the bug when building the extra layers in retinanet.
Update the matching part in retinanet_loss.

* Add the first version of the inference of RetinaNet.
Need to check it again to see if is there any room for speed
improvement.

* Remove the  retinanet_R-50-FPN_2x.yaml first.

* Optimize the retinanet postprocessing.

* quick fix.

* Add script for training RetinaNet with ResNet101 backbone.

* Move cfg.RETINANET to cfg.MODEL.RETINANET

* Remove the variables which are not used.

* revert boxlist_ops.
Generate Empty BoxLists instead of [] in retinanet_infer

* Remove the not used commented lines.
Add NUM_DETECTIONS_PER_IMAGE

* remove the not used codes.

* Move retinanet related files under Modeling/rpn/retinanet

* Add retinanet_X_101_32x8d_FPN_1x.yaml script.
This model is not fully validated. I only trained it around 5000
iterations and everything is fine.

* set RETINANET.PRE_NMS_TOP_N as 0 in level5 (p7), because previous setting may generate zero detections and could cause
the program break.
This part is used in original Detectron setting.

* Fix the rpn only bug when the training ends.

* Minor improvements

* Comments and add Python-only implementation

* Bugfix and remove commented code

* keep the generalized_rcnn same.
Move the build_retinanet inside build_rpn.

* Add USE_C5 in the MODEL.RETINANET

* Add two configs using P5 to generate P6.

* fix the bug when loading the Caffe2 ImageNet pretrained model.

* Reduce the code depulication of RPN loss and RetinaNet loss.

* Remove the comment which is not used.

* Remove the hard coded number of classes.

* share the foward part of rpn inference.

* fix the bug in rpn inference.

* Remove the conditional part in the inference.

* Bug fix: add the utils file for permute and flatten of the box
prediction layers.

* Update the comment.

* quick fix. Adding import cat.

* quick fix: forget including import.

* Adjust the normalization part according to Detectron's setting.

* Use the bbox reg normalization term.

* Clean the code according to recent review.

* Using CUDA version for training now. And the python version for training
on cpu.

* rename the directory to retinanet.

* Make the train and val datasets are consistent with mask r-cnn setting.

* add comment.
parent 595694cb
MODEL:
META_ARCHITECTURE: "GeneralizedRCNN"
WEIGHT: "catalog://ImageNetPretrained/MSRA/R-101"
RPN_ONLY: True
RETINANET_ON: True
BACKBONE:
CONV_BODY: "R-101-FPN-RETINANET"
OUT_CHANNELS: 256
RPN:
USE_FPN: True
FG_IOU_THRESHOLD: 0.5
BG_IOU_THRESHOLD: 0.4
ANCHOR_STRIDE: (4, 8, 16, 32, 64)
PRE_NMS_TOP_N_TRAIN: 2000
PRE_NMS_TOP_N_TEST: 1000
POST_NMS_TOP_N_TEST: 1000
FPN_POST_NMS_TOP_N_TEST: 1000
ROI_HEADS:
USE_FPN: True
BATCH_SIZE_PER_IMAGE: 256
ROI_BOX_HEAD:
POOLER_RESOLUTION: 7
POOLER_SCALES: (0.25, 0.125, 0.0625, 0.03125)
POOLER_SAMPLING_RATIO: 2
FEATURE_EXTRACTOR: "FPN2MLPFeatureExtractor"
PREDICTOR: "FPNPredictor"
RETINANET:
SCALES_PER_OCTAVE: 3
STRADDLE_THRESH: -1
FG_IOU_THRESHOLD: 0.5
BG_IOU_THRESHOLD: 0.4
DATASETS:
TRAIN: ("coco_2014_train", "coco_2014_valminusminival")
TEST: ("coco_2014_minival",)
INPUT:
MIN_SIZE_TRAIN: (800, )
MAX_SIZE_TRAIN: 1333
MIN_SIZE_TEST: 800
MAX_SIZE_TEST: 1333
DATALOADER:
SIZE_DIVISIBILITY: 32
SOLVER:
# Assume 4 gpus
BASE_LR: 0.005
WEIGHT_DECAY: 0.0001
STEPS: (120000, 160000)
MAX_ITER: 180000
IMS_PER_BATCH: 8
MODEL:
META_ARCHITECTURE: "GeneralizedRCNN"
WEIGHT: "catalog://ImageNetPretrained/MSRA/R-101"
RPN_ONLY: True
RETINANET_ON: True
BACKBONE:
CONV_BODY: "R-101-FPN-RETINANET"
OUT_CHANNELS: 256
RPN:
USE_FPN: True
FG_IOU_THRESHOLD: 0.5
BG_IOU_THRESHOLD: 0.4
ANCHOR_STRIDE: (4, 8, 16, 32, 64)
PRE_NMS_TOP_N_TRAIN: 2000
PRE_NMS_TOP_N_TEST: 1000
POST_NMS_TOP_N_TEST: 1000
FPN_POST_NMS_TOP_N_TEST: 1000
ROI_HEADS:
USE_FPN: True
BATCH_SIZE_PER_IMAGE: 256
ROI_BOX_HEAD:
POOLER_RESOLUTION: 7
POOLER_SCALES: (0.25, 0.125, 0.0625, 0.03125)
POOLER_SAMPLING_RATIO: 2
FEATURE_EXTRACTOR: "FPN2MLPFeatureExtractor"
PREDICTOR: "FPNPredictor"
RETINANET:
SCALES_PER_OCTAVE: 3
STRADDLE_THRESH: -1
USE_C5: False
FG_IOU_THRESHOLD: 0.5
BG_IOU_THRESHOLD: 0.4
DATASETS:
TRAIN: ("coco_2014_train", "coco_2014_valminusminival")
TEST: ("coco_2014_minival",)
INPUT:
MIN_SIZE_TRAIN: (800, )
MAX_SIZE_TRAIN: 1333
MIN_SIZE_TEST: 800
MAX_SIZE_TEST: 1333
DATALOADER:
SIZE_DIVISIBILITY: 32
SOLVER:
# Assume 4 gpus
BASE_LR: 0.005
WEIGHT_DECAY: 0.0001
STEPS: (120000, 160000)
MAX_ITER: 180000
IMS_PER_BATCH: 8
MODEL:
META_ARCHITECTURE: "GeneralizedRCNN"
WEIGHT: "catalog://ImageNetPretrained/MSRA/R-50"
RPN_ONLY: True
RETINANET_ON: True
BACKBONE:
CONV_BODY: "R-50-FPN-RETINANET"
OUT_CHANNELS: 256
RPN:
USE_FPN: True
FG_IOU_THRESHOLD: 0.5
BG_IOU_THRESHOLD: 0.4
ANCHOR_STRIDE: (4, 8, 16, 32, 64)
PRE_NMS_TOP_N_TRAIN: 2000
PRE_NMS_TOP_N_TEST: 1000
POST_NMS_TOP_N_TEST: 1000
FPN_POST_NMS_TOP_N_TEST: 1000
ROI_HEADS:
USE_FPN: True
BATCH_SIZE_PER_IMAGE: 256
ROI_BOX_HEAD:
POOLER_RESOLUTION: 7
POOLER_SCALES: (0.25, 0.125, 0.0625, 0.03125)
POOLER_SAMPLING_RATIO: 2
FEATURE_EXTRACTOR: "FPN2MLPFeatureExtractor"
PREDICTOR: "FPNPredictor"
RETINANET:
SCALES_PER_OCTAVE: 3
STRADDLE_THRESH: -1
FG_IOU_THRESHOLD: 0.5
BG_IOU_THRESHOLD: 0.4
DATASETS:
TRAIN: ("coco_2014_train", "coco_2014_valminusminival")
TEST: ("coco_2014_minival",)
INPUT:
MIN_SIZE_TRAIN: (800,)
MAX_SIZE_TRAIN: 1333
MIN_SIZE_TEST: 800
MAX_SIZE_TEST: 1333
DATALOADER:
SIZE_DIVISIBILITY: 32
SOLVER:
# Assume 4 gpus
BASE_LR: 0.005
WEIGHT_DECAY: 0.0001
STEPS: (120000, 160000)
MAX_ITER: 180000
IMS_PER_BATCH: 8
MODEL:
META_ARCHITECTURE: "GeneralizedRCNN"
WEIGHT: "catalog://ImageNetPretrained/MSRA/R-50"
RPN_ONLY: True
RETINANET_ON: True
BACKBONE:
CONV_BODY: "R-50-FPN-RETINANET"
OUT_CHANNELS: 256
RPN:
USE_FPN: True
FG_IOU_THRESHOLD: 0.5
BG_IOU_THRESHOLD: 0.4
ANCHOR_STRIDE: (4, 8, 16, 32, 64)
PRE_NMS_TOP_N_TRAIN: 2000
PRE_NMS_TOP_N_TEST: 1000
POST_NMS_TOP_N_TEST: 1000
FPN_POST_NMS_TOP_N_TEST: 1000
ROI_HEADS:
USE_FPN: True
BATCH_SIZE_PER_IMAGE: 256
ROI_BOX_HEAD:
POOLER_RESOLUTION: 7
POOLER_SCALES: (0.25, 0.125, 0.0625, 0.03125)
POOLER_SAMPLING_RATIO: 2
FEATURE_EXTRACTOR: "FPN2MLPFeatureExtractor"
PREDICTOR: "FPNPredictor"
RETINANET:
SCALES_PER_OCTAVE: 3
STRADDLE_THRESH: -1
FG_IOU_THRESHOLD: 0.5
BG_IOU_THRESHOLD: 0.4
DATASETS:
TRAIN: ("coco_2014_minival",)
TEST: ("coco_2014_minival",)
INPUT:
MIN_SIZE_TRAIN: (600,)
MAX_SIZE_TRAIN: 1000
MIN_SIZE_TEST: 800
MAX_SIZE_TEST: 1000
DATALOADER:
SIZE_DIVISIBILITY: 32
SOLVER:
BASE_LR: 0.005
WEIGHT_DECAY: 0.0001
STEPS: (3500,)
MAX_ITER: 4000
IMS_PER_BATCH: 4
MODEL:
META_ARCHITECTURE: "GeneralizedRCNN"
WEIGHT: "catalog://ImageNetPretrained/MSRA/R-50"
RPN_ONLY: True
RETINANET_ON: True
BACKBONE:
CONV_BODY: "R-50-FPN-RETINANET"
OUT_CHANNELS: 256
RPN:
USE_FPN: True
FG_IOU_THRESHOLD: 0.5
BG_IOU_THRESHOLD: 0.4
ANCHOR_STRIDE: (4, 8, 16, 32, 64)
PRE_NMS_TOP_N_TRAIN: 2000
PRE_NMS_TOP_N_TEST: 1000
POST_NMS_TOP_N_TEST: 1000
FPN_POST_NMS_TOP_N_TEST: 1000
ROI_HEADS:
USE_FPN: True
BATCH_SIZE_PER_IMAGE: 256
ROI_BOX_HEAD:
POOLER_RESOLUTION: 7
POOLER_SCALES: (0.25, 0.125, 0.0625, 0.03125)
POOLER_SAMPLING_RATIO: 2
FEATURE_EXTRACTOR: "FPN2MLPFeatureExtractor"
PREDICTOR: "FPNPredictor"
RETINANET:
SCALES_PER_OCTAVE: 3
STRADDLE_THRESH: -1
USE_C5: False
FG_IOU_THRESHOLD: 0.5
BG_IOU_THRESHOLD: 0.4
DATASETS:
TRAIN: ("coco_2014_train", "coco_2014_valminusminival")
TEST: ("coco_2014_minival",)
INPUT:
MIN_SIZE_TRAIN: (800,)
MAX_SIZE_TRAIN: 1333
MIN_SIZE_TEST: 800
MAX_SIZE_TEST: 1333
DATALOADER:
SIZE_DIVISIBILITY: 32
SOLVER:
# Assume 4 gpus
BASE_LR: 0.005
WEIGHT_DECAY: 0.0001
STEPS: (120000, 160000)
MAX_ITER: 180000
IMS_PER_BATCH: 8
MODEL:
META_ARCHITECTURE: "GeneralizedRCNN"
WEIGHT: "catalog://ImageNetPretrained/FAIR/20171220/X-101-32x8d"
RPN_ONLY: True
RETINANET_ON: True
BACKBONE:
CONV_BODY: "R-101-FPN-RETINANET"
OUT_CHANNELS: 256
RPN:
USE_FPN: True
FG_IOU_THRESHOLD: 0.5
BG_IOU_THRESHOLD: 0.4
ANCHOR_STRIDE: (4, 8, 16, 32, 64)
PRE_NMS_TOP_N_TRAIN: 2000
PRE_NMS_TOP_N_TEST: 1000
POST_NMS_TOP_N_TEST: 1000
FPN_POST_NMS_TOP_N_TEST: 1000
ROI_HEADS:
USE_FPN: True
BATCH_SIZE_PER_IMAGE: 256
ROI_BOX_HEAD:
POOLER_RESOLUTION: 7
POOLER_SCALES: (0.25, 0.125, 0.0625, 0.03125)
POOLER_SAMPLING_RATIO: 2
FEATURE_EXTRACTOR: "FPN2MLPFeatureExtractor"
PREDICTOR: "FPNPredictor"
RESNETS:
STRIDE_IN_1X1: False
NUM_GROUPS: 32
WIDTH_PER_GROUP: 8
RETINANET:
SCALES_PER_OCTAVE: 3
STRADDLE_THRESH: -1
FG_IOU_THRESHOLD: 0.5
BG_IOU_THRESHOLD: 0.4
DATASETS:
TRAIN: ("coco_2014_train", "coco_2014_valminusminival")
TEST: ("coco_2014_minival",)
INPUT:
MIN_SIZE_TRAIN: (800, )
MAX_SIZE_TRAIN: 1333
MIN_SIZE_TEST: 800
MAX_SIZE_TEST: 1333
DATALOADER:
SIZE_DIVISIBILITY: 32
SOLVER:
# Assume 4 gpus
BASE_LR: 0.0025
WEIGHT_DECAY: 0.0001
STEPS: (240000, 320000)
MAX_ITER: 360000
IMS_PER_BATCH: 4
...@@ -23,6 +23,7 @@ _C = CN() ...@@ -23,6 +23,7 @@ _C = CN()
_C.MODEL = CN() _C.MODEL = CN()
_C.MODEL.RPN_ONLY = False _C.MODEL.RPN_ONLY = False
_C.MODEL.MASK_ON = False _C.MODEL.MASK_ON = False
_C.MODEL.RETINANET_ON = False
_C.MODEL.KEYPOINT_ON = False _C.MODEL.KEYPOINT_ON = False
_C.MODEL.DEVICE = "cuda" _C.MODEL.DEVICE = "cuda"
_C.MODEL.META_ARCHITECTURE = "GeneralizedRCNN" _C.MODEL.META_ARCHITECTURE = "GeneralizedRCNN"
...@@ -273,6 +274,67 @@ _C.MODEL.RESNETS.RES5_DILATION = 1 ...@@ -273,6 +274,67 @@ _C.MODEL.RESNETS.RES5_DILATION = 1
_C.MODEL.RESNETS.RES2_OUT_CHANNELS = 256 _C.MODEL.RESNETS.RES2_OUT_CHANNELS = 256
_C.MODEL.RESNETS.STEM_OUT_CHANNELS = 64 _C.MODEL.RESNETS.STEM_OUT_CHANNELS = 64
# ---------------------------------------------------------------------------- #
# RetinaNet Options (Follow the Detectron version)
# ---------------------------------------------------------------------------- #
_C.MODEL.RETINANET = CN()
# This is the number of foreground classes and background.
_C.MODEL.RETINANET.NUM_CLASSES = 81
# Anchor aspect ratios to use
_C.MODEL.RETINANET.ANCHOR_SIZES = (32, 64, 128, 256, 512)
_C.MODEL.RETINANET.ASPECT_RATIOS = (0.5, 1.0, 2.0)
_C.MODEL.RETINANET.ANCHOR_STRIDES = (8, 16, 32, 64, 128)
_C.MODEL.RETINANET.STRADDLE_THRESH = 0
# Anchor scales per octave
_C.MODEL.RETINANET.OCTAVE = 2.0
_C.MODEL.RETINANET.SCALES_PER_OCTAVE = 3
# Use C5 or P5 to generate P6
_C.MODEL.RETINANET.USE_C5 = True
# Convolutions to use in the cls and bbox tower
# NOTE: this doesn't include the last conv for logits
_C.MODEL.RETINANET.NUM_CONVS = 4
# Weight for bbox_regression loss
_C.MODEL.RETINANET.BBOX_REG_WEIGHT = 4.0
# Smooth L1 loss beta for bbox regression
_C.MODEL.RETINANET.BBOX_REG_BETA = 0.11
# During inference, #locs to select based on cls score before NMS is performed
# per FPN level
_C.MODEL.RETINANET.PRE_NMS_TOP_N = 1000
# IoU overlap ratio for labeling an anchor as positive
# Anchors with >= iou overlap are labeled positive
_C.MODEL.RETINANET.FG_IOU_THRESHOLD = 0.5
# IoU overlap ratio for labeling an anchor as negative
# Anchors with < iou overlap are labeled negative
_C.MODEL.RETINANET.BG_IOU_THRESHOLD = 0.4
# Focal loss parameter: alpha
_C.MODEL.RETINANET.LOSS_ALPHA = 0.25
# Focal loss parameter: gamma
_C.MODEL.RETINANET.LOSS_GAMMA = 2.0
# Prior prob for the positives at the beginning of training. This is used to set
# the bias init for the logits layer
_C.MODEL.RETINANET.PRIOR_PROB = 0.01
# Inference cls score threshold, anchors with score > INFERENCE_TH are
# considered for inference
_C.MODEL.RETINANET.INFERENCE_TH = 0.05
# NMS threshold used in RetinaNet
_C.MODEL.RETINANET.NMS_TH = 0.4
# ---------------------------------------------------------------------------- # # ---------------------------------------------------------------------------- #
# Solver # Solver
# ---------------------------------------------------------------------------- # # ---------------------------------------------------------------------------- #
...@@ -311,6 +373,8 @@ _C.TEST.EXPECTED_RESULTS_SIGMA_TOL = 4 ...@@ -311,6 +373,8 @@ _C.TEST.EXPECTED_RESULTS_SIGMA_TOL = 4
# This is global, so if we have 8 GPUs and IMS_PER_BATCH = 16, each GPU will # This is global, so if we have 8 GPUs and IMS_PER_BATCH = 16, each GPU will
# see 2 images per batch # see 2 images per batch
_C.TEST.IMS_PER_BATCH = 8 _C.TEST.IMS_PER_BATCH = 8
# Number of detections per image
_C.TEST.DETECTIONS_PER_IMG = 100
# ---------------------------------------------------------------------------- # # ---------------------------------------------------------------------------- #
......
#pragma once
#include "cpu/vision.h"
#ifdef WITH_CUDA
#include "cuda/vision.h"
#endif
// Interface for Python
at::Tensor SigmoidFocalLoss_forward(
const at::Tensor& logits,
const at::Tensor& targets,
const int num_classes,
const float gamma,
const float alpha) {
if (logits.type().is_cuda()) {
#ifdef WITH_CUDA
return SigmoidFocalLoss_forward_cuda(logits, targets, num_classes, gamma, alpha);
#else
AT_ERROR("Not compiled with GPU support");
#endif
}
AT_ERROR("Not implemented on the CPU");
}
at::Tensor SigmoidFocalLoss_backward(
const at::Tensor& logits,
const at::Tensor& targets,
const at::Tensor& d_losses,
const int num_classes,
const float gamma,
const float alpha) {
if (logits.type().is_cuda()) {
#ifdef WITH_CUDA
return SigmoidFocalLoss_backward_cuda(logits, targets, d_losses, num_classes, gamma, alpha);
#else
AT_ERROR("Not compiled with GPU support");
#endif
}
AT_ERROR("Not implemented on the CPU");
}
// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
// This file is modified from https://github.com/pytorch/pytorch/blob/master/modules/detectron/sigmoid_focal_loss_op.cu
// Cheng-Yang Fu
// cyfu@cs.unc.edu
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <THC/THC.h>
#include <THC/THCAtomics.cuh>
#include <THC/THCDeviceUtils.cuh>
#include <cfloat>
// TODO make it in a common file
#define CUDA_1D_KERNEL_LOOP(i, n) \
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; \
i += blockDim.x * gridDim.x)
template <typename T>
__global__ void SigmoidFocalLossForward(const int nthreads,
const T* logits,
const int* targets,
const int num_classes,
const float gamma,
const float alpha,
const int num,
T* losses) {
CUDA_1D_KERNEL_LOOP(i, nthreads) {
int n = i / num_classes;
int d = i % num_classes; // current class[0~79];
int t = targets[n]; // target class [1~80];
// Decide it is positive or negative case.
T c1 = (t == (d+1));
T c2 = (t>=0 & t != (d+1));
T zn = (1.0 - alpha);
T zp = (alpha);
// p = 1. / 1. + expf(-x); p = sigmoid(x)
T p = 1. / (1. + expf(-logits[i]));
// (1-p)**gamma * log(p) where
T term1 = powf((1. - p), gamma) * logf(max(p, FLT_MIN));
// p**gamma * log(1-p)
T term2 = powf(p, gamma) *
(-1. * logits[i] * (logits[i] >= 0) -
logf(1. + expf(logits[i] - 2. * logits[i] * (logits[i] >= 0))));
losses[i] = 0.0;
losses[i] += -c1 * term1 * zp;
losses[i] += -c2 * term2 * zn;
} // CUDA_1D_KERNEL_LOOP
} // SigmoidFocalLossForward
template <typename T>
__global__ void SigmoidFocalLossBackward(const int nthreads,
const T* logits,
const int* targets,
const T* d_losses,
const int num_classes,
const float gamma,
const float alpha,
const int num,
T* d_logits) {
CUDA_1D_KERNEL_LOOP(i, nthreads) {
int n = i / num_classes;
int d = i % num_classes; // current class[0~79];
int t = targets[n]; // target class [1~80], 0 is background;
// Decide it is positive or negative case.
T c1 = (t == (d+1));
T c2 = (t>=0 & t != (d+1));
T zn = (1.0 - alpha);
T zp = (alpha);
// p = 1. / 1. + expf(-x); p = sigmoid(x)
T p = 1. / (1. + expf(-logits[i]));
// (1-p)**g * (1 - p - g*p*log(p)
T term1 = powf((1. - p), gamma) *
(1. - p - (p * gamma * logf(max(p, FLT_MIN))));
// (p**g) * (g*(1-p)*log(1-p) - p)
T term2 = powf(p, gamma) *
((-1. * logits[i] * (logits[i] >= 0) -
logf(1. + expf(logits[i] - 2. * logits[i] * (logits[i] >= 0)))) *
(1. - p) * gamma - p);
d_logits[i] = 0.0;
d_logits[i] += -c1 * term1 * zp;
d_logits[i] += -c2 * term2 * zn;
d_logits[i] = d_logits[i] * d_losses[i];
} // CUDA_1D_KERNEL_LOOP
} // SigmoidFocalLossBackward
at::Tensor SigmoidFocalLoss_forward_cuda(
const at::Tensor& logits,
const at::Tensor& targets,
const int num_classes,
const float gamma,
const float alpha) {
AT_ASSERTM(logits.type().is_cuda(), "logits must be a CUDA tensor");
AT_ASSERTM(targets.type().is_cuda(), "targets must be a CUDA tensor");
AT_ASSERTM(logits.dim() == 2, "logits should be NxClass");
const int num_samples = logits.size(0);
auto losses = at::empty({num_samples, logits.size(1)}, logits.options());
auto losses_size = num_samples * logits.size(1);
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
dim3 grid(std::min(THCCeilDiv(losses_size, 512L), 4096L));
dim3 block(512);
if (losses.numel() == 0) {
THCudaCheck(cudaGetLastError());
return losses;
}
AT_DISPATCH_FLOATING_TYPES(logits.type(), "SigmoidFocalLoss_forward", [&] {
SigmoidFocalLossForward<scalar_t><<<grid, block, 0, stream>>>(
losses_size,
logits.contiguous().data<scalar_t>(),
targets.contiguous().data<int>(),
num_classes,
gamma,
alpha,
num_samples,
losses.data<scalar_t>());
});
THCudaCheck(cudaGetLastError());
return losses;
}
at::Tensor SigmoidFocalLoss_backward_cuda(
const at::Tensor& logits,
const at::Tensor& targets,
const at::Tensor& d_losses,
const int num_classes,
const float gamma,
const float alpha) {
AT_ASSERTM(logits.type().is_cuda(), "logits must be a CUDA tensor");
AT_ASSERTM(targets.type().is_cuda(), "targets must be a CUDA tensor");
AT_ASSERTM(d_losses.type().is_cuda(), "d_losses must be a CUDA tensor");
AT_ASSERTM(logits.dim() == 2, "logits should be NxClass");
const int num_samples = logits.size(0);
AT_ASSERTM(logits.size(1) == num_classes, "logits.size(1) should be num_classes");
auto d_logits = at::zeros({num_samples, num_classes}, logits.options());
auto d_logits_size = num_samples * logits.size(1);
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
dim3 grid(std::min(THCCeilDiv(d_logits_size, 512L), 4096L));
dim3 block(512);
if (d_logits.numel() == 0) {
THCudaCheck(cudaGetLastError());
return d_logits;
}
AT_DISPATCH_FLOATING_TYPES(logits.type(), "SigmoidFocalLoss_backward", [&] {
SigmoidFocalLossBackward<scalar_t><<<grid, block, 0, stream>>>(
d_logits_size,
logits.contiguous().data<scalar_t>(),
targets.contiguous().data<int>(),
d_losses.contiguous().data<scalar_t>(),
num_classes,
gamma,
alpha,
num_samples,
d_logits.data<scalar_t>());
});
THCudaCheck(cudaGetLastError());
return d_logits;
}
...@@ -3,6 +3,21 @@ ...@@ -3,6 +3,21 @@
#include <torch/extension.h> #include <torch/extension.h>
at::Tensor SigmoidFocalLoss_forward_cuda(
const at::Tensor& logits,
const at::Tensor& targets,
const int num_classes,
const float gamma,
const float alpha);
at::Tensor SigmoidFocalLoss_backward_cuda(
const at::Tensor& logits,
const at::Tensor& targets,
const at::Tensor& d_losses,
const int num_classes,
const float gamma,
const float alpha);
at::Tensor ROIAlign_forward_cuda(const at::Tensor& input, at::Tensor ROIAlign_forward_cuda(const at::Tensor& input,
const at::Tensor& rois, const at::Tensor& rois,
const float spatial_scale, const float spatial_scale,
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
#include "nms.h" #include "nms.h"
#include "ROIAlign.h" #include "ROIAlign.h"
#include "ROIPool.h" #include "ROIPool.h"
#include "SigmoidFocalLoss.h"
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("nms", &nms, "non-maximum suppression"); m.def("nms", &nms, "non-maximum suppression");
...@@ -10,4 +10,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -10,4 +10,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("roi_align_backward", &ROIAlign_backward, "ROIAlign_backward"); m.def("roi_align_backward", &ROIAlign_backward, "ROIAlign_backward");
m.def("roi_pool_forward", &ROIPool_forward, "ROIPool_forward"); m.def("roi_pool_forward", &ROIPool_forward, "ROIPool_forward");
m.def("roi_pool_backward", &ROIPool_backward, "ROIPool_backward"); m.def("roi_pool_backward", &ROIPool_backward, "ROIPool_backward");
m.def("sigmoid_focalloss_forward", &SigmoidFocalLoss_forward, "SigmoidFocalLoss_forward");
m.def("sigmoid_focalloss_backward", &SigmoidFocalLoss_backward, "SigmoidFocalLoss_backward");
} }
...@@ -11,8 +11,10 @@ from .roi_align import roi_align ...@@ -11,8 +11,10 @@ from .roi_align import roi_align
from .roi_pool import ROIPool from .roi_pool import ROIPool
from .roi_pool import roi_pool from .roi_pool import roi_pool
from .smooth_l1_loss import smooth_l1_loss from .smooth_l1_loss import smooth_l1_loss
from .sigmoid_focal_loss import SigmoidFocalLoss
__all__ = ["nms", "roi_align", "ROIAlign", "roi_pool", "ROIPool",
"smooth_l1_loss", "Conv2d", "ConvTranspose2d", "interpolate",
"FrozenBatchNorm2d", "SigmoidFocalLoss"
]
__all__ = ["nms", "roi_align", "ROIAlign", "roi_pool", "ROIPool",
"smooth_l1_loss", "Conv2d", "ConvTranspose2d", "interpolate",
"FrozenBatchNorm2d",
]
import torch
from torch import nn
from torch.autograd import Function
from torch.autograd.function import once_differentiable
from maskrcnn_benchmark import _C
# TODO: Use JIT to replace CUDA implementation in the future.
class _SigmoidFocalLoss(Function):
@staticmethod
def forward(ctx, logits, targets, gamma, alpha):
ctx.save_for_backward(logits, targets)
num_classes = logits.shape[1]
ctx.num_classes = num_classes
ctx.gamma = gamma
ctx.alpha = alpha
losses = _C.sigmoid_focalloss_forward(
logits, targets, num_classes, gamma, alpha
)
return losses
@staticmethod
@once_differentiable
def backward(ctx, d_loss):
logits, targets = ctx.saved_tensors
num_classes = ctx.num_classes
gamma = ctx.gamma
alpha = ctx.alpha
d_loss = d_loss.contiguous()
d_logits = _C.sigmoid_focalloss_backward(
logits, targets, d_loss, num_classes, gamma, alpha
)
return d_logits, None, None, None, None
sigmoid_focal_loss_cuda = _SigmoidFocalLoss.apply
def sigmoid_focal_loss_cpu(logits, targets, gamma, alpha):
num_classes = logits.shape[1]
gamma = gamma[0]
alpha = alpha[0]
dtype = targets.dtype
device = targets.device
class_range = torch.arange(1, num_classes+1, dtype=dtype, device=device).unsqueeze(0)
t = targets.unsqueeze(1)
p = torch.sigmoid(logits)
term1 = (1 - p) ** gamma * torch.log(p)
term2 = p ** gamma * torch.log(1 - p)
return -(t == class_range).float() * term1 * alpha - ((t != class_range) * (t >= 0)).float() * term2 * (1 - alpha)
class SigmoidFocalLoss(nn.Module):
def __init__(self, gamma, alpha):
super(SigmoidFocalLoss, self).__init__()
self.gamma = gamma
self.alpha = alpha
def forward(self, logits, targets):
device = logits.device
if logits.is_cuda:
loss_func = sigmoid_focal_loss_cuda
else:
loss_func = sigmoid_focal_loss_cpu
loss = loss_func(logits, targets, self.gamma, self.alpha)
return loss.sum()
def __repr__(self):
tmpstr = self.__class__.__name__ + "("
tmpstr += "gamma=" + str(self.gamma)
tmpstr += ", alpha=" + str(self.alpha)
tmpstr += ")"
return tmpstr
...@@ -42,6 +42,29 @@ def build_resnet_fpn_backbone(cfg): ...@@ -42,6 +42,29 @@ def build_resnet_fpn_backbone(cfg):
model = nn.Sequential(OrderedDict([("body", body), ("fpn", fpn)])) model = nn.Sequential(OrderedDict([("body", body), ("fpn", fpn)]))
return model return model
@registry.BACKBONES.register("R-50-FPN-RETINANET")
@registry.BACKBONES.register("R-101-FPN-RETINANET")
def build_resnet_fpn_p3p7_backbone(cfg):
body = resnet.ResNet(cfg)
in_channels_stage2 = cfg.MODEL.RESNETS.RES2_OUT_CHANNELS
out_channels = cfg.MODEL.BACKBONE.OUT_CHANNELS
in_channels_p6p7 = in_channels_stage2 * 8 if cfg.MODEL.RETINANET.USE_C5 \
else out_channels
fpn = fpn_module.FPN(
in_channels_list=[
0,
in_channels_stage2 * 2,
in_channels_stage2 * 4,
in_channels_stage2 * 8,
],
out_channels=out_channels,
conv_block=conv_with_kaiming_uniform(
cfg.MODEL.FPN.USE_GN, cfg.MODEL.FPN.USE_RELU
),
top_blocks=fpn_module.LastLevelP6P7(in_channels_p6p7, out_channels),
)
model = nn.Sequential(OrderedDict([("body", body), ("fpn", fpn)]))
return model
def build_backbone(cfg): def build_backbone(cfg):
assert cfg.MODEL.BACKBONE.CONV_BODY in registry.BACKBONES, \ assert cfg.MODEL.BACKBONE.CONV_BODY in registry.BACKBONES, \
......
...@@ -29,6 +29,9 @@ class FPN(nn.Module): ...@@ -29,6 +29,9 @@ class FPN(nn.Module):
for idx, in_channels in enumerate(in_channels_list, 1): for idx, in_channels in enumerate(in_channels_list, 1):
inner_block = "fpn_inner{}".format(idx) inner_block = "fpn_inner{}".format(idx)
layer_block = "fpn_layer{}".format(idx) layer_block = "fpn_layer{}".format(idx)
if in_channels == 0:
continue
inner_block_module = conv_block(in_channels, out_channels, 1) inner_block_module = conv_block(in_channels, out_channels, 1)
layer_block_module = conv_block(out_channels, out_channels, 3, 1) layer_block_module = conv_block(out_channels, out_channels, 3, 1)
self.add_module(inner_block, inner_block_module) self.add_module(inner_block, inner_block_module)
...@@ -51,6 +54,8 @@ class FPN(nn.Module): ...@@ -51,6 +54,8 @@ class FPN(nn.Module):
for feature, inner_block, layer_block in zip( for feature, inner_block, layer_block in zip(
x[:-1][::-1], self.inner_blocks[:-1][::-1], self.layer_blocks[:-1][::-1] x[:-1][::-1], self.inner_blocks[:-1][::-1], self.layer_blocks[:-1][::-1]
): ):
if not inner_block:
continue
inner_top_down = F.interpolate(last_inner, scale_factor=2, mode="nearest") inner_top_down = F.interpolate(last_inner, scale_factor=2, mode="nearest")
inner_lateral = getattr(self, inner_block)(feature) inner_lateral = getattr(self, inner_block)(feature)
# TODO use size instead of scale to make it robust to different sizes # TODO use size instead of scale to make it robust to different sizes
...@@ -59,7 +64,10 @@ class FPN(nn.Module): ...@@ -59,7 +64,10 @@ class FPN(nn.Module):
last_inner = inner_lateral + inner_top_down last_inner = inner_lateral + inner_top_down
results.insert(0, getattr(self, layer_block)(last_inner)) results.insert(0, getattr(self, layer_block)(last_inner))
if self.top_blocks is not None: if isinstance(self.top_blocks, LastLevelP6P7):
last_results = self.top_blocks(x[-1], results[-1])
results.extend(last_results)
elif isinstance(self.top_blocks, LastLevelMaxPool):
last_results = self.top_blocks(results[-1]) last_results = self.top_blocks(results[-1])
results.extend(last_results) results.extend(last_results)
...@@ -69,3 +77,23 @@ class FPN(nn.Module): ...@@ -69,3 +77,23 @@ class FPN(nn.Module):
class LastLevelMaxPool(nn.Module): class LastLevelMaxPool(nn.Module):
def forward(self, x): def forward(self, x):
return [F.max_pool2d(x, 1, 2, 0)] return [F.max_pool2d(x, 1, 2, 0)]
class LastLevelP6P7(nn.Module):
"""
This module is used in RetinaNet to generate extra layers, P6 and P7.
"""
def __init__(self, in_channels, out_channels):
super(LastLevelP6P7, self).__init__()
self.p6 = nn.Conv2d(in_channels, out_channels, 3, 2, 1)
self.p7 = nn.Conv2d(out_channels, out_channels, 3, 2, 1)
for module in [self.p6, self.p7]:
nn.init.kaiming_uniform_(module.weight, a=1)
nn.init.constant_(module.bias, 0)
self.use_P5 = in_channels == out_channels
def forward(self, c5, p5):
x = p5 if self.use_P5 else c5
p6 = self.p6(x)
p7 = self.p7(F.relu(p6))
return [p6, p7]
...@@ -410,6 +410,8 @@ _STAGE_SPECS = Registry({ ...@@ -410,6 +410,8 @@ _STAGE_SPECS = Registry({
"R-101-C4": ResNet101StagesTo4, "R-101-C4": ResNet101StagesTo4,
"R-101-C5": ResNet101StagesTo5, "R-101-C5": ResNet101StagesTo5,
"R-50-FPN": ResNet50FPNStagesTo5, "R-50-FPN": ResNet50FPNStagesTo5,
"R-50-FPN-RETINANET": ResNet50FPNStagesTo5,
"R-101-FPN": ResNet101FPNStagesTo5, "R-101-FPN": ResNet101FPNStagesTo5,
"R-101-FPN-RETINANET": ResNet101FPNStagesTo5,
"R-152-FPN": ResNet152FPNStagesTo5, "R-152-FPN": ResNet152FPNStagesTo5,
}) })
...@@ -59,6 +59,9 @@ def build_roi_heads(cfg): ...@@ -59,6 +59,9 @@ def build_roi_heads(cfg):
# individually create the heads, that will be combined together # individually create the heads, that will be combined together
# afterwards # afterwards
roi_heads = [] roi_heads = []
if cfg.MODEL.RETINANET_ON:
return []
if not cfg.MODEL.RPN_ONLY: if not cfg.MODEL.RPN_ONLY:
roi_heads.append(("box", build_roi_box_head(cfg))) roi_heads.append(("box", build_roi_box_head(cfg)))
if cfg.MODEL.MASK_ON: if cfg.MODEL.MASK_ON:
......
...@@ -54,8 +54,13 @@ class AnchorGenerator(nn.Module): ...@@ -54,8 +54,13 @@ class AnchorGenerator(nn.Module):
else: else:
if len(anchor_strides) != len(sizes): if len(anchor_strides) != len(sizes):
raise RuntimeError("FPN should have #anchor_strides == #sizes") raise RuntimeError("FPN should have #anchor_strides == #sizes")
cell_anchors = [ cell_anchors = [
generate_anchors(anchor_stride, (size,), aspect_ratios).float() generate_anchors(
anchor_stride,
size if isinstance(size, (tuple, list)) else (size,),
aspect_ratios
).float()
for anchor_stride, size in zip(anchor_strides, sizes) for anchor_stride, size in zip(anchor_strides, sizes)
] ]
self.strides = anchor_strides self.strides = anchor_strides
...@@ -138,6 +143,28 @@ def make_anchor_generator(config): ...@@ -138,6 +143,28 @@ def make_anchor_generator(config):
return anchor_generator return anchor_generator
def make_anchor_generator_retinanet(config):
anchor_sizes = config.MODEL.RETINANET.ANCHOR_SIZES
aspect_ratios = config.MODEL.RETINANET.ASPECT_RATIOS
anchor_strides = config.MODEL.RETINANET.ANCHOR_STRIDES
straddle_thresh = config.MODEL.RETINANET.STRADDLE_THRESH
octave = config.MODEL.RETINANET.OCTAVE
scales_per_octave = config.MODEL.RETINANET.SCALES_PER_OCTAVE
assert len(anchor_strides) == len(anchor_sizes), "Only support FPN now"
new_anchor_sizes = []
for size in anchor_sizes:
per_layer_anchor_sizes = []
for scale_per_octave in range(scales_per_octave):
octave_scale = octave ** (scale_per_octave / float(scales_per_octave))
per_layer_anchor_sizes.append(octave_scale * size)
new_anchor_sizes.append(tuple(per_layer_anchor_sizes))
anchor_generator = AnchorGenerator(
tuple(new_anchor_sizes), aspect_ratios, anchor_strides, straddle_thresh
)
return anchor_generator
# Copyright (c) 2017-present, Facebook, Inc. # Copyright (c) 2017-present, Facebook, Inc.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
......
...@@ -8,7 +8,7 @@ from maskrcnn_benchmark.structures.boxlist_ops import boxlist_nms ...@@ -8,7 +8,7 @@ from maskrcnn_benchmark.structures.boxlist_ops import boxlist_nms
from maskrcnn_benchmark.structures.boxlist_ops import remove_small_boxes from maskrcnn_benchmark.structures.boxlist_ops import remove_small_boxes
from ..utils import cat from ..utils import cat
from .utils import permute_and_flatten
class RPNPostProcessor(torch.nn.Module): class RPNPostProcessor(torch.nn.Module):
""" """
...@@ -82,10 +82,10 @@ class RPNPostProcessor(torch.nn.Module): ...@@ -82,10 +82,10 @@ class RPNPostProcessor(torch.nn.Module):
N, A, H, W = objectness.shape N, A, H, W = objectness.shape
# put in the same format as anchors # put in the same format as anchors
objectness = objectness.permute(0, 2, 3, 1).reshape(N, -1) objectness = permute_and_flatten(objectness, N, A, 1, H, W).view(N, -1)
objectness = objectness.sigmoid() objectness = objectness.sigmoid()
box_regression = box_regression.view(N, -1, 4, H, W).permute(0, 3, 4, 1, 2)
box_regression = box_regression.reshape(N, -1, 4) box_regression = permute_and_flatten(box_regression, N, A, 4, H, W)
num_anchors = A * H * W num_anchors = A * H * W
......
...@@ -7,6 +7,8 @@ file ...@@ -7,6 +7,8 @@ file
import torch import torch
from torch.nn import functional as F from torch.nn import functional as F
from .utils import concat_box_prediction_layers
from ..balanced_positive_negative_sampler import BalancedPositiveNegativeSampler from ..balanced_positive_negative_sampler import BalancedPositiveNegativeSampler
from ..utils import cat from ..utils import cat
...@@ -21,7 +23,8 @@ class RPNLossComputation(object): ...@@ -21,7 +23,8 @@ class RPNLossComputation(object):
This class computes the RPN loss. This class computes the RPN loss.
""" """
def __init__(self, proposal_matcher, fg_bg_sampler, box_coder): def __init__(self, proposal_matcher, fg_bg_sampler, box_coder,
generate_labels_func):
""" """
Arguments: Arguments:
proposal_matcher (Matcher) proposal_matcher (Matcher)
...@@ -32,13 +35,16 @@ class RPNLossComputation(object): ...@@ -32,13 +35,16 @@ class RPNLossComputation(object):
self.proposal_matcher = proposal_matcher self.proposal_matcher = proposal_matcher
self.fg_bg_sampler = fg_bg_sampler self.fg_bg_sampler = fg_bg_sampler
self.box_coder = box_coder self.box_coder = box_coder
self.copied_fields = []
self.generate_labels_func = generate_labels_func
self.discard_cases = ['not_visibility', 'between_thresholds']
def match_targets_to_anchors(self, anchor, target): def match_targets_to_anchors(self, anchor, target, copied_fields=[]):
match_quality_matrix = boxlist_iou(target, anchor) match_quality_matrix = boxlist_iou(target, anchor)
matched_idxs = self.proposal_matcher(match_quality_matrix) matched_idxs = self.proposal_matcher(match_quality_matrix)
# RPN doesn't need any fields from target # RPN doesn't need any fields from target
# for creating the labels, so clear them all # for creating the labels, so clear them all
target = target.copy_with_fields([]) target = target.copy_with_fields(copied_fields)
# get the targets corresponding GT for each anchor # get the targets corresponding GT for each anchor
# NB: need to clamp the indices because we can have a single # NB: need to clamp the indices because we can have a single
# GT in the image, and matched_idxs can be -2, which goes # GT in the image, and matched_idxs can be -2, which goes
...@@ -52,18 +58,25 @@ class RPNLossComputation(object): ...@@ -52,18 +58,25 @@ class RPNLossComputation(object):
regression_targets = [] regression_targets = []
for anchors_per_image, targets_per_image in zip(anchors, targets): for anchors_per_image, targets_per_image in zip(anchors, targets):
matched_targets = self.match_targets_to_anchors( matched_targets = self.match_targets_to_anchors(
anchors_per_image, targets_per_image anchors_per_image, targets_per_image, self.copied_fields
) )
matched_idxs = matched_targets.get_field("matched_idxs") matched_idxs = matched_targets.get_field("matched_idxs")
labels_per_image = matched_idxs >= 0 labels_per_image = self.generate_labels_func(matched_targets)
labels_per_image = labels_per_image.to(dtype=torch.float32) labels_per_image = labels_per_image.to(dtype=torch.float32)
# Background (negative examples)
bg_indices = matched_idxs == Matcher.BELOW_LOW_THRESHOLD
labels_per_image[bg_indices] = 0
# discard anchors that go out of the boundaries of the image # discard anchors that go out of the boundaries of the image
labels_per_image[~anchors_per_image.get_field("visibility")] = -1 if "not_visibility" in self.discard_cases:
labels_per_image[~anchors_per_image.get_field("visibility")] = -1
# discard indices that are between thresholds # discard indices that are between thresholds
inds_to_discard = matched_idxs == Matcher.BETWEEN_THRESHOLDS if "between_thresholds" in self.discard_cases:
labels_per_image[inds_to_discard] = -1 inds_to_discard = matched_idxs == Matcher.BETWEEN_THRESHOLDS
labels_per_image[inds_to_discard] = -1
# compute regression targets # compute regression targets
regression_targets_per_image = self.box_coder.encode( regression_targets_per_image = self.box_coder.encode(
...@@ -75,6 +88,7 @@ class RPNLossComputation(object): ...@@ -75,6 +88,7 @@ class RPNLossComputation(object):
return labels, regression_targets return labels, regression_targets
def __call__(self, anchors, objectness, box_regression, targets): def __call__(self, anchors, objectness, box_regression, targets):
""" """
Arguments: Arguments:
...@@ -95,29 +109,10 @@ class RPNLossComputation(object): ...@@ -95,29 +109,10 @@ class RPNLossComputation(object):
sampled_inds = torch.cat([sampled_pos_inds, sampled_neg_inds], dim=0) sampled_inds = torch.cat([sampled_pos_inds, sampled_neg_inds], dim=0)
objectness_flattened = [] objectness, box_regression = \
box_regression_flattened = [] concat_box_prediction_layers(objectness, box_regression)
# for each feature level, permute the outputs to make them be in the
# same format as the labels. Note that the labels are computed for objectness = objectness.squeeze()
# all feature levels concatenated, so we keep the same representation
# for the objectness and the box_regression
for objectness_per_level, box_regression_per_level in zip(
objectness, box_regression
):
N, A, H, W = objectness_per_level.shape
objectness_per_level = objectness_per_level.permute(0, 2, 3, 1).reshape(
N, -1
)
box_regression_per_level = box_regression_per_level.view(N, -1, 4, H, W)
box_regression_per_level = box_regression_per_level.permute(0, 3, 4, 1, 2)
box_regression_per_level = box_regression_per_level.reshape(N, -1, 4)
objectness_flattened.append(objectness_per_level)
box_regression_flattened.append(box_regression_per_level)
# concatenate on the first dimension (representing the feature levels), to
# take into account the way the labels were generated (with all feature maps
# being concatenated as well)
objectness = cat(objectness_flattened, dim=1).reshape(-1)
box_regression = cat(box_regression_flattened, dim=1).reshape(-1, 4)
labels = torch.cat(labels, dim=0) labels = torch.cat(labels, dim=0)
regression_targets = torch.cat(regression_targets, dim=0) regression_targets = torch.cat(regression_targets, dim=0)
...@@ -135,6 +130,12 @@ class RPNLossComputation(object): ...@@ -135,6 +130,12 @@ class RPNLossComputation(object):
return objectness_loss, box_loss return objectness_loss, box_loss
# This function should be overwritten in RetinaNet
def generate_rpn_labels(matched_targets):
matched_idxs = matched_targets.get_field("matched_idxs")
labels_per_image = matched_idxs >= 0
return labels_per_image
def make_rpn_loss_evaluator(cfg, box_coder): def make_rpn_loss_evaluator(cfg, box_coder):
matcher = Matcher( matcher = Matcher(
...@@ -147,5 +148,10 @@ def make_rpn_loss_evaluator(cfg, box_coder): ...@@ -147,5 +148,10 @@ def make_rpn_loss_evaluator(cfg, box_coder):
cfg.MODEL.RPN.BATCH_SIZE_PER_IMAGE, cfg.MODEL.RPN.POSITIVE_FRACTION cfg.MODEL.RPN.BATCH_SIZE_PER_IMAGE, cfg.MODEL.RPN.POSITIVE_FRACTION
) )
loss_evaluator = RPNLossComputation(matcher, fg_bg_sampler, box_coder) loss_evaluator = RPNLossComputation(
matcher,
fg_bg_sampler,
box_coder,
generate_rpn_labels
)
return loss_evaluator return loss_evaluator
import torch
from ..inference import RPNPostProcessor
from ..utils import permute_and_flatten
from maskrcnn_benchmark.modeling.box_coder import BoxCoder
from maskrcnn_benchmark.modeling.utils import cat
from maskrcnn_benchmark.structures.bounding_box import BoxList
from maskrcnn_benchmark.structures.boxlist_ops import cat_boxlist
from maskrcnn_benchmark.structures.boxlist_ops import boxlist_nms
from maskrcnn_benchmark.structures.boxlist_ops import remove_small_boxes
class RetinaNetPostProcessor(RPNPostProcessor):
"""
Performs post-processing on the outputs of the RetinaNet boxes.
This is only used in the testing.
"""
def __init__(
self,
pre_nms_thresh,
pre_nms_top_n,
nms_thresh,
fpn_post_nms_top_n,
min_size,
num_classes,
box_coder=None,
):
"""
Arguments:
pre_nms_thresh (float)
pre_nms_top_n (int)
nms_thresh (float)
fpn_post_nms_top_n (int)
min_size (int)
num_classes (int)
box_coder (BoxCoder)
"""
super(RetinaNetPostProcessor, self).__init__(
pre_nms_thresh, 0, nms_thresh, min_size
)
self.pre_nms_thresh = pre_nms_thresh
self.pre_nms_top_n = pre_nms_top_n
self.nms_thresh = nms_thresh
self.fpn_post_nms_top_n = fpn_post_nms_top_n
self.min_size = min_size
self.num_classes = num_classes
if box_coder is None:
box_coder = BoxCoder(weights=(10., 10., 5., 5.))
self.box_coder = box_coder
def add_gt_proposals(self, proposals, targets):
"""
This function is not used in RetinaNet
"""
pass
def forward_for_single_feature_map(
self, anchors, box_cls, box_regression):
"""
Arguments:
anchors: list[BoxList]
box_cls: tensor of size N, A * C, H, W
box_regression: tensor of size N, A * 4, H, W
"""
device = box_cls.device
N, _, H, W = box_cls.shape
A = box_regression.size(1) // 4
C = box_cls.size(1) // A
# put in the same format as anchors
box_cls = permute_and_flatten(box_cls, N, A, C, H, W)
box_cls = box_cls.sigmoid()
box_regression = permute_and_flatten(box_regression, N, A, 4, H, W)
box_regression = box_regression.reshape(N, -1, 4)
num_anchors = A * H * W
candidate_inds = box_cls > self.pre_nms_thresh
pre_nms_top_n = candidate_inds.view(N, -1).sum(1)
pre_nms_top_n = pre_nms_top_n.clamp(max=self.pre_nms_top_n)
results = []
for per_box_cls, per_box_regression, per_pre_nms_top_n, \
per_candidate_inds, per_anchors in zip(
box_cls,
box_regression,
pre_nms_top_n,
candidate_inds,
anchors):
# Sort and select TopN
# TODO most of this can be made out of the loop for
# all images.
# TODO:Yang: Not easy to do. Because the numbers of detections are
# different in each image. Therefore, this part needs to be done
# per image.
per_box_cls = per_box_cls[per_candidate_inds]
per_box_cls, top_k_indices = \
per_box_cls.topk(per_pre_nms_top_n, sorted=False)
per_candidate_nonzeros = \
per_candidate_inds.nonzero()[top_k_indices, :]
per_box_loc = per_candidate_nonzeros[:, 0]
per_class = per_candidate_nonzeros[:, 1]
per_class += 1
detections = self.box_coder.decode(
per_box_regression[per_box_loc, :].view(-1, 4),
per_anchors.bbox[per_box_loc, :].view(-1, 4)
)
boxlist = BoxList(detections, per_anchors.size, mode="xyxy")
boxlist.add_field("labels", per_class)
boxlist.add_field("scores", per_box_cls)
boxlist = boxlist.clip_to_image(remove_empty=False)
boxlist = remove_small_boxes(boxlist, self.min_size)
results.append(boxlist)
return results
# TODO very similar to filter_results from PostProcessor
# but filter_results is per image
# TODO Yang: solve this issue in the future. No good solution
# right now.
def select_over_all_levels(self, boxlists):
num_images = len(boxlists)
results = []
for i in range(num_images):
scores = boxlists[i].get_field("scores")
labels = boxlists[i].get_field("labels")
boxes = boxlists[i].bbox
boxlist = boxlists[i]
result = []
# skip the background
for j in range(1, self.num_classes):
inds = (labels == j).nonzero().view(-1)
scores_j = scores[inds]
boxes_j = boxes[inds, :].view(-1, 4)
boxlist_for_class = BoxList(boxes_j, boxlist.size, mode="xyxy")
boxlist_for_class.add_field("scores", scores_j)
boxlist_for_class = boxlist_nms(
boxlist_for_class, self.nms_thresh,
score_field="scores"
)
num_labels = len(boxlist_for_class)
boxlist_for_class.add_field(
"labels", torch.full((num_labels,), j,
dtype=torch.int64,
device=scores.device)
)
result.append(boxlist_for_class)
result = cat_boxlist(result)
number_of_detections = len(result)
# Limit to max_per_image detections **over all classes**
if number_of_detections > self.fpn_post_nms_top_n > 0:
cls_scores = result.get_field("scores")
image_thresh, _ = torch.kthvalue(
cls_scores.cpu(),
number_of_detections - self.fpn_post_nms_top_n + 1
)
keep = cls_scores >= image_thresh.item()
keep = torch.nonzero(keep).squeeze(1)
result = result[keep]
results.append(result)
return results
def make_retinanet_postprocessor(config, rpn_box_coder, is_train):
pre_nms_thresh = config.MODEL.RETINANET.INFERENCE_TH
pre_nms_top_n = config.MODEL.RETINANET.PRE_NMS_TOP_N
nms_thresh = config.MODEL.RETINANET.NMS_TH
fpn_post_nms_top_n = config.TEST.DETECTIONS_PER_IMG
min_size = 0
box_selector = RetinaNetPostProcessor(
pre_nms_thresh=pre_nms_thresh,
pre_nms_top_n=pre_nms_top_n,
nms_thresh=nms_thresh,
fpn_post_nms_top_n=fpn_post_nms_top_n,
min_size=min_size,
num_classes=config.MODEL.RETINANET.NUM_CLASSES,
box_coder=rpn_box_coder,
)
return box_selector
"""
This file contains specific functions for computing losses on the RetinaNet
file
"""
import torch
from torch.nn import functional as F
from ..utils import concat_box_prediction_layers
from maskrcnn_benchmark.layers import smooth_l1_loss
from maskrcnn_benchmark.layers import SigmoidFocalLoss
from maskrcnn_benchmark.modeling.matcher import Matcher
from maskrcnn_benchmark.modeling.utils import cat
from maskrcnn_benchmark.structures.boxlist_ops import boxlist_iou
from maskrcnn_benchmark.structures.boxlist_ops import cat_boxlist
from maskrcnn_benchmark.modeling.rpn.loss import RPNLossComputation
class RetinaNetLossComputation(RPNLossComputation):
"""
This class computes the RetinaNet loss.
"""
def __init__(self, proposal_matcher, box_coder,
generate_labels_func,
sigmoid_focal_loss,
bbox_reg_beta=0.11,
regress_norm=1.0):
"""
Arguments:
proposal_matcher (Matcher)
box_coder (BoxCoder)
"""
self.proposal_matcher = proposal_matcher
self.box_coder = box_coder
self.box_cls_loss_func = sigmoid_focal_loss
self.bbox_reg_beta = bbox_reg_beta
self.copied_fields = ['labels']
self.generate_labels_func = generate_labels_func
self.discard_cases = ['between_thresholds']
self.regress_norm = regress_norm
def __call__(self, anchors, box_cls, box_regression, targets):
"""
Arguments:
anchors (list[BoxList])
box_cls (list[Tensor])
box_regression (list[Tensor])
targets (list[BoxList])
Returns:
retinanet_cls_loss (Tensor)
retinanet_regression_loss (Tensor
"""
anchors = [cat_boxlist(anchors_per_image) for anchors_per_image in anchors]
labels, regression_targets = self.prepare_targets(anchors, targets)
N = len(labels)
box_cls, box_regression = \
concat_box_prediction_layers(box_cls, box_regression)
labels = torch.cat(labels, dim=0)
regression_targets = torch.cat(regression_targets, dim=0)
pos_inds = torch.nonzero(labels > 0).squeeze(1)
retinanet_regression_loss = smooth_l1_loss(
box_regression[pos_inds],
regression_targets[pos_inds],
beta=self.bbox_reg_beta,
size_average=False,
) / (max(1, pos_inds.numel() * self.regress_norm))
labels = labels.int()
retinanet_cls_loss = self.box_cls_loss_func(
box_cls,
labels
) / (pos_inds.numel() + N)
return retinanet_cls_loss, retinanet_regression_loss
def generate_retinanet_labels(matched_targets):
labels_per_image = matched_targets.get_field("labels")
return labels_per_image
def make_retinanet_loss_evaluator(cfg, box_coder):
matcher = Matcher(
cfg.MODEL.RETINANET.FG_IOU_THRESHOLD,
cfg.MODEL.RETINANET.BG_IOU_THRESHOLD,
allow_low_quality_matches=True,
)
sigmoid_focal_loss = SigmoidFocalLoss(
cfg.MODEL.RETINANET.LOSS_GAMMA,
cfg.MODEL.RETINANET.LOSS_ALPHA
)
loss_evaluator = RetinaNetLossComputation(
matcher,
box_coder,
generate_retinanet_labels,
sigmoid_focal_loss,
bbox_reg_beta = cfg.MODEL.RETINANET.BBOX_REG_BETA,
regress_norm = cfg.MODEL.RETINANET.BBOX_REG_WEIGHT,
)
return loss_evaluator
import math
import torch
import torch.nn.functional as F
from torch import nn
from .inference import make_retinanet_postprocessor
from .loss import make_retinanet_loss_evaluator
from ..anchor_generator import make_anchor_generator_retinanet
from maskrcnn_benchmark.modeling.box_coder import BoxCoder
class RetinaNetHead(torch.nn.Module):
"""
Adds a RetinNet head with classification and regression heads
"""
def __init__(self, cfg):
"""
Arguments:
in_channels (int): number of channels of the input feature
num_anchors (int): number of anchors to be predicted
"""
super(RetinaNetHead, self).__init__()
# TODO: Implement the sigmoid version first.
num_classes = cfg.MODEL.RETINANET.NUM_CLASSES - 1
in_channels = cfg.MODEL.BACKBONE.OUT_CHANNELS
num_anchors = len(cfg.MODEL.RETINANET.ASPECT_RATIOS) \
* cfg.MODEL.RETINANET.SCALES_PER_OCTAVE
cls_tower = []
bbox_tower = []
for i in range(cfg.MODEL.RETINANET.NUM_CONVS):
cls_tower.append(
nn.Conv2d(
in_channels,
in_channels,
kernel_size=3,
stride=1,
padding=1
)
)
cls_tower.append(nn.ReLU())
bbox_tower.append(
nn.Conv2d(
in_channels,
in_channels,
kernel_size=3,
stride=1,
padding=1
)
)
bbox_tower.append(nn.ReLU())
self.add_module('cls_tower', nn.Sequential(*cls_tower))
self.add_module('bbox_tower', nn.Sequential(*bbox_tower))
self.cls_logits = nn.Conv2d(
in_channels, num_anchors * num_classes, kernel_size=3, stride=1,
padding=1
)
self.bbox_pred = nn.Conv2d(
in_channels, num_anchors * 4, kernel_size=3, stride=1,
padding=1
)
# Initialization
for modules in [self.cls_tower, self.bbox_tower, self.cls_logits,
self.bbox_pred]:
for l in modules.modules():
if isinstance(l, nn.Conv2d):
torch.nn.init.normal_(l.weight, std=0.01)
torch.nn.init.constant_(l.bias, 0)
# retinanet_bias_init
prior_prob = cfg.MODEL.RETINANET.PRIOR_PROB
bias_value = -math.log((1 - prior_prob) / prior_prob)
torch.nn.init.constant_(self.cls_logits.bias, bias_value)
def forward(self, x):
logits = []
bbox_reg = []
for feature in x:
logits.append(self.cls_logits(self.cls_tower(feature)))
bbox_reg.append(self.bbox_pred(self.bbox_tower(feature)))
return logits, bbox_reg
class RetinaNetModule(torch.nn.Module):
"""
Module for RetinaNet computation. Takes feature maps from the backbone and
RetinaNet outputs and losses. Only Test on FPN now.
"""
def __init__(self, cfg):
super(RetinaNetModule, self).__init__()
self.cfg = cfg.clone()
anchor_generator = make_anchor_generator_retinanet(cfg)
head = RetinaNetHead(cfg)
box_coder = BoxCoder(weights=(10., 10., 5., 5.))
box_selector_test = make_retinanet_postprocessor(cfg, box_coder, is_train=False)
loss_evaluator = make_retinanet_loss_evaluator(cfg, box_coder)
self.anchor_generator = anchor_generator
self.head = head
self.box_selector_test = box_selector_test
self.loss_evaluator = loss_evaluator
def forward(self, images, features, targets=None):
"""
Arguments:
images (ImageList): images for which we want to compute the predictions
features (list[Tensor]): features computed from the images that are
used for computing the predictions. Each tensor in the list
correspond to different feature levels
targets (list[BoxList): ground-truth boxes present in the image (optional)
Returns:
boxes (list[BoxList]): the predicted boxes from the RPN, one BoxList per
image.
losses (dict[Tensor]): the losses for the model during training. During
testing, it is an empty dict.
"""
box_cls, box_regression = self.head(features)
anchors = self.anchor_generator(images, features)
if self.training:
return self._forward_train(anchors, box_cls, box_regression, targets)
else:
return self._forward_test(anchors, box_cls, box_regression)
def _forward_train(self, anchors, box_cls, box_regression, targets):
loss_box_cls, loss_box_reg = self.loss_evaluator(
anchors, box_cls, box_regression, targets
)
losses = {
"loss_retina_cls": loss_box_cls,
"loss_retina_reg": loss_box_reg,
}
return anchors, losses
def _forward_test(self, anchors, box_cls, box_regression):
boxes = self.box_selector_test(anchors, box_cls, box_regression)
return boxes, {}
def build_retinanet(cfg):
return RetinaNetModule(cfg)
...@@ -5,11 +5,11 @@ from torch import nn ...@@ -5,11 +5,11 @@ from torch import nn
from maskrcnn_benchmark.modeling import registry from maskrcnn_benchmark.modeling import registry
from maskrcnn_benchmark.modeling.box_coder import BoxCoder from maskrcnn_benchmark.modeling.box_coder import BoxCoder
from maskrcnn_benchmark.modeling.rpn.retinanet.retinanet import build_retinanet
from .loss import make_rpn_loss_evaluator from .loss import make_rpn_loss_evaluator
from .anchor_generator import make_anchor_generator from .anchor_generator import make_anchor_generator
from .inference import make_rpn_postprocessor from .inference import make_rpn_postprocessor
@registry.RPN_HEADS.register("SingleConvRPNHead") @registry.RPN_HEADS.register("SingleConvRPNHead")
class RPNHead(nn.Module): class RPNHead(nn.Module):
""" """
...@@ -142,4 +142,7 @@ def build_rpn(cfg): ...@@ -142,4 +142,7 @@ def build_rpn(cfg):
""" """
This gives the gist of it. Not super important because it doesn't change as much This gives the gist of it. Not super important because it doesn't change as much
""" """
if cfg.MODEL.RETINANET_ON:
return build_retinanet(cfg)
return RPNModule(cfg) return RPNModule(cfg)
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
"""
Utility functions minipulating the prediction layers
"""
from ..utils import cat
import torch
def permute_and_flatten(layer, N, A, C, H, W):
layer = layer.view(N, -1, C, H, W)
layer = layer.permute(0, 3, 4, 1, 2)
layer = layer.reshape(N, -1, C)
return layer
def concat_box_prediction_layers(box_cls, box_regression):
box_cls_flattened = []
box_regression_flattened = []
# for each feature level, permute the outputs to make them be in the
# same format as the labels. Note that the labels are computed for
# all feature levels concatenated, so we keep the same representation
# for the objectness and the box_regression
for box_cls_per_level, box_regression_per_level in zip(
box_cls, box_regression
):
N, AxC, H, W = box_cls_per_level.shape
Ax4 = box_regression_per_level.shape[1]
A = Ax4 // 4
C = AxC // A
box_cls_per_level = permute_and_flatten(
box_cls_per_level, N, A, C, H, W
)
box_cls_flattened.append(box_cls_per_level)
box_regression_per_level = permute_and_flatten(
box_regression_per_level, N, A, 4, H, W
)
box_regression_flattened.append(box_regression_per_level)
# concatenate on the first dimension (representing the feature levels), to
# take into account the way the labels were generated (with all feature maps
# being concatenated as well)
box_cls = cat(box_cls_flattened, dim=1).reshape(-1, C)
box_regression = cat(box_regression_flattened, dim=1).reshape(-1, 4)
return box_cls, box_regression
...@@ -157,12 +157,15 @@ C2_FORMAT_LOADER = Registry() ...@@ -157,12 +157,15 @@ C2_FORMAT_LOADER = Registry()
@C2_FORMAT_LOADER.register("R-101-C4") @C2_FORMAT_LOADER.register("R-101-C4")
@C2_FORMAT_LOADER.register("R-101-C5") @C2_FORMAT_LOADER.register("R-101-C5")
@C2_FORMAT_LOADER.register("R-50-FPN") @C2_FORMAT_LOADER.register("R-50-FPN")
@C2_FORMAT_LOADER.register("R-50-FPN-RETINANET")
@C2_FORMAT_LOADER.register("R-101-FPN") @C2_FORMAT_LOADER.register("R-101-FPN")
@C2_FORMAT_LOADER.register("R-101-FPN-RETINANET")
@C2_FORMAT_LOADER.register("R-152-FPN") @C2_FORMAT_LOADER.register("R-152-FPN")
def load_resnet_c2_format(cfg, f): def load_resnet_c2_format(cfg, f):
state_dict = _load_c2_pickled_weights(f) state_dict = _load_c2_pickled_weights(f)
conv_body = cfg.MODEL.BACKBONE.CONV_BODY conv_body = cfg.MODEL.BACKBONE.CONV_BODY
arch = conv_body.replace("-C4", "").replace("-C5", "").replace("-FPN", "") arch = conv_body.replace("-C4", "").replace("-C5", "").replace("-FPN", "")
arch = arch.replace("-RETINANET", "")
stages = _C2_STAGE_NAMES[arch] stages = _C2_STAGE_NAMES[arch]
state_dict = _rename_weights_for_resnet(state_dict, stages) state_dict = _rename_weights_for_resnet(state_dict, stages)
return dict(model=state_dict) return dict(model=state_dict)
......
...@@ -84,7 +84,7 @@ def main(): ...@@ -84,7 +84,7 @@ def main():
data_loader_val, data_loader_val,
dataset_name=dataset_name, dataset_name=dataset_name,
iou_types=iou_types, iou_types=iou_types,
box_only=cfg.MODEL.RPN_ONLY, box_only=False if cfg.MODEL.RETINANET_ON else cfg.MODEL.RPN_ONLY,
device=cfg.MODEL.DEVICE, device=cfg.MODEL.DEVICE,
expected_results=cfg.TEST.EXPECTED_RESULTS, expected_results=cfg.TEST.EXPECTED_RESULTS,
expected_results_sigma_tol=cfg.TEST.EXPECTED_RESULTS_SIGMA_TOL, expected_results_sigma_tol=cfg.TEST.EXPECTED_RESULTS_SIGMA_TOL,
......
...@@ -99,7 +99,7 @@ def test(cfg, model, distributed): ...@@ -99,7 +99,7 @@ def test(cfg, model, distributed):
data_loader_val, data_loader_val,
dataset_name=dataset_name, dataset_name=dataset_name,
iou_types=iou_types, iou_types=iou_types,
box_only=cfg.MODEL.RPN_ONLY, box_only=False if cfg.MODEL.RETINANET_ON else cfg.MODEL.RPN_ONLY,
device=cfg.MODEL.DEVICE, device=cfg.MODEL.DEVICE,
expected_results=cfg.TEST.EXPECTED_RESULTS, expected_results=cfg.TEST.EXPECTED_RESULTS,
expected_results_sigma_tol=cfg.TEST.EXPECTED_RESULTS_SIGMA_TOL, expected_results_sigma_tol=cfg.TEST.EXPECTED_RESULTS_SIGMA_TOL,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment