Commit b7159125 authored by Ilija Radosavovic's avatar Ilija Radosavovic Committed by Facebook Github Bot

Re-use test_engine roidb range testing logic for RetinaNet

Reviewed By: rbgirshick

Differential Revision: D6829350

fbshipit-source-id: a5fbaa8900fab28602d7a993d63206ca67b24975
parent 03544cd8
...@@ -38,6 +38,8 @@ from datasets.json_dataset import JsonDataset ...@@ -38,6 +38,8 @@ from datasets.json_dataset import JsonDataset
from modeling import model_builder from modeling import model_builder
from utils.io import save_object from utils.io import save_object
from utils.timer import Timer from utils.timer import Timer
import core.test_retinanet as test_retinanet
import utils.c2 as c2_utils import utils.c2 as c2_utils
import utils.env as envu import utils.env as envu
import utils.net as net_utils import utils.net as net_utils
...@@ -119,9 +121,11 @@ def test_net(ind_range=None): ...@@ -119,9 +121,11 @@ def test_net(ind_range=None):
'Use rpn_generate to generate proposals from RPN-only models' 'Use rpn_generate to generate proposals from RPN-only models'
assert cfg.TEST.DATASET != '', \ assert cfg.TEST.DATASET != '', \
'TEST.DATASET must be set to the dataset name to test' 'TEST.DATASET must be set to the dataset name to test'
# Create anchors for RetinaNet
if cfg.RETINANET.RETINANET_ON: if cfg.RETINANET.RETINANET_ON:
import core.test_retinanet as test_retinanet anchors = test_retinanet.create_cell_anchors()
return test_retinanet.test_retinanet(ind_range) else:
anchors = None
output_dir = get_output_dir(training=False) output_dir = get_output_dir(training=False)
roidb, dataset, start_ind, end_ind, total_num_images = get_roidb_and_dataset( roidb, dataset, start_ind, end_ind, total_num_images = get_roidb_and_dataset(
ind_range ind_range
...@@ -132,11 +136,7 @@ def test_net(ind_range=None): ...@@ -132,11 +136,7 @@ def test_net(ind_range=None):
all_boxes, all_segms, all_keyps = empty_results(num_classes, num_images) all_boxes, all_segms, all_keyps = empty_results(num_classes, num_images)
timers = defaultdict(Timer) timers = defaultdict(Timer)
for i, entry in enumerate(roidb): for i, entry in enumerate(roidb):
if cfg.MODEL.FASTER_RCNN: if cfg.TEST.PRECOMPUTED_PROPOSALS:
# Faster R-CNN type models generate proposals on-the-fly with an
# in-network RPN
box_proposals = None
else:
# The roidb may contain ground-truth rois (for example, if the roidb # The roidb may contain ground-truth rois (for example, if the roidb
# comes from the training or val split). We only want to evaluate # comes from the training or val split). We only want to evaluate
# detection on the *non*-ground-truth rois. We select only the rois # detection on the *non*-ground-truth rois. We select only the rois
...@@ -145,12 +145,20 @@ def test_net(ind_range=None): ...@@ -145,12 +145,20 @@ def test_net(ind_range=None):
box_proposals = entry['boxes'][entry['gt_classes'] == 0] box_proposals = entry['boxes'][entry['gt_classes'] == 0]
if len(box_proposals) == 0: if len(box_proposals) == 0:
continue continue
else:
# Faster R-CNN type models generate proposals on-the-fly with an
# in-network RPN; 1-stage models don't require proposals.
box_proposals = None
im = cv2.imread(entry['image']) im = cv2.imread(entry['image'])
with c2_utils.NamedCudaScope(0): with c2_utils.NamedCudaScope(0):
cls_boxes_i, cls_segms_i, cls_keyps_i = im_detect_all( if cfg.RETINANET.RETINANET_ON:
model, im, box_proposals, timers cls_boxes_i, cls_segms_i, cls_keyps_i = \
) test_retinanet.im_detections(model, im, anchors, timers)
else:
cls_boxes_i, cls_segms_i, cls_keyps_i = im_detect_all(
model, im, box_proposals, timers
)
extend_results(i, all_boxes, cls_boxes_i) extend_results(i, all_boxes, cls_boxes_i)
if cls_segms_i is not None: if cls_segms_i is not None:
...@@ -238,13 +246,13 @@ def get_roidb_and_dataset(ind_range): ...@@ -238,13 +246,13 @@ def get_roidb_and_dataset(ind_range):
restrict it to a range of indices if ind_range is a pair of integers. restrict it to a range of indices if ind_range is a pair of integers.
""" """
dataset = JsonDataset(cfg.TEST.DATASET) dataset = JsonDataset(cfg.TEST.DATASET)
if cfg.MODEL.FASTER_RCNN: if cfg.TEST.PRECOMPUTED_PROPOSALS:
roidb = dataset.get_roidb()
else:
roidb = dataset.get_roidb( roidb = dataset.get_roidb(
proposal_file=cfg.TEST.PROPOSAL_FILE, proposal_file=cfg.TEST.PROPOSAL_FILE,
proposal_limit=cfg.TEST.PROPOSAL_LIMIT proposal_limit=cfg.TEST.PROPOSAL_LIMIT
) )
else:
roidb = dataset.get_roidb()
if ind_range is not None: if ind_range is not None:
total_num_images = len(roidb) total_num_images = len(roidb)
......
...@@ -21,26 +21,17 @@ from __future__ import print_function ...@@ -21,26 +21,17 @@ from __future__ import print_function
from __future__ import unicode_literals from __future__ import unicode_literals
import numpy as np import numpy as np
import cv2
import os
import yaml
import logging import logging
from collections import defaultdict from collections import defaultdict
from caffe2.python import core, workspace from caffe2.python import core, workspace
from core.config import cfg, get_output_dir from core.config import cfg
from core.rpn_generator import _get_image_blob from core.rpn_generator import _get_image_blob
from datasets.json_dataset import JsonDataset
from modeling import model_builder
from modeling.generate_anchors import generate_anchors from modeling.generate_anchors import generate_anchors
from utils.io import save_object
from utils.timer import Timer from utils.timer import Timer
import core.test_engine as test_engine
import utils.boxes as box_utils import utils.boxes as box_utils
import utils.c2 as c2_utils
import utils.net as nu
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -74,8 +65,12 @@ def create_cell_anchors(): ...@@ -74,8 +65,12 @@ def create_cell_anchors():
return anchors return anchors
def im_detections(model, im, anchors): def im_detections(model, im, anchors, timers=None):
"""Generate RetinaNet detections on a single image.""" """Generate RetinaNet detections on a single image."""
if timers is None:
timers = defaultdict(Timer)
timers['im_detect_bbox'].tic()
k_max, k_min = cfg.FPN.RPN_MAX_LEVEL, cfg.FPN.RPN_MIN_LEVEL k_max, k_min = cfg.FPN.RPN_MAX_LEVEL, cfg.FPN.RPN_MIN_LEVEL
A = cfg.RETINANET.SCALES_PER_OCTAVE * len(cfg.RETINANET.ASPECT_RATIOS) A = cfg.RETINANET.SCALES_PER_OCTAVE * len(cfg.RETINANET.ASPECT_RATIOS)
inputs = {} inputs = {}
...@@ -160,8 +155,10 @@ def im_detections(model, im, anchors): ...@@ -160,8 +155,10 @@ def im_detections(model, im, anchors):
inds = np.where(classes == cls - 1)[0] inds = np.where(classes == cls - 1)[0]
if len(inds) > 0: if len(inds) > 0:
boxes_all[cls].extend(box_scores[inds, :]) boxes_all[cls].extend(box_scores[inds, :])
timers['im_detect_bbox'].toc()
# Combine predictions across all levels and retain the top scoring by class # Combine predictions across all levels and retain the top scoring by class
timers['misc_bbox'].tic()
detections = [] detections = []
for cls, boxes in boxes_all.items(): for cls, boxes in boxes_all.items():
cls_dets = np.vstack(boxes).astype(dtype=np.float32) cls_dets = np.vstack(boxes).astype(dtype=np.float32)
...@@ -188,70 +185,6 @@ def im_detections(model, im, anchors): ...@@ -188,70 +185,6 @@ def im_detections(model, im, anchors):
for c in range(1, num_classes): for c in range(1, num_classes):
inds = np.where(detections[:, 5] == c)[0] inds = np.where(detections[:, 5] == c)[0]
cls_boxes[c] = detections[inds, :5] cls_boxes[c] = detections[inds, :5]
timers['misc_bbox'].toc()
return cls_boxes return cls_boxes, None, None
def im_list_detections(model, roidb):
"""Generate RetinaNet detections on all images in an imdb."""
_t = Timer()
num_images = len(roidb)
num_classes = cfg.MODEL.NUM_CLASSES
all_boxes, all_segms, all_keyps = test_engine.empty_results(
num_classes, num_images
)
# create anchors for each level
anchors = create_cell_anchors()
for i, entry in enumerate(roidb):
im = cv2.imread(entry['image'])
with c2_utils.NamedCudaScope(0):
_t.tic()
cls_boxes_i = im_detections(model, im, anchors)
_t.toc()
test_engine.extend_results(i, all_boxes, cls_boxes_i)
logger.info(
'im_detections: {:d}/{:d} {:.3f}s'.format(
i + 1, num_images, _t.average_time))
return all_boxes, all_segms, all_keyps
def test_retinanet(ind_range=None):
"""
Test RetinaNet model either on the entire dataset or the subset of dataset
specified by the index range
"""
assert cfg.RETINANET.RETINANET_ON, \
'RETINANET_ON must be set for testing RetinaNet model'
output_dir = get_output_dir(training=False)
dataset = JsonDataset(cfg.TEST.DATASET)
roidb = dataset.get_roidb()
if ind_range is not None:
start, end = ind_range
roidb = roidb[start:end]
# Create and load the model
model = model_builder.create(cfg.MODEL.TYPE, train=False)
if cfg.TEST.WEIGHTS:
nu.initialize_from_weights_file(
model, cfg.TEST.WEIGHTS, broadcast=False
)
model_builder.add_inference_inputs(model)
workspace.CreateNet(model.net)
# Compute the detections
all_boxes, all_segms, all_keyps = im_list_detections(model, roidb)
# Save the detections
cfg_yaml = yaml.dump(cfg)
if ind_range is not None:
det_name = 'detection_range_%s_%s.pkl' % tuple(ind_range)
else:
det_name = 'detections.pkl'
det_file = os.path.join(output_dir, det_name)
save_object(
dict(
all_boxes=all_boxes,
all_segms=all_segms,
all_keyps=all_keyps,
cfg=cfg_yaml
), det_file
)
logger.info('Wrote detections to: {}'.format(os.path.abspath(det_file)))
return all_boxes, all_segms, all_keyps
...@@ -339,6 +339,8 @@ def build_generic_retinanet_model( ...@@ -339,6 +339,8 @@ def build_generic_retinanet_model(
"""Builds the model on a single GPU. Can be called in a loop over GPUs """Builds the model on a single GPU. Can be called in a loop over GPUs
with name and device scoping to create a data parallel model.""" with name and device scoping to create a data parallel model."""
blobs, dim, spatial_scales = add_conv_body_func(model) blobs, dim, spatial_scales = add_conv_body_func(model)
if not model.train:
model.conv_body_net = model.net.Clone('conv_body_net')
retinanet_heads.add_fpn_retinanet_outputs( retinanet_heads.add_fpn_retinanet_outputs(
model, blobs, dim, spatial_scales model, blobs, dim, spatial_scales
) )
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment