Commit aee97b14 authored by Ilija Radosavovic's avatar Ilija Radosavovic Committed by Facebook Github Bot

Unify retinanet results format with the one used in test_engine

Reviewed By: rbgirshick

Differential Revision: D6828915

fbshipit-source-id: dfc8bd970c0555d56bfd7ad313550b0a12c154f2
parent 829a3941
...@@ -22,25 +22,23 @@ from __future__ import unicode_literals ...@@ -22,25 +22,23 @@ from __future__ import unicode_literals
import numpy as np import numpy as np
import cv2 import cv2
import json
import os import os
import uuid
import yaml import yaml
import logging import logging
from collections import defaultdict, OrderedDict from collections import defaultdict
from caffe2.python import core, workspace from caffe2.python import core, workspace
from core.config import cfg, get_output_dir from core.config import cfg, get_output_dir
from core.rpn_generator import _get_image_blob from core.rpn_generator import _get_image_blob
from datasets.json_dataset import JsonDataset
from datasets import task_evaluation from datasets import task_evaluation
from datasets.json_dataset import JsonDataset
from modeling import model_builder from modeling import model_builder
from modeling.generate_anchors import generate_anchors from modeling.generate_anchors import generate_anchors
from pycocotools.cocoeval import COCOeval
from utils.io import save_object from utils.io import save_object
from utils.timer import Timer from utils.timer import Timer
import core.test_engine as test_engine
import utils.boxes as box_utils import utils.boxes as box_utils
import utils.c2 as c2_utils import utils.c2 as c2_utils
import utils.env as envu import utils.env as envu
...@@ -178,38 +176,44 @@ def im_detections(model, im, anchors): ...@@ -178,38 +176,44 @@ def im_detections(model, im, anchors):
out[:, 5].fill(cls) out[:, 5].fill(cls)
detections.append(out) detections.append(out)
# detections (N, 6) format:
# detections[:, :4] - boxes
# detections[:, 4] - scores
# detections[:, 5] - classes
detections = np.vstack(detections) detections = np.vstack(detections)
# sort all again # sort all again
inds = np.argsort(-detections[:, 4]) inds = np.argsort(-detections[:, 4])
detections = detections[inds[0:cfg.TEST.DETECTIONS_PER_IM], :] detections = detections[inds[0:cfg.TEST.DETECTIONS_PER_IM], :]
boxes = detections[:, 0:4]
scores = detections[:, 4] # Convert the detections to image cls_ format (see core/test_engine.py)
classes = detections[:, 5] num_classes = cfg.MODEL.NUM_CLASSES
return boxes, scores, classes cls_boxes = [[] for _ in range(cfg.MODEL.NUM_CLASSES)]
for c in range(1, num_classes):
inds = np.where(detections[:, 5] == c)[0]
cls_boxes[c] = detections[inds, :5]
return cls_boxes
def im_list_detections(model, im_list): def im_list_detections(model, roidb):
"""Generate RetinaNet proposals on all images in im_list.""" """Generate RetinaNet detections on all images in an imdb."""
_t = Timer() _t = Timer()
num_images = len(im_list) num_images = len(roidb)
im_list_boxes = [[] for _ in range(num_images)] num_classes = cfg.MODEL.NUM_CLASSES
im_list_scores = [[] for _ in range(num_images)] all_boxes, _, _ = test_engine.empty_results(num_classes, num_images)
im_list_ids = [[] for _ in range(num_images)]
im_list_classes = [[] for _ in range(num_images)]
# create anchors for each level # create anchors for each level
anchors = create_cell_anchors() anchors = create_cell_anchors()
for i in range(num_images): for i, entry in enumerate(roidb):
im_list_ids[i] = im_list[i]['id'] im = cv2.imread(entry['image'])
im = cv2.imread(im_list[i]['image'])
with c2_utils.NamedCudaScope(0): with c2_utils.NamedCudaScope(0):
_t.tic() _t.tic()
im_list_boxes[i], im_list_scores[i], im_list_classes[i] = \ cls_boxes_i = im_detections(model, im, anchors)
im_detections(model, im, anchors)
_t.toc() _t.toc()
test_engine.extend_results(i, all_boxes, cls_boxes_i)
logger.info( logger.info(
'im_detections: {:d}/{:d} {:.3f}s'.format( 'im_detections: {:d}/{:d} {:.3f}s'.format(
i + 1, num_images, _t.average_time)) i + 1, num_images, _t.average_time))
return im_list_boxes, im_list_scores, im_list_classes, im_list_ids return all_boxes
def test_retinanet(ind_range=None): def test_retinanet(ind_range=None):
...@@ -221,17 +225,17 @@ def test_retinanet(ind_range=None): ...@@ -221,17 +225,17 @@ def test_retinanet(ind_range=None):
'RETINANET_ON must be set for testing RetinaNet model' 'RETINANET_ON must be set for testing RetinaNet model'
output_dir = get_output_dir(training=False) output_dir = get_output_dir(training=False)
dataset = JsonDataset(cfg.TEST.DATASET) dataset = JsonDataset(cfg.TEST.DATASET)
im_list = dataset.get_roidb() roidb = dataset.get_roidb()
if ind_range is not None: if ind_range is not None:
start, end = ind_range start, end = ind_range
im_list = im_list[start:end] roidb = roidb[start:end]
logger.info('Testing on roidb range: {}-{}'.format(start, end)) logger.info('Testing on roidb range: {}-{}'.format(start, end))
else: else:
# if testing over the whole dataset, use the NUM_TEST_IMAGES setting # if testing over the whole dataset, use the NUM_TEST_IMAGES setting
# the NUM_TEST_IMAGES could be over a small set of images for quick # the NUM_TEST_IMAGES could be over a small set of images for quick
# debugging purposes # debugging purposes
im_list = im_list[0:cfg.TEST.NUM_TEST_IMAGES] roidb = roidb[0:cfg.TEST.NUM_TEST_IMAGES]
# Create and load the model
model = model_builder.create(cfg.MODEL.TYPE, train=False) model = model_builder.create(cfg.MODEL.TYPE, train=False)
if cfg.TEST.WEIGHTS: if cfg.TEST.WEIGHTS:
nu.initialize_from_weights_file( nu.initialize_from_weights_file(
...@@ -239,20 +243,20 @@ def test_retinanet(ind_range=None): ...@@ -239,20 +243,20 @@ def test_retinanet(ind_range=None):
) )
model_builder.add_inference_inputs(model) model_builder.add_inference_inputs(model)
workspace.CreateNet(model.net) workspace.CreateNet(model.net)
boxes, scores, classes, image_ids = im_list_detections( # Compute the detections
model, im_list[0:cfg.TEST.NUM_TEST_IMAGES]) all_boxes = im_list_detections(model, roidb)
# Save the detections
cfg_yaml = yaml.dump(cfg) cfg_yaml = yaml.dump(cfg)
if ind_range is not None: if ind_range is not None:
det_name = 'retinanet_detections_range_%s_%s.pkl' % tuple(ind_range) det_name = 'detection_range_%s_%s.pkl' % tuple(ind_range)
else: else:
det_name = 'retinanet_detections.pkl' det_name = 'detections.pkl'
det_file = os.path.join(output_dir, det_name) det_file = os.path.join(output_dir, det_name)
save_object( save_object(
dict(boxes=boxes, scores=scores, classes=classes, ids=image_ids, cfg=cfg_yaml), dict(all_boxes=all_boxes, cfg=cfg_yaml),
det_file) det_file)
logger.info('Wrote detections to: {}'.format(os.path.abspath(det_file))) logger.info('Wrote detections to: {}'.format(os.path.abspath(det_file)))
return boxes, scores, classes, image_ids return all_boxes
def multi_gpu_test_retinanet_on_dataset(num_images, output_dir, dataset): def multi_gpu_test_retinanet_on_dataset(num_images, output_dir, dataset):
...@@ -270,16 +274,25 @@ def multi_gpu_test_retinanet_on_dataset(num_images, output_dir, dataset): ...@@ -270,16 +274,25 @@ def multi_gpu_test_retinanet_on_dataset(num_images, output_dir, dataset):
# Run inference in parallel in subprocesses # Run inference in parallel in subprocesses
outputs = subprocess_utils.process_in_parallel( outputs = subprocess_utils.process_in_parallel(
'retinanet_detections', num_images, binary, output_dir) 'detection', num_images, binary, output_dir)
# Combine the results from each subprocess now # Combine the results from each subprocess
boxes, scores, classes, image_ids = [], [], [], [] all_boxes = [[] for _ in range(cfg.MODEL.NUM_CLASSES)]
for det_data in outputs: for det_data in outputs:
boxes.extend(det_data['boxes']) all_boxes_batch = det_data['all_boxes']
scores.extend(det_data['scores']) for cls_idx in range(1, cfg.MODEL.NUM_CLASSES):
classes.extend(det_data['classes']) all_boxes[cls_idx] += all_boxes_batch[cls_idx]
image_ids.extend(det_data['ids'])
return boxes, scores, classes, image_ids, # Save the computed detections
det_file = os.path.join(output_dir, 'detections.pkl')
cfg_yaml = yaml.dump(cfg)
save_object(
dict(all_boxes=all_boxes, cfg=cfg_yaml),
det_file
)
logger.info('Wrote detections to: {}'.format(os.path.abspath(det_file)))
return all_boxes
def test_retinanet_on_dataset(multi_gpu=False): def test_retinanet_on_dataset(multi_gpu=False):
...@@ -287,86 +300,24 @@ def test_retinanet_on_dataset(multi_gpu=False): ...@@ -287,86 +300,24 @@ def test_retinanet_on_dataset(multi_gpu=False):
Main entry point for testing on a given dataset: whether multi_gpu or not Main entry point for testing on a given dataset: whether multi_gpu or not
""" """
output_dir = get_output_dir(training=False) output_dir = get_output_dir(training=False)
logger.info('Output will be saved to: {:s}'.format(os.path.abspath(output_dir)))
dataset = JsonDataset(cfg.TEST.DATASET) dataset = JsonDataset(cfg.TEST.DATASET)
test_timer = Timer()
test_timer.tic()
# for test-dev or full test dataset, we generate detections for all images # for test-dev or full test dataset, we generate detections for all images
if 'test-dev' in cfg.TEST.DATASET or 'test' in cfg.TEST.DATASET: if 'test-dev' in cfg.TEST.DATASET or 'test' in cfg.TEST.DATASET:
cfg.TEST.NUM_TEST_IMAGES = len(dataset.get_roidb()) cfg.TEST.NUM_TEST_IMAGES = len(dataset.get_roidb())
if multi_gpu: if multi_gpu:
num_images = cfg.TEST.NUM_TEST_IMAGES num_images = cfg.TEST.NUM_TEST_IMAGES
boxes, scores, classes, image_ids = multi_gpu_test_retinanet_on_dataset( all_boxes = multi_gpu_test_retinanet_on_dataset(
num_images, output_dir, dataset num_images, output_dir, dataset
) )
else: else:
boxes, scores, classes, image_ids = test_retinanet() all_boxes = test_retinanet()
test_timer.toc()
# write RetinaNet detections pkl file to be used for various purposes logger.info('Total inference time: {:.3f}s'.format(test_timer.average_time))
# dump the boxes first just in case there are spurious failures results = task_evaluation.evaluate_all(
res_file = os.path.join(output_dir, 'retinanet_detections.pkl') dataset, all_boxes, None, None, output_dir
logger.info(
'Writing roidb detections to file: {}'.
format(os.path.abspath(res_file))
)
save_object(
dict(boxes=boxes, scores=scores, classes=classes, ids=image_ids),
res_file
) )
logger.info('Wrote RetinaNet detections to {}'.format(os.path.abspath(res_file))) return results
# Write the detections to a file that can be uploaded to coco evaluation server
# which takes a json file format
res_file = write_coco_detection_results(
output_dir, dataset, boxes, scores, classes, image_ids)
# Perform coco evaluation
coco_eval = coco_evaluate(dataset, res_file, image_ids)
box_results = task_evaluation._coco_eval_to_box_results(coco_eval)
return OrderedDict([(dataset.name, box_results)])
def coco_evaluate(json_dataset, res_file, image_ids):
coco_dt = json_dataset.COCO.loadRes(str(res_file))
coco_eval = COCOeval(json_dataset.COCO, coco_dt, 'bbox')
coco_eval.params.imgIds = image_ids
coco_eval.evaluate()
coco_eval.accumulate()
coco_eval.summarize()
return coco_eval
def write_coco_detection_results(
output_dir, json_dataset, all_boxes, all_scores, all_classes, image_ids,
use_salt=False
):
res_file = os.path.join(
output_dir, 'detections_' + json_dataset.name + '_results')
if use_salt:
res_file += '_{}'.format(str(uuid.uuid4()))
res_file += '.json'
logger.info('Writing RetinaNet detections for submitting to coco server...')
results = []
for (im_id, dets, cls, score) in zip(image_ids, all_boxes, all_classes, all_scores):
dets = dets.astype(np.float)
score = score.astype(np.float)
classes = np.array(
[json_dataset.contiguous_category_id_to_json_id[c] for c in cls])
xs = dets[:, 0]
ys = dets[:, 1]
ws = dets[:, 2] - xs + 1
hs = dets[:, 3] - ys + 1
results.extend(
[{'image_id': im_id,
'category_id': classes[k],
'bbox': [xs[k], ys[k], ws[k], hs[k]],
'score': score[k]} for k in range(dets.shape[0])])
logger.info('Writing detection results to json: {}'.format(
os.path.abspath(res_file)
))
with open(res_file, 'w') as fid:
json.dump(results, fid)
logger.info('Done!')
return res_file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment