Commit ffb87f4c authored by Ilija Radosavovic's avatar Ilija Radosavovic Committed by Facebook Github Bot

Make yaml load/dump functions specific to the environment

Reviewed By: ashwinb

Differential Revision: D10496498

fbshipit-source-id: eb12fe573ec3270172e27c2fdb39a70fc92d8d99
parent 91b894d8
...@@ -51,7 +51,6 @@ import numpy as np ...@@ -51,7 +51,6 @@ import numpy as np
import os import os
import os.path as osp import os.path as osp
import six import six
import yaml
from detectron.utils.collections import AttrDict from detectron.utils.collections import AttrDict
from detectron.utils.io import cache_url from detectron.utils.io import cache_url
...@@ -63,7 +62,6 @@ __C = AttrDict() ...@@ -63,7 +62,6 @@ __C = AttrDict()
# from detectron.core.config import cfg # from detectron.core.config import cfg
cfg = __C cfg = __C
# Random note: avoid using '.ON' as a config key since yaml converts it to True; # Random note: avoid using '.ON' as a config key since yaml converts it to True;
# prefer 'ENABLED' instead # prefer 'ENABLED' instead
...@@ -1125,7 +1123,9 @@ def load_cfg(cfg_to_load): ...@@ -1125,7 +1123,9 @@ def load_cfg(cfg_to_load):
# yaml object encoding: !!python/object/new:<module>.<object> # yaml object encoding: !!python/object/new:<module>.<object>
old_module, new_module = 'new:' + old_module, 'new:' + new_module old_module, new_module = 'new:' + old_module, 'new:' + new_module
cfg_to_load = cfg_to_load.replace(old_module, new_module) cfg_to_load = cfg_to_load.replace(old_module, new_module)
return yaml.load(cfg_to_load) # Import inline due to a circular dependency between env.py and config.py
import detectron.utils.env as envu
return envu.yaml_load(cfg_to_load)
def merge_cfg_from_file(cfg_filename): def merge_cfg_from_file(cfg_filename):
......
...@@ -33,7 +33,6 @@ import datetime ...@@ -33,7 +33,6 @@ import datetime
import logging import logging
import numpy as np import numpy as np
import os import os
import yaml
from caffe2.python import core from caffe2.python import core
from caffe2.python import workspace from caffe2.python import workspace
...@@ -111,7 +110,7 @@ def multi_gpu_generate_rpn_on_dataset( ...@@ -111,7 +110,7 @@ def multi_gpu_generate_rpn_on_dataset(
scores += rpn_data['scores'] scores += rpn_data['scores']
ids += rpn_data['ids'] ids += rpn_data['ids']
rpn_file = os.path.join(output_dir, 'rpn_proposals.pkl') rpn_file = os.path.join(output_dir, 'rpn_proposals.pkl')
cfg_yaml = yaml.dump(cfg) cfg_yaml = envu.yaml_dump(cfg)
save_object( save_object(
dict(boxes=boxes, scores=scores, ids=ids, cfg=cfg_yaml), rpn_file dict(boxes=boxes, scores=scores, ids=ids, cfg=cfg_yaml), rpn_file
) )
...@@ -155,7 +154,7 @@ def generate_rpn_on_range( ...@@ -155,7 +154,7 @@ def generate_rpn_on_range(
gpu_id=gpu_id, gpu_id=gpu_id,
) )
cfg_yaml = yaml.dump(cfg) cfg_yaml = envu.yaml_dump(cfg)
if ind_range is not None: if ind_range is not None:
rpn_name = 'rpn_proposals_range_%s_%s.pkl' % tuple(ind_range) rpn_name = 'rpn_proposals_range_%s_%s.pkl' % tuple(ind_range)
else: else:
......
...@@ -26,7 +26,6 @@ import datetime ...@@ -26,7 +26,6 @@ import datetime
import logging import logging
import numpy as np import numpy as np
import os import os
import yaml
from caffe2.python import workspace from caffe2.python import workspace
...@@ -201,7 +200,7 @@ def multi_gpu_test_net_on_dataset( ...@@ -201,7 +200,7 @@ def multi_gpu_test_net_on_dataset(
all_segms[cls_idx] += all_segms_batch[cls_idx] all_segms[cls_idx] += all_segms_batch[cls_idx]
all_keyps[cls_idx] += all_keyps_batch[cls_idx] all_keyps[cls_idx] += all_keyps_batch[cls_idx]
det_file = os.path.join(output_dir, 'detections.pkl') det_file = os.path.join(output_dir, 'detections.pkl')
cfg_yaml = yaml.dump(cfg) cfg_yaml = envu.yaml_dump(cfg)
save_object( save_object(
dict( dict(
all_boxes=all_boxes, all_boxes=all_boxes,
...@@ -303,7 +302,7 @@ def test_net( ...@@ -303,7 +302,7 @@ def test_net(
show_class=True show_class=True
) )
cfg_yaml = yaml.dump(cfg) cfg_yaml = envu.yaml_dump(cfg)
if ind_range is not None: if ind_range is not None:
det_name = 'detection_range_%s_%s.pkl' % tuple(ind_range) det_name = 'detection_range_%s_%s.pkl' % tuple(ind_range)
else: else:
......
...@@ -21,11 +21,11 @@ from __future__ import unicode_literals ...@@ -21,11 +21,11 @@ from __future__ import unicode_literals
import copy import copy
import tempfile import tempfile
import unittest import unittest
import yaml
from detectron.core.config import cfg from detectron.core.config import cfg
from detectron.utils.collections import AttrDict from detectron.utils.collections import AttrDict
import detectron.core.config as core_config import detectron.core.config as core_config
import detectron.utils.env as envu
import detectron.utils.logging as logging_utils import detectron.utils.logging as logging_utils
...@@ -59,7 +59,7 @@ class TestAttrDict(unittest.TestCase): ...@@ -59,7 +59,7 @@ class TestAttrDict(unittest.TestCase):
# Serialize immutability state # Serialize immutability state
a.immutable(True) a.immutable(True)
a2 = core_config.load_cfg(yaml.dump(a)) a2 = core_config.load_cfg(envu.yaml_dump(a))
assert a.is_immutable() assert a.is_immutable()
assert a2.is_immutable() assert a2.is_immutable()
...@@ -81,7 +81,7 @@ class TestCfg(unittest.TestCase): ...@@ -81,7 +81,7 @@ class TestCfg(unittest.TestCase):
# Test: merge from yaml # Test: merge from yaml
s = 'dummy1' s = 'dummy1'
cfg2 = core_config.load_cfg(yaml.dump(cfg)) cfg2 = core_config.load_cfg(envu.yaml_dump(cfg))
cfg2.MODEL.TYPE = s cfg2.MODEL.TYPE = s
core_config.merge_cfg_from_cfg(cfg2) core_config.merge_cfg_from_cfg(cfg2)
assert cfg.MODEL.TYPE == s assert cfg.MODEL.TYPE == s
...@@ -119,7 +119,7 @@ class TestCfg(unittest.TestCase): ...@@ -119,7 +119,7 @@ class TestCfg(unittest.TestCase):
def test_merge_cfg_from_file(self): def test_merge_cfg_from_file(self):
with tempfile.NamedTemporaryFile() as f: with tempfile.NamedTemporaryFile() as f:
yaml.dump(cfg, f) envu.yaml_dump(cfg, f)
s = cfg.MODEL.TYPE s = cfg.MODEL.TYPE
cfg.MODEL.TYPE = 'dummy' cfg.MODEL.TYPE = 'dummy'
assert cfg.MODEL.TYPE != s assert cfg.MODEL.TYPE != s
...@@ -161,7 +161,7 @@ class TestCfg(unittest.TestCase): ...@@ -161,7 +161,7 @@ class TestCfg(unittest.TestCase):
with tempfile.NamedTemporaryFile() as f: with tempfile.NamedTemporaryFile() as f:
cfg2 = copy.deepcopy(cfg) cfg2 = copy.deepcopy(cfg)
cfg2.MODEL.DILATION = 2 cfg2.MODEL.DILATION = 2
yaml.dump(cfg2, f) envu.yaml_dump(cfg2, f)
with self.assertRaises(AttributeError): with self.assertRaises(AttributeError):
_ = cfg.MODEL.DILATION # noqa _ = cfg.MODEL.DILATION # noqa
core_config.merge_cfg_from_file(f.name) core_config.merge_cfg_from_file(f.name)
...@@ -187,7 +187,7 @@ class TestCfg(unittest.TestCase): ...@@ -187,7 +187,7 @@ class TestCfg(unittest.TestCase):
cfg2.EXAMPLE = AttrDict() cfg2.EXAMPLE = AttrDict()
cfg2.EXAMPLE.RENAMED = AttrDict() cfg2.EXAMPLE.RENAMED = AttrDict()
cfg2.EXAMPLE.RENAMED.KEY = 'foobar' cfg2.EXAMPLE.RENAMED.KEY = 'foobar'
yaml.dump(cfg2, f) envu.yaml_dump(cfg2, f)
with self.assertRaises(AttributeError): with self.assertRaises(AttributeError):
_ = cfg.EXAMPLE.RENAMED.KEY # noqa _ = cfg.EXAMPLE.RENAMED.KEY # noqa
with self.assertRaises(KeyError): with self.assertRaises(KeyError):
......
...@@ -22,6 +22,7 @@ from __future__ import unicode_literals ...@@ -22,6 +22,7 @@ from __future__ import unicode_literals
import os import os
import sys import sys
import yaml
# Default value of the CMake install prefix # Default value of the CMake install prefix
_CMAKE_INSTALL_PREFIX = '/usr/local' _CMAKE_INSTALL_PREFIX = '/usr/local'
...@@ -83,3 +84,8 @@ def get_custom_ops_lib(): ...@@ -83,3 +84,8 @@ def get_custom_ops_lib():
assert os.path.exists(custom_ops_lib), \ assert os.path.exists(custom_ops_lib), \
'Custom ops lib not found at \'{}\''.format(custom_ops_lib) 'Custom ops lib not found at \'{}\''.format(custom_ops_lib)
return custom_ops_lib return custom_ops_lib
# YAML load/dump function aliases
yaml_load = yaml.load
yaml_dump = yaml.dump
...@@ -25,7 +25,6 @@ import logging ...@@ -25,7 +25,6 @@ import logging
import numpy as np import numpy as np
import os import os
import pprint import pprint
import yaml
from caffe2.python import core from caffe2.python import core
from caffe2.python import workspace from caffe2.python import workspace
...@@ -35,6 +34,7 @@ from detectron.core.config import load_cfg ...@@ -35,6 +34,7 @@ from detectron.core.config import load_cfg
from detectron.utils.io import load_object from detectron.utils.io import load_object
from detectron.utils.io import save_object from detectron.utils.io import save_object
import detectron.utils.c2 as c2_utils import detectron.utils.c2 as c2_utils
import detectron.utils.env as envu
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO) logger.setLevel(logging.INFO)
...@@ -165,7 +165,7 @@ def save_model_to_weights_file(weights_file, model): ...@@ -165,7 +165,7 @@ def save_model_to_weights_file(weights_file, model):
' {:s} -> {:s} (preserved)'.format( ' {:s} -> {:s} (preserved)'.format(
scoped_name, unscoped_name)) scoped_name, unscoped_name))
blobs[unscoped_name] = workspace.FetchBlob(scoped_name) blobs[unscoped_name] = workspace.FetchBlob(scoped_name)
cfg_yaml = yaml.dump(cfg) cfg_yaml = envu.yaml_dump(cfg)
save_object(dict(blobs=blobs, cfg=cfg_yaml), weights_file) save_object(dict(blobs=blobs, cfg=cfg_yaml), weights_file)
......
...@@ -24,13 +24,13 @@ from __future__ import print_function ...@@ -24,13 +24,13 @@ from __future__ import print_function
from __future__ import unicode_literals from __future__ import unicode_literals
import os import os
import yaml
import numpy as np import numpy as np
import subprocess import subprocess
from six.moves import shlex_quote from six.moves import shlex_quote
from detectron.core.config import cfg from detectron.core.config import cfg
from detectron.utils.io import load_object from detectron.utils.io import load_object
import detectron.utils.env as envu
import logging import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -47,7 +47,7 @@ def process_in_parallel( ...@@ -47,7 +47,7 @@ def process_in_parallel(
# subprocesses # subprocesses
cfg_file = os.path.join(output_dir, '{}_range_config.yaml'.format(tag)) cfg_file = os.path.join(output_dir, '{}_range_config.yaml'.format(tag))
with open(cfg_file, 'w') as f: with open(cfg_file, 'w') as f:
yaml.dump(cfg, stream=f) envu.yaml_dump(cfg, stream=f)
subprocess_env = os.environ.copy() subprocess_env = os.environ.copy()
processes = [] processes = []
subinds = np.array_split(range(total_range_size), cfg.NUM_GPUS) subinds = np.array_split(range(total_range_size), cfg.NUM_GPUS)
......
...@@ -32,7 +32,6 @@ import cv2 # NOQA (Must import before importing caffe2 due to bug in cv2) ...@@ -32,7 +32,6 @@ import cv2 # NOQA (Must import before importing caffe2 due to bug in cv2)
import logging import logging
import os import os
import sys import sys
import yaml
from caffe2.python import workspace from caffe2.python import workspace
...@@ -47,6 +46,7 @@ import detectron.core.rpn_generator as rpn_engine ...@@ -47,6 +46,7 @@ import detectron.core.rpn_generator as rpn_engine
import detectron.core.test_engine as model_engine import detectron.core.test_engine as model_engine
import detectron.datasets.dummy_datasets as dummy_datasets import detectron.datasets.dummy_datasets as dummy_datasets
import detectron.utils.c2 as c2_utils import detectron.utils.c2 as c2_utils
import detectron.utils.env as envu
import detectron.utils.vis as vis_utils import detectron.utils.vis as vis_utils
c2_utils.import_detectron_ops() c2_utils.import_detectron_ops()
...@@ -119,7 +119,7 @@ def get_rpn_box_proposals(im, args): ...@@ -119,7 +119,7 @@ def get_rpn_box_proposals(im, args):
def main(args): def main(args):
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
dummy_coco_dataset = dummy_datasets.get_coco_dataset() dummy_coco_dataset = dummy_datasets.get_coco_dataset()
cfg_orig = load_cfg(yaml.dump(cfg)) cfg_orig = load_cfg(envu.yaml_dump(cfg))
im = cv2.imread(args.im_file) im = cv2.imread(args.im_file)
if args.rpn_pkl is not None: if args.rpn_pkl is not None:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment