Commit 1b1a8a4f authored by Ross Girshick's avatar Ross Girshick Committed by Facebook Github Bot

Do not mutate TEST.DATASET, TEST.PROPOSAL_FILE (also remove them)

Reviewed By: ir413

Differential Revision: D7148430

fbshipit-source-id: 71f25716d157daf4082a35dd2138debd9a43c7f1
parent 4db30576
......@@ -272,14 +272,6 @@ __C.TEST.FORCE_JSON_DATASET_EVAL = False
# Not set for 1-stage models and 2-stage models with RPN subnetwork enabled
__C.TEST.PRECOMPUTED_PROPOSALS = True
# [Inferred value; do not set directly in a config]
# Active dataset to test on
__C.TEST.DATASET = b''
# [Inferred value; do not set directly in a config]
# Active proposal file to use
__C.TEST.PROPOSAL_FILE = b''
# ---------------------------------------------------------------------------- #
# Test-time augmentations for bounding box detection
......@@ -1008,6 +1000,16 @@ _RENAMED_KEYS = {
"Also convert from a tuple, e.g. (600, ), " +
"to a integer, e.g. 600."
),
'TEST.DATASET': (
'TEST.DATASETS',
"Also convert from a string, e.g 'coco_2014_minival', " +
"to a tuple, e.g. ('coco_2014_minival', )."
),
'TEST.PROPOSAL_FILE': (
'TEST.PROPOSAL_FILES',
"Also convert from a string, e.g. '/path/to/props.pkl', " +
"to a tuple, e.g. ('/path/to/props.pkl', )."
),
}
......
......@@ -53,27 +53,31 @@ import utils.subprocess as subprocess_utils
logger = logging.getLogger(__name__)
def generate_rpn_on_dataset(output_dir, multi_gpu=False, gpu_id=0):
def generate_rpn_on_dataset(
dataset_name, _proposal_file_ignored, output_dir, multi_gpu=False, gpu_id=0
):
"""Run inference on a dataset."""
dataset = JsonDataset(cfg.TEST.DATASET)
dataset = JsonDataset(dataset_name)
test_timer = Timer()
test_timer.tic()
if multi_gpu:
num_images = len(dataset.get_roidb())
_boxes, _scores, _ids, rpn_file = multi_gpu_generate_rpn_on_dataset(
num_images, output_dir
dataset_name, _proposal_file_ignored, num_images, output_dir
)
else:
# Processes entire dataset range by default
_boxes, _scores, _ids, rpn_file = generate_rpn_on_range(
output_dir, gpu_id=gpu_id
dataset_name, _proposal_file_ignored, output_dir, gpu_id=gpu_id
)
test_timer.toc()
logger.info('Total inference time: {:.3f}s'.format(test_timer.average_time))
return evaluate_proposal_file(dataset, rpn_file, output_dir)
def multi_gpu_generate_rpn_on_dataset(num_images, output_dir):
def multi_gpu_generate_rpn_on_dataset(
dataset_name, _proposal_file_ignored, num_images, output_dir
):
"""Multi-gpu inference on a dataset."""
# Retrieve the test_net binary path
binary_dir = envu.get_runtime_dir()
......@@ -81,9 +85,12 @@ def multi_gpu_generate_rpn_on_dataset(num_images, output_dir):
binary = os.path.join(binary_dir, 'test_net' + binary_ext)
assert os.path.exists(binary), 'Binary \'{}\' not found'.format(binary)
# Pass the target dataset via the command line
opts = ['TEST.DATASETS', '("{}",)'.format(dataset_name)]
# Run inference in parallel in subprocesses
outputs = subprocess_utils.process_in_parallel(
'rpn_proposals', num_images, binary, output_dir
'rpn_proposals', num_images, binary, output_dir, opts
)
# Collate the results from each subprocess
......@@ -101,17 +108,19 @@ def multi_gpu_generate_rpn_on_dataset(num_images, output_dir):
return boxes, scores, ids, rpn_file
def generate_rpn_on_range(output_dir, ind_range=None, gpu_id=0):
def generate_rpn_on_range(
dataset_name, _proposal_file_ignored, output_dir, ind_range=None, gpu_id=0
):
"""Run inference on all images in a dataset or over an index range of images
in a dataset using a single GPU.
"""
assert cfg.TEST.WEIGHTS != '', \
'TEST.WEIGHTS must be set to the model file to test'
assert cfg.TEST.DATASET != '', \
'TEST.DATASET must be set to the dataset name to test'
assert cfg.MODEL.RPN_ONLY or cfg.MODEL.FASTER_RCNN
roidb, start_ind, end_ind, total_num_images = get_roidb(ind_range)
roidb, start_ind, end_ind, total_num_images = get_roidb(
dataset_name, ind_range
)
logger.info(
'Output will be saved to: {:s}'.format(os.path.abspath(output_dir))
)
......@@ -228,11 +237,11 @@ def im_proposals(model, im):
return boxes, scores
def get_roidb(ind_range):
def get_roidb(dataset_name, ind_range):
"""Get the roidb for the dataset specified in the global cfg. Optionally
restrict it to a range of indices if ind_range is a pair of integers.
"""
dataset = JsonDataset(cfg.TEST.DATASET)
dataset = JsonDataset(dataset_name)
roidb = dataset.get_roidb()
if ind_range is not None:
......
......@@ -63,6 +63,25 @@ def get_eval_functions():
return parent_func, child_func
def get_inference_dataset(index, is_parent=True):
assert is_parent or len(cfg.TEST.DATASETS) == 1, \
'The child inference process can only work on a single dataset'
dataset_name = cfg.TEST.DATASETS[index]
if cfg.TEST.PRECOMPUTED_PROPOSALS:
assert is_parent or len(cfg.TEST.PROPOSAL_FILES) == 1, \
'The child inference process can only work on a single proposal file'
assert len(cfg.TEST.PROPOSAL_FILES) == len(cfg.TEST.DATASETS), \
'If proposals are used, one proposal file must be specified for ' \
'each dataset'
proposal_file = cfg.TEST.PROPOSAL_FILES[index]
else:
proposal_file = None
return dataset_name, proposal_file
def run_inference(ind_range=None, multi_gpu_testing=False, gpu_id=0):
parent_func, child_func = get_eval_functions()
......@@ -72,41 +91,50 @@ def run_inference(ind_range=None, multi_gpu_testing=False, gpu_id=0):
# In this case we're either running inference on the entire dataset in a
# single process or (if multi_gpu_testing is True) using this process to
# launch subprocesses that each run inference on a range of the dataset
if len(cfg.TEST.DATASETS) == 0:
cfg.TEST.DATASETS = (cfg.TEST.DATASET, )
cfg.TEST.PROPOSAL_FILES = (cfg.TEST.PROPOSAL_FILE, )
all_results = {}
for i in range(len(cfg.TEST.DATASETS)):
cfg.TEST.DATASET = cfg.TEST.DATASETS[i]
if cfg.TEST.PRECOMPUTED_PROPOSALS:
cfg.TEST.PROPOSAL_FILE = cfg.TEST.PROPOSAL_FILES[i]
output_dir = get_output_dir(cfg.TEST.DATASET, training=False)
results = parent_func(output_dir, multi_gpu=multi_gpu_testing)
dataset_name, proposal_file = get_inference_dataset(i)
output_dir = get_output_dir(dataset_name, training=False)
results = parent_func(
dataset_name,
proposal_file,
output_dir,
multi_gpu=multi_gpu_testing
)
all_results.update(results)
return all_results
else:
# Subprocess child case:
# In this case test_net was called via subprocess.Popen to execute on a
# range of inputs on a single dataset (i.e., use cfg.TEST.DATASET and
# don't loop over cfg.TEST.DATASETS)
output_dir = get_output_dir(cfg.TEST.DATASET, training=False)
return child_func(output_dir, ind_range=ind_range, gpu_id=gpu_id)
# range of inputs on a single dataset
dataset_name, proposal_file = get_inference_dataset(0, is_parent=False)
output_dir = get_output_dir(dataset_name, training=False)
return child_func(
dataset_name,
proposal_file,
output_dir,
ind_range=ind_range,
gpu_id=gpu_id
)
def test_net_on_dataset(output_dir, multi_gpu=False, gpu_id=0):
def test_net_on_dataset(
dataset_name, proposal_file, output_dir, multi_gpu=False, gpu_id=0
):
"""Run inference on a dataset."""
dataset = JsonDataset(cfg.TEST.DATASET)
dataset = JsonDataset(dataset_name)
test_timer = Timer()
test_timer.tic()
if multi_gpu:
num_images = len(dataset.get_roidb())
all_boxes, all_segms, all_keyps = multi_gpu_test_net_on_dataset(
num_images, output_dir
dataset_name, proposal_file, num_images, output_dir
)
else:
all_boxes, all_segms, all_keyps = test_net(output_dir, gpu_id=gpu_id)
all_boxes, all_segms, all_keyps = test_net(
dataset_name, proposal_file, output_dir, gpu_id=gpu_id
)
test_timer.toc()
logger.info('Total inference time: {:.3f}s'.format(test_timer.average_time))
results = task_evaluation.evaluate_all(
......@@ -115,18 +143,25 @@ def test_net_on_dataset(output_dir, multi_gpu=False, gpu_id=0):
return results
def multi_gpu_test_net_on_dataset(num_images, output_dir):
def multi_gpu_test_net_on_dataset(
dataset_name, proposal_file, num_images, output_dir
):
"""Multi-gpu inference on a dataset."""
binary_dir = envu.get_runtime_dir()
binary_ext = envu.get_py_bin_ext()
binary = os.path.join(binary_dir, 'test_net' + binary_ext)
assert os.path.exists(binary), 'Binary \'{}\' not found'.format(binary)
# Pass the target dataset and proposal file (if any) via the command line
opts = ['TEST.DATASETS', '("{}",)'.format(dataset_name)]
if proposal_file:
opts += ['TEST.PROPOSAL_FILES', '("{}",)'.format(proposal_file)]
# Run inference in parallel in subprocesses
# Outputs will be a list of outputs from each subprocess, where the output
# of each subprocess is the dictionary saved by test_net().
outputs = subprocess_utils.process_in_parallel(
'detection', num_images, binary, output_dir
'detection', num_images, binary, output_dir, opts
)
# Collate the results from each subprocess
......@@ -156,7 +191,7 @@ def multi_gpu_test_net_on_dataset(num_images, output_dir):
return all_boxes, all_segms, all_keyps
def test_net(output_dir, ind_range=None, gpu_id=0):
def test_net(dataset_name, proposal_file, output_dir, ind_range=None, gpu_id=0):
"""Run inference on all images in a dataset or over an index range of images
in a dataset using a single GPU.
"""
......@@ -164,11 +199,9 @@ def test_net(output_dir, ind_range=None, gpu_id=0):
'TEST.WEIGHTS must be set to the model file to test'
assert not cfg.MODEL.RPN_ONLY, \
'Use rpn_generate to generate proposals from RPN-only models'
assert cfg.TEST.DATASET != '', \
'TEST.DATASET must be set to the dataset name to test'
roidb, dataset, start_ind, end_ind, total_num_images = get_roidb_and_dataset(
ind_range
dataset_name, proposal_file, ind_range
)
model = initialize_model_from_cfg(gpu_id=gpu_id)
num_images = len(roidb)
......@@ -277,14 +310,15 @@ def initialize_model_from_cfg(gpu_id=0):
return model
def get_roidb_and_dataset(ind_range):
def get_roidb_and_dataset(dataset_name, proposal_file, ind_range):
"""Get the roidb for the dataset specified in the global cfg. Optionally
restrict it to a range of indices if ind_range is a pair of integers.
"""
dataset = JsonDataset(cfg.TEST.DATASET)
dataset = JsonDataset(dataset_name)
if cfg.TEST.PRECOMPUTED_PROPOSALS:
assert proposal_file, 'No proposal file given'
roidb = dataset.get_roidb(
proposal_file=cfg.TEST.PROPOSAL_FILE,
proposal_file=proposal_file,
proposal_limit=cfg.TEST.PROPOSAL_LIMIT
)
else:
......
......@@ -36,7 +36,9 @@ import logging
logger = logging.getLogger(__name__)
def process_in_parallel(tag, total_range_size, binary, output_dir):
def process_in_parallel(
tag, total_range_size, binary, output_dir, opts=''
):
"""Run the specified binary cfg.NUM_GPUS times in parallel, each time as a
subprocess that uses one GPU. The binary must accept the command line
arguments `--range {start} {end}` that specify a data processing range.
......@@ -62,12 +64,13 @@ def process_in_parallel(tag, total_range_size, binary, output_dir):
start = subinds[i][0]
end = subinds[i][-1] + 1
subprocess_env['CUDA_VISIBLE_DEVICES'] = str(gpu_ind)
cmd = '{binary} --range {start} {end} --cfg {cfg_file} NUM_GPUS 1'
cmd = '{binary} --range {start} {end} --cfg {cfg_file} NUM_GPUS 1 {opts}'
cmd = cmd.format(
binary=shlex_quote(binary),
start=int(start),
end=int(end),
cfg_file=shlex_quote(cfg_file)
cfg_file=shlex_quote(cfg_file),
opts=' '.join([shlex_quote(opt) for opt in opts])
)
logger.info('{} range command {}: {}'.format(tag, i, cmd))
if i == 0:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment