Commit eddb1301 authored by Ross Girshick's avatar Ross Girshick Committed by Facebook Github Bot

assert_and_infer_cfg makes cfg immutable by default

Reviewed By: ashwinb

Differential Revision: D7148434

fbshipit-source-id: 4f5e6386d484b0012ed14dbe1445ee8010d2cc26
parent 3f888a7e
...@@ -1013,13 +1013,22 @@ _RENAMED_KEYS = { ...@@ -1013,13 +1013,22 @@ _RENAMED_KEYS = {
} }
def assert_and_infer_cfg(cache_urls=True): def assert_and_infer_cfg(cache_urls=True, make_immutable=True):
"""Call this function in your script after you have finished setting all cfg
values that are necessary (e.g., merging a config from a file, merging
command line config options, etc.). By default, this function will also
mark the global cfg as immutable to prevent changing the global cfg settings
during script execution (which can lead to hard to debug errors or code
that's harder to understand than is necessary).
"""
if __C.MODEL.RPN_ONLY or __C.MODEL.FASTER_RCNN: if __C.MODEL.RPN_ONLY or __C.MODEL.FASTER_RCNN:
__C.RPN.RPN_ON = True __C.RPN.RPN_ON = True
if __C.RPN.RPN_ON or __C.RETINANET.RETINANET_ON: if __C.RPN.RPN_ON or __C.RETINANET.RETINANET_ON:
__C.TEST.PRECOMPUTED_PROPOSALS = False __C.TEST.PRECOMPUTED_PROPOSALS = False
if cache_urls: if cache_urls:
cache_cfg_urls() cache_cfg_urls()
if make_immutable:
cfg.immutable(True)
def cache_cfg_urls(): def cache_cfg_urls():
...@@ -1029,10 +1038,10 @@ def cache_cfg_urls(): ...@@ -1029,10 +1038,10 @@ def cache_cfg_urls():
__C.TRAIN.WEIGHTS = cache_url(__C.TRAIN.WEIGHTS, __C.DOWNLOAD_CACHE) __C.TRAIN.WEIGHTS = cache_url(__C.TRAIN.WEIGHTS, __C.DOWNLOAD_CACHE)
__C.TEST.WEIGHTS = cache_url(__C.TEST.WEIGHTS, __C.DOWNLOAD_CACHE) __C.TEST.WEIGHTS = cache_url(__C.TEST.WEIGHTS, __C.DOWNLOAD_CACHE)
__C.TRAIN.PROPOSAL_FILES = tuple( __C.TRAIN.PROPOSAL_FILES = tuple(
[cache_url(f, __C.DOWNLOAD_CACHE) for f in __C.TRAIN.PROPOSAL_FILES] cache_url(f, __C.DOWNLOAD_CACHE) for f in __C.TRAIN.PROPOSAL_FILES
) )
__C.TEST.PROPOSAL_FILES = tuple( __C.TEST.PROPOSAL_FILES = tuple(
[cache_url(f, __C.DOWNLOAD_CACHE) for f in __C.TEST.PROPOSAL_FILES] cache_url(f, __C.DOWNLOAD_CACHE) for f in __C.TEST.PROPOSAL_FILES
) )
......
...@@ -23,6 +23,12 @@ from __future__ import unicode_literals ...@@ -23,6 +23,12 @@ from __future__ import unicode_literals
class AttrDict(dict): class AttrDict(dict):
IMMUTABLE = '__immutable__'
def __init__(self, *args, **kwargs):
super(AttrDict, self).__init__(*args, **kwargs)
self.__dict__[AttrDict.IMMUTABLE] = False
def __getattr__(self, name): def __getattr__(self, name):
if name in self.__dict__: if name in self.__dict__:
return self.__dict__[name] return self.__dict__[name]
...@@ -32,7 +38,29 @@ class AttrDict(dict): ...@@ -32,7 +38,29 @@ class AttrDict(dict):
raise AttributeError(name) raise AttributeError(name)
def __setattr__(self, name, value): def __setattr__(self, name, value):
if name in self.__dict__: if not self.__dict__[AttrDict.IMMUTABLE]:
self.__dict__[name] = value if name in self.__dict__:
self.__dict__[name] = value
else:
self[name] = value
else: else:
self[name] = value raise AttributeError(
'Attempted to set "{}" to "{}", but AttrDict is immutable'.
format(name, value)
)
def immutable(self, is_immutable):
"""Set immutability to is_immutable and recursively apply the setting
to all nested AttrDicts.
"""
self.__dict__[AttrDict.IMMUTABLE] = is_immutable
# Recursively set immutable state
for v in self.__dict__.values():
if isinstance(v, AttrDict):
v.immutable(is_immutable)
for v in self.values():
if isinstance(v, AttrDict):
v.immutable(is_immutable)
def is_immutable(self):
return self.__dict__[AttrDict.IMMUTABLE]
...@@ -253,7 +253,12 @@ def configure_bbox_reg_weights(model, saved_cfg): ...@@ -253,7 +253,12 @@ def configure_bbox_reg_weights(model, saved_cfg):
'MODEL.BBOX_REG_WEIGHTS was added. Forcing ' 'MODEL.BBOX_REG_WEIGHTS was added. Forcing '
'MODEL.BBOX_REG_WEIGHTS = (1., 1., 1., 1.) to ensure ' 'MODEL.BBOX_REG_WEIGHTS = (1., 1., 1., 1.) to ensure '
'correct **inference** behavior.') 'correct **inference** behavior.')
# Generally we don't allow modifying the config, but this is a one-off
# hack to support some very old models
is_immutable = cfg.is_immutable()
cfg.immutable(False)
cfg.MODEL.BBOX_REG_WEIGHTS = (1., 1., 1., 1.) cfg.MODEL.BBOX_REG_WEIGHTS = (1., 1., 1., 1.)
cfg.immutable(is_immutable)
logger.info('New config:') logger.info('New config:')
logger.info(pprint.pformat(cfg)) logger.info(pprint.pformat(cfg))
assert not model.train, ( assert not model.train, (
......
...@@ -63,7 +63,7 @@ def parse_args(): ...@@ -63,7 +63,7 @@ def parse_args():
parser.add_argument( parser.add_argument(
'--num-batches', dest='num_batches', '--num-batches', dest='num_batches',
help='Number of minibatches to run', help='Number of minibatches to run',
default=500, type=int) default=200, type=int)
parser.add_argument( parser.add_argument(
'--sleep', dest='sleep_time', '--sleep', dest='sleep_time',
help='Seconds sleep to emulate a network running', help='Seconds sleep to emulate a network running',
...@@ -150,7 +150,7 @@ def main(opts): ...@@ -150,7 +150,7 @@ def main(opts):
# To inspect: # To inspect:
# blobs = workspace.FetchBlobs(all_blobs) # blobs = workspace.FetchBlobs(all_blobs)
# from IPython import embed; embed() # from IPython import embed; embed()
logger.info('Shutting down data loader (EnqueueBlob errors are ok)...') logger.info('Shutting down data loader...')
roi_data_loader.shutdown() roi_data_loader.shutdown()
......
...@@ -29,6 +29,41 @@ import core.config ...@@ -29,6 +29,41 @@ import core.config
import utils.logging import utils.logging
class TestAttrDict(unittest.TestCase):
def test_immutability(self):
# Top level immutable
a = AttrDict()
a.foo = 0
a.immutable(True)
with self.assertRaises(AttributeError):
a.foo = 1
a.bar = 1
assert a.is_immutable()
assert a.foo == 0
a.immutable(False)
assert not a.is_immutable()
a.foo = 1
assert a.foo == 1
# Recursively immutable
a.level1 = AttrDict()
a.level1.foo = 0
a.level1.level2 = AttrDict()
a.level1.level2.foo = 0
a.immutable(True)
assert a.is_immutable()
with self.assertRaises(AttributeError):
a.level1.level2.foo = 1
a.level1.bar = 1
assert a.level1.level2.foo == 0
# Serialize immutability state
a.immutable(True)
a2 = yaml.load(yaml.dump(a))
assert a.is_immutable()
assert a2.is_immutable()
class TestCfg(unittest.TestCase): class TestCfg(unittest.TestCase):
def test_copy_cfg(self): def test_copy_cfg(self):
cfg2 = copy.deepcopy(cfg) cfg2 = copy.deepcopy(cfg)
......
...@@ -106,7 +106,7 @@ if __name__ == '__main__': ...@@ -106,7 +106,7 @@ if __name__ == '__main__':
logger.setLevel(logging.DEBUG) logger.setLevel(logging.DEBUG)
logging.getLogger('roi_data.loader').setLevel(logging.INFO) logging.getLogger('roi_data.loader').setLevel(logging.INFO)
np.random.seed(cfg.RNG_SEED) np.random.seed(cfg.RNG_SEED)
assert_and_infer_cfg()
cfg.TRAIN.ASPECT_GROUPING = False cfg.TRAIN.ASPECT_GROUPING = False
cfg.NUM_GPUS = 2 cfg.NUM_GPUS = 2
assert_and_infer_cfg()
unittest.main() unittest.main()
...@@ -97,6 +97,7 @@ def parse_args(): ...@@ -97,6 +97,7 @@ def parse_args():
def get_rpn_box_proposals(im, args): def get_rpn_box_proposals(im, args):
cfg.immutable(False)
merge_cfg_from_file(args.rpn_cfg) merge_cfg_from_file(args.rpn_cfg)
cfg.NUM_GPUS = 1 cfg.NUM_GPUS = 1
cfg.MODEL.RPN_ONLY = True cfg.MODEL.RPN_ONLY = True
...@@ -125,6 +126,7 @@ def main(args): ...@@ -125,6 +126,7 @@ def main(args):
for i in range(0, len(args.models_to_run), 2): for i in range(0, len(args.models_to_run), 2):
pkl = args.models_to_run[i] pkl = args.models_to_run[i]
yml = args.models_to_run[i + 1] yml = args.models_to_run[i + 1]
cfg.immutable(False)
merge_cfg_from_cfg(cfg_orig) merge_cfg_from_cfg(cfg_orig)
merge_cfg_from_file(yml) merge_cfg_from_file(yml)
if len(pkl) > 0: if len(pkl) > 0:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment