Commit eddb1301 authored by Ross Girshick's avatar Ross Girshick Committed by Facebook Github Bot

assert_and_infer_cfg makes cfg immutable by default

Reviewed By: ashwinb

Differential Revision: D7148434

fbshipit-source-id: 4f5e6386d484b0012ed14dbe1445ee8010d2cc26
parent 3f888a7e
......@@ -1013,13 +1013,22 @@ _RENAMED_KEYS = {
}
def assert_and_infer_cfg(cache_urls=True):
def assert_and_infer_cfg(cache_urls=True, make_immutable=True):
"""Call this function in your script after you have finished setting all cfg
values that are necessary (e.g., merging a config from a file, merging
command line config options, etc.). By default, this function will also
mark the global cfg as immutable to prevent changing the global cfg settings
during script execution (which can lead to hard to debug errors or code
that's harder to understand than is necessary).
"""
if __C.MODEL.RPN_ONLY or __C.MODEL.FASTER_RCNN:
__C.RPN.RPN_ON = True
if __C.RPN.RPN_ON or __C.RETINANET.RETINANET_ON:
__C.TEST.PRECOMPUTED_PROPOSALS = False
if cache_urls:
cache_cfg_urls()
if make_immutable:
cfg.immutable(True)
def cache_cfg_urls():
......@@ -1029,10 +1038,10 @@ def cache_cfg_urls():
__C.TRAIN.WEIGHTS = cache_url(__C.TRAIN.WEIGHTS, __C.DOWNLOAD_CACHE)
__C.TEST.WEIGHTS = cache_url(__C.TEST.WEIGHTS, __C.DOWNLOAD_CACHE)
__C.TRAIN.PROPOSAL_FILES = tuple(
[cache_url(f, __C.DOWNLOAD_CACHE) for f in __C.TRAIN.PROPOSAL_FILES]
cache_url(f, __C.DOWNLOAD_CACHE) for f in __C.TRAIN.PROPOSAL_FILES
)
__C.TEST.PROPOSAL_FILES = tuple(
[cache_url(f, __C.DOWNLOAD_CACHE) for f in __C.TEST.PROPOSAL_FILES]
cache_url(f, __C.DOWNLOAD_CACHE) for f in __C.TEST.PROPOSAL_FILES
)
......
......@@ -23,6 +23,12 @@ from __future__ import unicode_literals
class AttrDict(dict):
IMMUTABLE = '__immutable__'
def __init__(self, *args, **kwargs):
super(AttrDict, self).__init__(*args, **kwargs)
self.__dict__[AttrDict.IMMUTABLE] = False
def __getattr__(self, name):
if name in self.__dict__:
return self.__dict__[name]
......@@ -32,7 +38,29 @@ class AttrDict(dict):
raise AttributeError(name)
def __setattr__(self, name, value):
if name in self.__dict__:
self.__dict__[name] = value
if not self.__dict__[AttrDict.IMMUTABLE]:
if name in self.__dict__:
self.__dict__[name] = value
else:
self[name] = value
else:
self[name] = value
raise AttributeError(
'Attempted to set "{}" to "{}", but AttrDict is immutable'.
format(name, value)
)
def immutable(self, is_immutable):
"""Set immutability to is_immutable and recursively apply the setting
to all nested AttrDicts.
"""
self.__dict__[AttrDict.IMMUTABLE] = is_immutable
# Recursively set immutable state
for v in self.__dict__.values():
if isinstance(v, AttrDict):
v.immutable(is_immutable)
for v in self.values():
if isinstance(v, AttrDict):
v.immutable(is_immutable)
def is_immutable(self):
return self.__dict__[AttrDict.IMMUTABLE]
......@@ -253,7 +253,12 @@ def configure_bbox_reg_weights(model, saved_cfg):
'MODEL.BBOX_REG_WEIGHTS was added. Forcing '
'MODEL.BBOX_REG_WEIGHTS = (1., 1., 1., 1.) to ensure '
'correct **inference** behavior.')
# Generally we don't allow modifying the config, but this is a one-off
# hack to support some very old models
is_immutable = cfg.is_immutable()
cfg.immutable(False)
cfg.MODEL.BBOX_REG_WEIGHTS = (1., 1., 1., 1.)
cfg.immutable(is_immutable)
logger.info('New config:')
logger.info(pprint.pformat(cfg))
assert not model.train, (
......
......@@ -63,7 +63,7 @@ def parse_args():
parser.add_argument(
'--num-batches', dest='num_batches',
help='Number of minibatches to run',
default=500, type=int)
default=200, type=int)
parser.add_argument(
'--sleep', dest='sleep_time',
help='Seconds sleep to emulate a network running',
......@@ -150,7 +150,7 @@ def main(opts):
# To inspect:
# blobs = workspace.FetchBlobs(all_blobs)
# from IPython import embed; embed()
logger.info('Shutting down data loader (EnqueueBlob errors are ok)...')
logger.info('Shutting down data loader...')
roi_data_loader.shutdown()
......
......@@ -29,6 +29,41 @@ import core.config
import utils.logging
class TestAttrDict(unittest.TestCase):
def test_immutability(self):
# Top level immutable
a = AttrDict()
a.foo = 0
a.immutable(True)
with self.assertRaises(AttributeError):
a.foo = 1
a.bar = 1
assert a.is_immutable()
assert a.foo == 0
a.immutable(False)
assert not a.is_immutable()
a.foo = 1
assert a.foo == 1
# Recursively immutable
a.level1 = AttrDict()
a.level1.foo = 0
a.level1.level2 = AttrDict()
a.level1.level2.foo = 0
a.immutable(True)
assert a.is_immutable()
with self.assertRaises(AttributeError):
a.level1.level2.foo = 1
a.level1.bar = 1
assert a.level1.level2.foo == 0
# Serialize immutability state
a.immutable(True)
a2 = yaml.load(yaml.dump(a))
assert a.is_immutable()
assert a2.is_immutable()
class TestCfg(unittest.TestCase):
def test_copy_cfg(self):
cfg2 = copy.deepcopy(cfg)
......
......@@ -106,7 +106,7 @@ if __name__ == '__main__':
logger.setLevel(logging.DEBUG)
logging.getLogger('roi_data.loader').setLevel(logging.INFO)
np.random.seed(cfg.RNG_SEED)
assert_and_infer_cfg()
cfg.TRAIN.ASPECT_GROUPING = False
cfg.NUM_GPUS = 2
assert_and_infer_cfg()
unittest.main()
......@@ -97,6 +97,7 @@ def parse_args():
def get_rpn_box_proposals(im, args):
cfg.immutable(False)
merge_cfg_from_file(args.rpn_cfg)
cfg.NUM_GPUS = 1
cfg.MODEL.RPN_ONLY = True
......@@ -125,6 +126,7 @@ def main(args):
for i in range(0, len(args.models_to_run), 2):
pkl = args.models_to_run[i]
yml = args.models_to_run[i + 1]
cfg.immutable(False)
merge_cfg_from_cfg(cfg_orig)
merge_cfg_from_file(yml)
if len(pkl) > 0:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment