Commit ffb87f4c authored by Ilija Radosavovic's avatar Ilija Radosavovic Committed by Facebook Github Bot

Make yaml load/dump functions specific to the environment

Reviewed By: ashwinb

Differential Revision: D10496498

fbshipit-source-id: eb12fe573ec3270172e27c2fdb39a70fc92d8d99
parent 91b894d8
......@@ -51,7 +51,6 @@ import numpy as np
import os
import os.path as osp
import six
import yaml
from detectron.utils.collections import AttrDict
from detectron.utils.io import cache_url
......@@ -63,7 +62,6 @@ __C = AttrDict()
# from detectron.core.config import cfg
cfg = __C
# Random note: avoid using '.ON' as a config key since yaml converts it to True;
# prefer 'ENABLED' instead
......@@ -1125,7 +1123,9 @@ def load_cfg(cfg_to_load):
# yaml object encoding: !!python/object/new:<module>.<object>
old_module, new_module = 'new:' + old_module, 'new:' + new_module
cfg_to_load = cfg_to_load.replace(old_module, new_module)
return yaml.load(cfg_to_load)
# Import inline due to a circular dependency between env.py and config.py
import detectron.utils.env as envu
return envu.yaml_load(cfg_to_load)
def merge_cfg_from_file(cfg_filename):
......
......@@ -33,7 +33,6 @@ import datetime
import logging
import numpy as np
import os
import yaml
from caffe2.python import core
from caffe2.python import workspace
......@@ -111,7 +110,7 @@ def multi_gpu_generate_rpn_on_dataset(
scores += rpn_data['scores']
ids += rpn_data['ids']
rpn_file = os.path.join(output_dir, 'rpn_proposals.pkl')
cfg_yaml = yaml.dump(cfg)
cfg_yaml = envu.yaml_dump(cfg)
save_object(
dict(boxes=boxes, scores=scores, ids=ids, cfg=cfg_yaml), rpn_file
)
......@@ -155,7 +154,7 @@ def generate_rpn_on_range(
gpu_id=gpu_id,
)
cfg_yaml = yaml.dump(cfg)
cfg_yaml = envu.yaml_dump(cfg)
if ind_range is not None:
rpn_name = 'rpn_proposals_range_%s_%s.pkl' % tuple(ind_range)
else:
......
......@@ -26,7 +26,6 @@ import datetime
import logging
import numpy as np
import os
import yaml
from caffe2.python import workspace
......@@ -201,7 +200,7 @@ def multi_gpu_test_net_on_dataset(
all_segms[cls_idx] += all_segms_batch[cls_idx]
all_keyps[cls_idx] += all_keyps_batch[cls_idx]
det_file = os.path.join(output_dir, 'detections.pkl')
cfg_yaml = yaml.dump(cfg)
cfg_yaml = envu.yaml_dump(cfg)
save_object(
dict(
all_boxes=all_boxes,
......@@ -303,7 +302,7 @@ def test_net(
show_class=True
)
cfg_yaml = yaml.dump(cfg)
cfg_yaml = envu.yaml_dump(cfg)
if ind_range is not None:
det_name = 'detection_range_%s_%s.pkl' % tuple(ind_range)
else:
......
......@@ -21,11 +21,11 @@ from __future__ import unicode_literals
import copy
import tempfile
import unittest
import yaml
from detectron.core.config import cfg
from detectron.utils.collections import AttrDict
import detectron.core.config as core_config
import detectron.utils.env as envu
import detectron.utils.logging as logging_utils
......@@ -59,7 +59,7 @@ class TestAttrDict(unittest.TestCase):
# Serialize immutability state
a.immutable(True)
a2 = core_config.load_cfg(yaml.dump(a))
a2 = core_config.load_cfg(envu.yaml_dump(a))
assert a.is_immutable()
assert a2.is_immutable()
......@@ -81,7 +81,7 @@ class TestCfg(unittest.TestCase):
# Test: merge from yaml
s = 'dummy1'
cfg2 = core_config.load_cfg(yaml.dump(cfg))
cfg2 = core_config.load_cfg(envu.yaml_dump(cfg))
cfg2.MODEL.TYPE = s
core_config.merge_cfg_from_cfg(cfg2)
assert cfg.MODEL.TYPE == s
......@@ -119,7 +119,7 @@ class TestCfg(unittest.TestCase):
def test_merge_cfg_from_file(self):
with tempfile.NamedTemporaryFile() as f:
yaml.dump(cfg, f)
envu.yaml_dump(cfg, f)
s = cfg.MODEL.TYPE
cfg.MODEL.TYPE = 'dummy'
assert cfg.MODEL.TYPE != s
......@@ -161,7 +161,7 @@ class TestCfg(unittest.TestCase):
with tempfile.NamedTemporaryFile() as f:
cfg2 = copy.deepcopy(cfg)
cfg2.MODEL.DILATION = 2
yaml.dump(cfg2, f)
envu.yaml_dump(cfg2, f)
with self.assertRaises(AttributeError):
_ = cfg.MODEL.DILATION # noqa
core_config.merge_cfg_from_file(f.name)
......@@ -187,7 +187,7 @@ class TestCfg(unittest.TestCase):
cfg2.EXAMPLE = AttrDict()
cfg2.EXAMPLE.RENAMED = AttrDict()
cfg2.EXAMPLE.RENAMED.KEY = 'foobar'
yaml.dump(cfg2, f)
envu.yaml_dump(cfg2, f)
with self.assertRaises(AttributeError):
_ = cfg.EXAMPLE.RENAMED.KEY # noqa
with self.assertRaises(KeyError):
......
......@@ -22,6 +22,7 @@ from __future__ import unicode_literals
import os
import sys
import yaml
# Default value of the CMake install prefix
_CMAKE_INSTALL_PREFIX = '/usr/local'
......@@ -83,3 +84,8 @@ def get_custom_ops_lib():
assert os.path.exists(custom_ops_lib), \
'Custom ops lib not found at \'{}\''.format(custom_ops_lib)
return custom_ops_lib
# YAML load/dump function aliases
yaml_load = yaml.load
yaml_dump = yaml.dump
......@@ -25,7 +25,6 @@ import logging
import numpy as np
import os
import pprint
import yaml
from caffe2.python import core
from caffe2.python import workspace
......@@ -35,6 +34,7 @@ from detectron.core.config import load_cfg
from detectron.utils.io import load_object
from detectron.utils.io import save_object
import detectron.utils.c2 as c2_utils
import detectron.utils.env as envu
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
......@@ -165,7 +165,7 @@ def save_model_to_weights_file(weights_file, model):
' {:s} -> {:s} (preserved)'.format(
scoped_name, unscoped_name))
blobs[unscoped_name] = workspace.FetchBlob(scoped_name)
cfg_yaml = yaml.dump(cfg)
cfg_yaml = envu.yaml_dump(cfg)
save_object(dict(blobs=blobs, cfg=cfg_yaml), weights_file)
......
......@@ -24,13 +24,13 @@ from __future__ import print_function
from __future__ import unicode_literals
import os
import yaml
import numpy as np
import subprocess
from six.moves import shlex_quote
from detectron.core.config import cfg
from detectron.utils.io import load_object
import detectron.utils.env as envu
import logging
logger = logging.getLogger(__name__)
......@@ -47,7 +47,7 @@ def process_in_parallel(
# subprocesses
cfg_file = os.path.join(output_dir, '{}_range_config.yaml'.format(tag))
with open(cfg_file, 'w') as f:
yaml.dump(cfg, stream=f)
envu.yaml_dump(cfg, stream=f)
subprocess_env = os.environ.copy()
processes = []
subinds = np.array_split(range(total_range_size), cfg.NUM_GPUS)
......
......@@ -32,7 +32,6 @@ import cv2 # NOQA (Must import before importing caffe2 due to bug in cv2)
import logging
import os
import sys
import yaml
from caffe2.python import workspace
......@@ -47,6 +46,7 @@ import detectron.core.rpn_generator as rpn_engine
import detectron.core.test_engine as model_engine
import detectron.datasets.dummy_datasets as dummy_datasets
import detectron.utils.c2 as c2_utils
import detectron.utils.env as envu
import detectron.utils.vis as vis_utils
c2_utils.import_detectron_ops()
......@@ -119,7 +119,7 @@ def get_rpn_box_proposals(im, args):
def main(args):
logger = logging.getLogger(__name__)
dummy_coco_dataset = dummy_datasets.get_coco_dataset()
cfg_orig = load_cfg(yaml.dump(cfg))
cfg_orig = load_cfg(envu.yaml_dump(cfg))
im = cv2.imread(args.im_file)
if args.rpn_pkl is not None:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment