Commit e7ad1b48 authored by Ross Girshick's avatar Ross Girshick Committed by Facebook Github Bot

cache_url fix for inference tools

Reviewed By: ir413

Differential Revision: D7364656

fbshipit-source-id: bfd31bc7c95b9606037c2f5546c9edd0e0318272
parent eddb1301
......@@ -29,6 +29,7 @@ from __future__ import unicode_literals
import argparse
import cv2 # NOQA (Must import before importing caffe2 due to bug in cv2)
import logging
import os
import sys
import yaml
......@@ -39,6 +40,7 @@ from core.config import assert_and_infer_cfg
from core.config import cfg
from core.config import merge_cfg_from_cfg
from core.config import merge_cfg_from_file
from utils.io import cache_url
import core.rpn_generator as rpn_engine
import core.test_engine as model_engine
import datasets.dummy_datasets as dummy_datasets
......@@ -112,6 +114,7 @@ def get_rpn_box_proposals(im, args):
def main(args):
logger = logging.getLogger(__name__)
dummy_coco_dataset = dummy_datasets.get_coco_dataset()
cfg_orig = yaml.load(yaml.dump(cfg))
im = cv2.imread(args.im_file)
......@@ -144,6 +147,11 @@ def main(args):
cls_keyps = cls_keyps_ if cls_keyps_ is not None else cls_keyps
workspace.ResetWorkspace()
out_name = os.path.join(
args.output_dir, '{}'.format(os.path.basename(args.im_file) + '.pdf')
)
logger.info('Processing {} -> {}'.format(args.im_file, out_name))
vis_utils.vis_one_image(
im[:, :, ::-1],
args.im_file,
......@@ -165,13 +173,18 @@ def check_args(args):
(args.rpn_pkl is None and args.rpn_cfg is None)
)
if args.rpn_pkl is not None:
args.rpn_pkl = cache_url(args.rpn_pkl, cfg.DOWNLOAD_CACHE)
assert os.path.exists(args.rpn_pkl)
assert os.path.exists(args.rpn_cfg)
if args.models_to_run is not None:
assert len(args.models_to_run) % 2 == 0
for model_file in args.models_to_run:
for i, model_file in enumerate(args.models_to_run):
if len(model_file) > 0:
assert os.path.exists(model_file)
if i % 2 == 0:
model_file = cache_url(model_file, cfg.DOWNLOAD_CACHE)
args.models_to_run[i] = model_file
assert os.path.exists(model_file), \
'\'{}\' does not exist'.format(model_file)
if __name__ == '__main__':
......
......@@ -38,6 +38,7 @@ from caffe2.python import workspace
from core.config import assert_and_infer_cfg
from core.config import cfg
from core.config import merge_cfg_from_file
from utils.io import cache_url
from utils.timer import Timer
import core.test_engine as infer_engine
import datasets.dummy_datasets as dummy_datasets
......@@ -94,6 +95,7 @@ def main(args):
logger = logging.getLogger(__name__)
merge_cfg_from_file(args.cfg)
cfg.NUM_GPUS = 1
args.weights = cache_url(args.weights, cfg.DOWNLOAD_CACHE)
assert_and_infer_cfg()
model = infer_engine.initialize_model_from_cfg(args.weights)
dummy_coco_dataset = dummy_datasets.get_coco_dataset()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment