Commit 83849b4e authored by wat3rBro's avatar wat3rBro Committed by Francisco Massa

Add registry for model builder functions (#153)

* adding registry to hook custom building blocks

* adding customizable rpn head

* support customizable c2 weight loading
parent 1276d20b
...@@ -134,6 +134,8 @@ _C.MODEL.RPN.MIN_SIZE = 0 ...@@ -134,6 +134,8 @@ _C.MODEL.RPN.MIN_SIZE = 0
# all FPN levels # all FPN levels
_C.MODEL.RPN.FPN_POST_NMS_TOP_N_TRAIN = 2000 _C.MODEL.RPN.FPN_POST_NMS_TOP_N_TRAIN = 2000
_C.MODEL.RPN.FPN_POST_NMS_TOP_N_TEST = 2000 _C.MODEL.RPN.FPN_POST_NMS_TOP_N_TEST = 2000
# Custom rpn head, empty to use default conv or separable conv
_C.MODEL.RPN.RPN_HEAD = "SingleConvRPNHead"
# ---------------------------------------------------------------------------- # # ---------------------------------------------------------------------------- #
......
...@@ -3,16 +3,21 @@ from collections import OrderedDict ...@@ -3,16 +3,21 @@ from collections import OrderedDict
from torch import nn from torch import nn
from maskrcnn_benchmark.modeling import registry
from . import fpn as fpn_module from . import fpn as fpn_module
from . import resnet from . import resnet
@registry.BACKBONES.register("R-50-C4")
def build_resnet_backbone(cfg): def build_resnet_backbone(cfg):
body = resnet.ResNet(cfg) body = resnet.ResNet(cfg)
model = nn.Sequential(OrderedDict([("body", body)])) model = nn.Sequential(OrderedDict([("body", body)]))
return model return model
@registry.BACKBONES.register("R-50-FPN")
@registry.BACKBONES.register("R-101-FPN")
def build_resnet_fpn_backbone(cfg): def build_resnet_fpn_backbone(cfg):
body = resnet.ResNet(cfg) body = resnet.ResNet(cfg)
in_channels_stage2 = cfg.MODEL.RESNETS.RES2_OUT_CHANNELS in_channels_stage2 = cfg.MODEL.RESNETS.RES2_OUT_CHANNELS
...@@ -31,14 +36,9 @@ def build_resnet_fpn_backbone(cfg): ...@@ -31,14 +36,9 @@ def build_resnet_fpn_backbone(cfg):
return model return model
_BACKBONES = {"resnet": build_resnet_backbone, "resnet-fpn": build_resnet_fpn_backbone}
def build_backbone(cfg): def build_backbone(cfg):
assert cfg.MODEL.BACKBONE.CONV_BODY.startswith( assert cfg.MODEL.BACKBONE.CONV_BODY in registry.BACKBONES, \
"R-" "cfg.MODEL.BACKBONE.CONV_BODY: {} are not registered in registry".format(
), "Only ResNet and ResNeXt models are currently implemented" cfg.MODEL.BACKBONE.CONV_BODY
# Models using FPN end with "-FPN" )
if cfg.MODEL.BACKBONE.CONV_BODY.endswith("-FPN"): return registry.BACKBONES[cfg.MODEL.BACKBONE.CONV_BODY](cfg)
return build_resnet_fpn_backbone(cfg)
return build_resnet_backbone(cfg)
...@@ -18,6 +18,7 @@ from torch import nn ...@@ -18,6 +18,7 @@ from torch import nn
from maskrcnn_benchmark.layers import FrozenBatchNorm2d from maskrcnn_benchmark.layers import FrozenBatchNorm2d
from maskrcnn_benchmark.layers import Conv2d from maskrcnn_benchmark.layers import Conv2d
from maskrcnn_benchmark.utils.registry import Registry
# ResNet stage specification # ResNet stage specification
...@@ -290,30 +291,15 @@ class StemWithFixedBatchNorm(nn.Module): ...@@ -290,30 +291,15 @@ class StemWithFixedBatchNorm(nn.Module):
return x return x
_TRANSFORMATION_MODULES = {"BottleneckWithFixedBatchNorm": BottleneckWithFixedBatchNorm} _TRANSFORMATION_MODULES = Registry({
"BottleneckWithFixedBatchNorm": BottleneckWithFixedBatchNorm
})
_STEM_MODULES = {"StemWithFixedBatchNorm": StemWithFixedBatchNorm} _STEM_MODULES = Registry({"StemWithFixedBatchNorm": StemWithFixedBatchNorm})
_STAGE_SPECS = { _STAGE_SPECS = Registry({
"R-50-C4": ResNet50StagesTo4, "R-50-C4": ResNet50StagesTo4,
"R-50-C5": ResNet50StagesTo5, "R-50-C5": ResNet50StagesTo5,
"R-50-FPN": ResNet50FPNStagesTo5, "R-50-FPN": ResNet50FPNStagesTo5,
"R-101-FPN": ResNet101FPNStagesTo5, "R-101-FPN": ResNet101FPNStagesTo5,
} })
def register_transformation_module(module_name, module):
_register_generic(_TRANSFORMATION_MODULES, module_name, module)
def register_stem_module(module_name, module):
_register_generic(_STEM_MODULES, module_name, module)
def register_stage_spec(stage_spec_name, stage_spec):
_register_generic(_STAGE_SPECS, stage_spec_name, stage_spec)
def _register_generic(module_dict, module_name, module):
assert module_name not in module_dict
module_dict[module_name] = module
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
from maskrcnn_benchmark.utils.registry import Registry
BACKBONES = Registry()
ROI_BOX_FEATURE_EXTRACTORS = Registry()
RPN_HEADS = Registry()
...@@ -2,10 +2,12 @@ ...@@ -2,10 +2,12 @@
from torch import nn from torch import nn
from torch.nn import functional as F from torch.nn import functional as F
from maskrcnn_benchmark.modeling import registry
from maskrcnn_benchmark.modeling.backbone import resnet from maskrcnn_benchmark.modeling.backbone import resnet
from maskrcnn_benchmark.modeling.poolers import Pooler from maskrcnn_benchmark.modeling.poolers import Pooler
@registry.ROI_BOX_FEATURE_EXTRACTORS.register("ResNet50Conv5ROIFeatureExtractor")
class ResNet50Conv5ROIFeatureExtractor(nn.Module): class ResNet50Conv5ROIFeatureExtractor(nn.Module):
def __init__(self, config): def __init__(self, config):
super(ResNet50Conv5ROIFeatureExtractor, self).__init__() super(ResNet50Conv5ROIFeatureExtractor, self).__init__()
...@@ -39,6 +41,7 @@ class ResNet50Conv5ROIFeatureExtractor(nn.Module): ...@@ -39,6 +41,7 @@ class ResNet50Conv5ROIFeatureExtractor(nn.Module):
return x return x
@registry.ROI_BOX_FEATURE_EXTRACTORS.register("FPN2MLPFeatureExtractor")
class FPN2MLPFeatureExtractor(nn.Module): class FPN2MLPFeatureExtractor(nn.Module):
""" """
Heads for FPN for classification Heads for FPN for classification
...@@ -77,12 +80,8 @@ class FPN2MLPFeatureExtractor(nn.Module): ...@@ -77,12 +80,8 @@ class FPN2MLPFeatureExtractor(nn.Module):
return x return x
_ROI_BOX_FEATURE_EXTRACTORS = {
"ResNet50Conv5ROIFeatureExtractor": ResNet50Conv5ROIFeatureExtractor,
"FPN2MLPFeatureExtractor": FPN2MLPFeatureExtractor,
}
def make_roi_box_feature_extractor(cfg): def make_roi_box_feature_extractor(cfg):
func = _ROI_BOX_FEATURE_EXTRACTORS[cfg.MODEL.ROI_BOX_HEAD.FEATURE_EXTRACTOR] func = registry.ROI_BOX_FEATURE_EXTRACTORS[
cfg.MODEL.ROI_BOX_HEAD.FEATURE_EXTRACTOR
]
return func(cfg) return func(cfg)
...@@ -3,20 +3,23 @@ import torch ...@@ -3,20 +3,23 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch import nn from torch import nn
from maskrcnn_benchmark.modeling import registry
from maskrcnn_benchmark.modeling.box_coder import BoxCoder from maskrcnn_benchmark.modeling.box_coder import BoxCoder
from .loss import make_rpn_loss_evaluator from .loss import make_rpn_loss_evaluator
from .anchor_generator import make_anchor_generator from .anchor_generator import make_anchor_generator
from .inference import make_rpn_postprocessor from .inference import make_rpn_postprocessor
@registry.RPN_HEADS.register("SingleConvRPNHead")
class RPNHead(nn.Module): class RPNHead(nn.Module):
""" """
Adds a simple RPN Head with classification and regression heads Adds a simple RPN Head with classification and regression heads
""" """
def __init__(self, in_channels, num_anchors): def __init__(self, cfg, in_channels, num_anchors):
""" """
Arguments: Arguments:
cfg : config
in_channels (int): number of channels of the input feature in_channels (int): number of channels of the input feature
num_anchors (int): number of anchors to be predicted num_anchors (int): number of anchors to be predicted
""" """
...@@ -57,7 +60,10 @@ class RPNModule(torch.nn.Module): ...@@ -57,7 +60,10 @@ class RPNModule(torch.nn.Module):
anchor_generator = make_anchor_generator(cfg) anchor_generator = make_anchor_generator(cfg)
in_channels = cfg.MODEL.BACKBONE.OUT_CHANNELS in_channels = cfg.MODEL.BACKBONE.OUT_CHANNELS
head = RPNHead(in_channels, anchor_generator.num_anchors_per_location()[0]) rpn_head = registry.RPN_HEADS[cfg.MODEL.RPN.RPN_HEAD]
head = rpn_head(
cfg, in_channels, anchor_generator.num_anchors_per_location()[0]
)
rpn_box_coder = BoxCoder(weights=(1.0, 1.0, 1.0, 1.0)) rpn_box_coder = BoxCoder(weights=(1.0, 1.0, 1.0, 1.0))
......
...@@ -6,6 +6,7 @@ from collections import OrderedDict ...@@ -6,6 +6,7 @@ from collections import OrderedDict
import torch import torch
from maskrcnn_benchmark.utils.model_serialization import load_state_dict from maskrcnn_benchmark.utils.model_serialization import load_state_dict
from maskrcnn_benchmark.utils.registry import Registry
def _rename_basic_resnet_weights(layer_keys): def _rename_basic_resnet_weights(layer_keys):
...@@ -135,11 +136,20 @@ _C2_STAGE_NAMES = { ...@@ -135,11 +136,20 @@ _C2_STAGE_NAMES = {
"R-101": ["1.2", "2.3", "3.22", "4.2"], "R-101": ["1.2", "2.3", "3.22", "4.2"],
} }
def load_c2_format(cfg, f): C2_FORMAT_LOADER = Registry()
# TODO make it support other architectures
@C2_FORMAT_LOADER.register("R-50-C4")
@C2_FORMAT_LOADER.register("R-50-FPN")
@C2_FORMAT_LOADER.register("R-101-FPN")
def load_resnet_c2_format(cfg, f):
state_dict = _load_c2_pickled_weights(f) state_dict = _load_c2_pickled_weights(f)
conv_body = cfg.MODEL.BACKBONE.CONV_BODY conv_body = cfg.MODEL.BACKBONE.CONV_BODY
arch = conv_body.replace("-C4", "").replace("-FPN", "") arch = conv_body.replace("-C4", "").replace("-FPN", "")
stages = _C2_STAGE_NAMES[arch] stages = _C2_STAGE_NAMES[arch]
state_dict = _rename_weights_for_resnet(state_dict, stages) state_dict = _rename_weights_for_resnet(state_dict, stages)
return dict(model=state_dict) return dict(model=state_dict)
def load_c2_format(cfg, f):
return C2_FORMAT_LOADER[cfg.MODEL.BACKBONE.CONV_BODY](cfg, f)
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
def _register_generic(module_dict, module_name, module):
assert module_name not in module_dict
module_dict[module_name] = module
class Registry(dict):
'''
A helper class for managing registering modules, it extends a dictionary
and provides a register functions.
Eg. creeting a registry:
some_registry = Registry({"default": default_module})
There're two ways of registering new modules:
1): normal way is just calling register function:
def foo():
...
some_registry.register("foo_module", foo)
2): used as decorator when declaring the module:
@some_registry.register("foo_module")
@some_registry.register("foo_modeul_nickname")
def foo():
...
Access of module is just like using a dictionary, eg:
f = some_registry["foo_modeul"]
'''
def __init__(self, *args, **kwargs):
super(Registry, self).__init__(*args, **kwargs)
def register(self, module_name, module=None):
# used as function call
if module is not None:
_register_generic(self, module_name, module)
return
# used as decorator
def register_fn(fn):
_register_generic(self, module_name, fn)
return fn
return register_fn
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment