Commit 22636d44 authored by Ashwin Bharambe's avatar Ashwin Bharambe Committed by Facebook Github Bot

Move `test_net.main()` wrapper inside `test_engine.run_inference`

Summary:
Another small refactor trying to reduce non-trivial logic existing
inside the `tools` scripts. I considered moving `main` as
`run_inference_and_check_expected_results` but that just sounded uglier.

Reviewed By: rbgirshick

Differential Revision: D7432976

fbshipit-source-id: a31106e11cf0c1d93781fc9a696d2afee553f3d6
parent 31738256
......@@ -83,11 +83,14 @@ def get_inference_dataset(index, is_parent=True):
def run_inference(
weights_file, ind_range=None, multi_gpu_testing=False, gpu_id=0
weights_file, ind_range=None,
multi_gpu_testing=False, gpu_id=0,
check_expected_results=False,
):
parent_func, child_func = get_eval_functions()
is_parent = ind_range is None
def result_getter():
if is_parent:
# Parent case:
# In this case we're either running inference on the entire dataset in a
......@@ -122,6 +125,17 @@ def run_inference(
gpu_id=gpu_id
)
all_results = result_getter()
if check_expected_results and is_parent:
task_evaluation.check_expected_results(
all_results,
atol=cfg.EXPECTED_RESULTS_ATOL,
rtol=cfg.EXPECTED_RESULTS_RTOL
)
task_evaluation.log_copy_paste_friendly_results(all_results)
return all_results
def test_net_on_dataset(
weights_file,
......
......@@ -36,7 +36,6 @@ from core.config import cfg
from core.config import merge_cfg_from_file
from core.config import merge_cfg_from_list
from core.test_engine import run_inference
from datasets import task_evaluation
import utils.c2
import utils.logging
......@@ -91,21 +90,6 @@ def parse_args():
return parser.parse_args()
def main(weights_file, ind_range=None, multi_gpu_testing=False):
all_results = run_inference(
weights_file,
ind_range=ind_range,
multi_gpu_testing=multi_gpu_testing,
)
if not ind_range:
task_evaluation.check_expected_results(
all_results,
atol=cfg.EXPECTED_RESULTS_ATOL,
rtol=cfg.EXPECTED_RESULTS_RTOL
)
task_evaluation.log_copy_paste_friendly_results(all_results)
if __name__ == '__main__':
workspace.GlobalInit(['caffe2', '--caffe2_log_level=0'])
logger = utils.logging.setup_logging(__name__)
......@@ -124,8 +108,9 @@ if __name__ == '__main__':
logger.info('Waiting for \'{}\' to exist...'.format(cfg.TEST.WEIGHTS))
time.sleep(10)
main(
run_inference(
cfg.TEST.WEIGHTS,
ind_range=args.range,
multi_gpu_testing=args.multi_gpu_testing
multi_gpu_testing=args.multi_gpu_testing,
check_expected_results=True,
)
......@@ -28,7 +28,6 @@ import logging
import numpy as np
import pprint
import sys
import test_net
from caffe2.python import workspace
......@@ -36,6 +35,7 @@ from core.config import assert_and_infer_cfg
from core.config import cfg
from core.config import merge_cfg_from_file
from core.config import merge_cfg_from_list
from core.test_engine import run_inference
from utils.logging import setup_logging
import utils.c2
import utils.train
......@@ -118,7 +118,10 @@ def test_model(model_file, multi_gpu_testing, opts=None):
# Clear memory before inference
workspace.ResetWorkspace()
# Run inference
test_net.main(model_file, multi_gpu_testing=multi_gpu_testing)
run_inference(
model_file, multi_gpu_testing=multi_gpu_testing,
check_expected_results=True,
)
if __name__ == '__main__':
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment