Commit dd6c6615 authored by Ilija Radosavovic's avatar Ilija Radosavovic Committed by Facebook Github Bot

Support running the demo with RetinaNet

Reviewed By: rbgirshick

Differential Revision: D6829499

fbshipit-source-id: 66fe40ef1b2c8c560ce769faa27d107c3c7be841
parent b7159125
......@@ -39,6 +39,7 @@ import pycocotools.mask as mask_util
from core.config import cfg
from utils.timer import Timer
import core.test_retinanet as test_retinanet
import modeling.FPN as fpn
import utils.blob as blob_utils
import utils.boxes as box_utils
......@@ -52,6 +53,11 @@ def im_detect_all(model, im, box_proposals, timers=None):
if timers is None:
timers = defaultdict(Timer)
# Handle RetinaNet testing separately for now
if cfg.RETINANET.RETINANET_ON:
cls_boxes = test_retinanet.im_detect_bbox(model, im, timers)
return cls_boxes, None, None
timers['im_detect_bbox'].tic()
if cfg.TEST.BBOX_AUG.ENABLED:
scores, boxes, im_scales = im_detect_bbox_aug(model, im, box_proposals)
......
......@@ -38,8 +38,6 @@ from datasets.json_dataset import JsonDataset
from modeling import model_builder
from utils.io import save_object
from utils.timer import Timer
import core.test_retinanet as test_retinanet
import utils.c2 as c2_utils
import utils.env as envu
import utils.net as net_utils
......@@ -121,11 +119,6 @@ def test_net(ind_range=None):
'Use rpn_generate to generate proposals from RPN-only models'
assert cfg.TEST.DATASET != '', \
'TEST.DATASET must be set to the dataset name to test'
# Create anchors for RetinaNet
if cfg.RETINANET.RETINANET_ON:
anchors = test_retinanet.create_cell_anchors()
else:
anchors = None
output_dir = get_output_dir(training=False)
roidb, dataset, start_ind, end_ind, total_num_images = get_roidb_and_dataset(
ind_range
......@@ -152,13 +145,9 @@ def test_net(ind_range=None):
im = cv2.imread(entry['image'])
with c2_utils.NamedCudaScope(0):
if cfg.RETINANET.RETINANET_ON:
cls_boxes_i, cls_segms_i, cls_keyps_i = \
test_retinanet.im_detections(model, im, anchors, timers)
else:
cls_boxes_i, cls_segms_i, cls_keyps_i = im_detect_all(
model, im, box_proposals, timers
)
cls_boxes_i, cls_segms_i, cls_keyps_i = im_detect_all(
model, im, box_proposals, timers
)
extend_results(i, all_boxes, cls_boxes_i)
if cls_segms_i is not None:
......
......@@ -36,7 +36,7 @@ import utils.boxes as box_utils
logger = logging.getLogger(__name__)
def create_cell_anchors():
def _create_cell_anchors():
"""
Generate all types of anchors for all fpn levels/scales/aspect ratios.
This function is called only once at the beginning of inference.
......@@ -65,11 +65,13 @@ def create_cell_anchors():
return anchors
def im_detections(model, im, anchors, timers=None):
def im_detect_bbox(model, im, timers=None):
"""Generate RetinaNet detections on a single image."""
if timers is None:
timers = defaultdict(Timer)
# Although anchors are input independent and could be precomputed,
# recomputing them per image only brings a small overhead
anchors = _create_cell_anchors()
timers['im_detect_bbox'].tic()
k_max, k_min = cfg.FPN.RPN_MAX_LEVEL, cfg.FPN.RPN_MIN_LEVEL
A = cfg.RETINANET.SCALES_PER_OCTAVE * len(cfg.RETINANET.ASPECT_RATIOS)
......@@ -187,4 +189,4 @@ def im_detections(model, im, anchors, timers=None):
cls_boxes[c] = detections[inds, :5]
timers['misc_bbox'].toc()
return cls_boxes, None, None
return cls_boxes
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment