Commit b7159125 authored by Ilija Radosavovic's avatar Ilija Radosavovic Committed by Facebook Github Bot

Re-use test_engine roidb range testing logic for RetinaNet

Reviewed By: rbgirshick

Differential Revision: D6829350

fbshipit-source-id: a5fbaa8900fab28602d7a993d63206ca67b24975
parent 03544cd8
......@@ -38,6 +38,8 @@ from datasets.json_dataset import JsonDataset
from modeling import model_builder
from utils.io import save_object
from utils.timer import Timer
import core.test_retinanet as test_retinanet
import utils.c2 as c2_utils
import utils.env as envu
import utils.net as net_utils
......@@ -119,9 +121,11 @@ def test_net(ind_range=None):
'Use rpn_generate to generate proposals from RPN-only models'
assert cfg.TEST.DATASET != '', \
'TEST.DATASET must be set to the dataset name to test'
# Create anchors for RetinaNet
if cfg.RETINANET.RETINANET_ON:
import core.test_retinanet as test_retinanet
return test_retinanet.test_retinanet(ind_range)
anchors = test_retinanet.create_cell_anchors()
else:
anchors = None
output_dir = get_output_dir(training=False)
roidb, dataset, start_ind, end_ind, total_num_images = get_roidb_and_dataset(
ind_range
......@@ -132,11 +136,7 @@ def test_net(ind_range=None):
all_boxes, all_segms, all_keyps = empty_results(num_classes, num_images)
timers = defaultdict(Timer)
for i, entry in enumerate(roidb):
if cfg.MODEL.FASTER_RCNN:
# Faster R-CNN type models generate proposals on-the-fly with an
# in-network RPN
box_proposals = None
else:
if cfg.TEST.PRECOMPUTED_PROPOSALS:
# The roidb may contain ground-truth rois (for example, if the roidb
# comes from the training or val split). We only want to evaluate
# detection on the *non*-ground-truth rois. We select only the rois
......@@ -145,12 +145,20 @@ def test_net(ind_range=None):
box_proposals = entry['boxes'][entry['gt_classes'] == 0]
if len(box_proposals) == 0:
continue
else:
# Faster R-CNN type models generate proposals on-the-fly with an
# in-network RPN; 1-stage models don't require proposals.
box_proposals = None
im = cv2.imread(entry['image'])
with c2_utils.NamedCudaScope(0):
cls_boxes_i, cls_segms_i, cls_keyps_i = im_detect_all(
model, im, box_proposals, timers
)
if cfg.RETINANET.RETINANET_ON:
cls_boxes_i, cls_segms_i, cls_keyps_i = \
test_retinanet.im_detections(model, im, anchors, timers)
else:
cls_boxes_i, cls_segms_i, cls_keyps_i = im_detect_all(
model, im, box_proposals, timers
)
extend_results(i, all_boxes, cls_boxes_i)
if cls_segms_i is not None:
......@@ -238,13 +246,13 @@ def get_roidb_and_dataset(ind_range):
restrict it to a range of indices if ind_range is a pair of integers.
"""
dataset = JsonDataset(cfg.TEST.DATASET)
if cfg.MODEL.FASTER_RCNN:
roidb = dataset.get_roidb()
else:
if cfg.TEST.PRECOMPUTED_PROPOSALS:
roidb = dataset.get_roidb(
proposal_file=cfg.TEST.PROPOSAL_FILE,
proposal_limit=cfg.TEST.PROPOSAL_LIMIT
)
else:
roidb = dataset.get_roidb()
if ind_range is not None:
total_num_images = len(roidb)
......
......@@ -21,26 +21,17 @@ from __future__ import print_function
from __future__ import unicode_literals
import numpy as np
import cv2
import os
import yaml
import logging
from collections import defaultdict
from caffe2.python import core, workspace
from core.config import cfg, get_output_dir
from core.config import cfg
from core.rpn_generator import _get_image_blob
from datasets.json_dataset import JsonDataset
from modeling import model_builder
from modeling.generate_anchors import generate_anchors
from utils.io import save_object
from utils.timer import Timer
import core.test_engine as test_engine
import utils.boxes as box_utils
import utils.c2 as c2_utils
import utils.net as nu
logger = logging.getLogger(__name__)
......@@ -74,8 +65,12 @@ def create_cell_anchors():
return anchors
def im_detections(model, im, anchors):
def im_detections(model, im, anchors, timers=None):
"""Generate RetinaNet detections on a single image."""
if timers is None:
timers = defaultdict(Timer)
timers['im_detect_bbox'].tic()
k_max, k_min = cfg.FPN.RPN_MAX_LEVEL, cfg.FPN.RPN_MIN_LEVEL
A = cfg.RETINANET.SCALES_PER_OCTAVE * len(cfg.RETINANET.ASPECT_RATIOS)
inputs = {}
......@@ -160,8 +155,10 @@ def im_detections(model, im, anchors):
inds = np.where(classes == cls - 1)[0]
if len(inds) > 0:
boxes_all[cls].extend(box_scores[inds, :])
timers['im_detect_bbox'].toc()
# Combine predictions across all levels and retain the top scoring by class
timers['misc_bbox'].tic()
detections = []
for cls, boxes in boxes_all.items():
cls_dets = np.vstack(boxes).astype(dtype=np.float32)
......@@ -188,70 +185,6 @@ def im_detections(model, im, anchors):
for c in range(1, num_classes):
inds = np.where(detections[:, 5] == c)[0]
cls_boxes[c] = detections[inds, :5]
timers['misc_bbox'].toc()
return cls_boxes
def im_list_detections(model, roidb):
"""Generate RetinaNet detections on all images in an imdb."""
_t = Timer()
num_images = len(roidb)
num_classes = cfg.MODEL.NUM_CLASSES
all_boxes, all_segms, all_keyps = test_engine.empty_results(
num_classes, num_images
)
# create anchors for each level
anchors = create_cell_anchors()
for i, entry in enumerate(roidb):
im = cv2.imread(entry['image'])
with c2_utils.NamedCudaScope(0):
_t.tic()
cls_boxes_i = im_detections(model, im, anchors)
_t.toc()
test_engine.extend_results(i, all_boxes, cls_boxes_i)
logger.info(
'im_detections: {:d}/{:d} {:.3f}s'.format(
i + 1, num_images, _t.average_time))
return all_boxes, all_segms, all_keyps
def test_retinanet(ind_range=None):
"""
Test RetinaNet model either on the entire dataset or the subset of dataset
specified by the index range
"""
assert cfg.RETINANET.RETINANET_ON, \
'RETINANET_ON must be set for testing RetinaNet model'
output_dir = get_output_dir(training=False)
dataset = JsonDataset(cfg.TEST.DATASET)
roidb = dataset.get_roidb()
if ind_range is not None:
start, end = ind_range
roidb = roidb[start:end]
# Create and load the model
model = model_builder.create(cfg.MODEL.TYPE, train=False)
if cfg.TEST.WEIGHTS:
nu.initialize_from_weights_file(
model, cfg.TEST.WEIGHTS, broadcast=False
)
model_builder.add_inference_inputs(model)
workspace.CreateNet(model.net)
# Compute the detections
all_boxes, all_segms, all_keyps = im_list_detections(model, roidb)
# Save the detections
cfg_yaml = yaml.dump(cfg)
if ind_range is not None:
det_name = 'detection_range_%s_%s.pkl' % tuple(ind_range)
else:
det_name = 'detections.pkl'
det_file = os.path.join(output_dir, det_name)
save_object(
dict(
all_boxes=all_boxes,
all_segms=all_segms,
all_keyps=all_keyps,
cfg=cfg_yaml
), det_file
)
logger.info('Wrote detections to: {}'.format(os.path.abspath(det_file)))
return all_boxes, all_segms, all_keyps
return cls_boxes, None, None
......@@ -339,6 +339,8 @@ def build_generic_retinanet_model(
"""Builds the model on a single GPU. Can be called in a loop over GPUs
with name and device scoping to create a data parallel model."""
blobs, dim, spatial_scales = add_conv_body_func(model)
if not model.train:
model.conv_body_net = model.net.Clone('conv_body_net')
retinanet_heads.add_fpn_retinanet_outputs(
model, blobs, dim, spatial_scales
)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment