Commit 22636d44 authored by Ashwin Bharambe's avatar Ashwin Bharambe Committed by Facebook Github Bot

Move `test_net.main()` wrapper inside `test_engine.run_inference`

Summary:
Another small refactor trying to reduce non-trivial logic existing
inside the `tools` scripts. I considered moving `main` as
`run_inference_and_check_expected_results` but that just sounded uglier.

Reviewed By: rbgirshick

Differential Revision: D7432976

fbshipit-source-id: a31106e11cf0c1d93781fc9a696d2afee553f3d6
parent 31738256
...@@ -83,11 +83,14 @@ def get_inference_dataset(index, is_parent=True): ...@@ -83,11 +83,14 @@ def get_inference_dataset(index, is_parent=True):
def run_inference( def run_inference(
weights_file, ind_range=None, multi_gpu_testing=False, gpu_id=0 weights_file, ind_range=None,
multi_gpu_testing=False, gpu_id=0,
check_expected_results=False,
): ):
parent_func, child_func = get_eval_functions() parent_func, child_func = get_eval_functions()
is_parent = ind_range is None is_parent = ind_range is None
def result_getter():
if is_parent: if is_parent:
# Parent case: # Parent case:
# In this case we're either running inference on the entire dataset in a # In this case we're either running inference on the entire dataset in a
...@@ -122,6 +125,17 @@ def run_inference( ...@@ -122,6 +125,17 @@ def run_inference(
gpu_id=gpu_id gpu_id=gpu_id
) )
all_results = result_getter()
if check_expected_results and is_parent:
task_evaluation.check_expected_results(
all_results,
atol=cfg.EXPECTED_RESULTS_ATOL,
rtol=cfg.EXPECTED_RESULTS_RTOL
)
task_evaluation.log_copy_paste_friendly_results(all_results)
return all_results
def test_net_on_dataset( def test_net_on_dataset(
weights_file, weights_file,
......
...@@ -36,7 +36,6 @@ from core.config import cfg ...@@ -36,7 +36,6 @@ from core.config import cfg
from core.config import merge_cfg_from_file from core.config import merge_cfg_from_file
from core.config import merge_cfg_from_list from core.config import merge_cfg_from_list
from core.test_engine import run_inference from core.test_engine import run_inference
from datasets import task_evaluation
import utils.c2 import utils.c2
import utils.logging import utils.logging
...@@ -91,21 +90,6 @@ def parse_args(): ...@@ -91,21 +90,6 @@ def parse_args():
return parser.parse_args() return parser.parse_args()
def main(weights_file, ind_range=None, multi_gpu_testing=False):
all_results = run_inference(
weights_file,
ind_range=ind_range,
multi_gpu_testing=multi_gpu_testing,
)
if not ind_range:
task_evaluation.check_expected_results(
all_results,
atol=cfg.EXPECTED_RESULTS_ATOL,
rtol=cfg.EXPECTED_RESULTS_RTOL
)
task_evaluation.log_copy_paste_friendly_results(all_results)
if __name__ == '__main__': if __name__ == '__main__':
workspace.GlobalInit(['caffe2', '--caffe2_log_level=0']) workspace.GlobalInit(['caffe2', '--caffe2_log_level=0'])
logger = utils.logging.setup_logging(__name__) logger = utils.logging.setup_logging(__name__)
...@@ -124,8 +108,9 @@ if __name__ == '__main__': ...@@ -124,8 +108,9 @@ if __name__ == '__main__':
logger.info('Waiting for \'{}\' to exist...'.format(cfg.TEST.WEIGHTS)) logger.info('Waiting for \'{}\' to exist...'.format(cfg.TEST.WEIGHTS))
time.sleep(10) time.sleep(10)
main( run_inference(
cfg.TEST.WEIGHTS, cfg.TEST.WEIGHTS,
ind_range=args.range, ind_range=args.range,
multi_gpu_testing=args.multi_gpu_testing multi_gpu_testing=args.multi_gpu_testing,
check_expected_results=True,
) )
...@@ -28,7 +28,6 @@ import logging ...@@ -28,7 +28,6 @@ import logging
import numpy as np import numpy as np
import pprint import pprint
import sys import sys
import test_net
from caffe2.python import workspace from caffe2.python import workspace
...@@ -36,6 +35,7 @@ from core.config import assert_and_infer_cfg ...@@ -36,6 +35,7 @@ from core.config import assert_and_infer_cfg
from core.config import cfg from core.config import cfg
from core.config import merge_cfg_from_file from core.config import merge_cfg_from_file
from core.config import merge_cfg_from_list from core.config import merge_cfg_from_list
from core.test_engine import run_inference
from utils.logging import setup_logging from utils.logging import setup_logging
import utils.c2 import utils.c2
import utils.train import utils.train
...@@ -118,7 +118,10 @@ def test_model(model_file, multi_gpu_testing, opts=None): ...@@ -118,7 +118,10 @@ def test_model(model_file, multi_gpu_testing, opts=None):
# Clear memory before inference # Clear memory before inference
workspace.ResetWorkspace() workspace.ResetWorkspace()
# Run inference # Run inference
test_net.main(model_file, multi_gpu_testing=multi_gpu_testing) run_inference(
model_file, multi_gpu_testing=multi_gpu_testing,
check_expected_results=True,
)
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment