Commit 1589ce09 authored by Tong Xiao's avatar Tong Xiao Committed by Francisco Massa

Registry for RoI Box Predictors (#402)

* Registry for RoI Box Predictors

- Add a registry ROI_BOX_PREDICTOR
- Use the registry in roi_box_predictors.py, replacing the local factory
- Minor changes in structures/bounding_box.py: when copying a box with
fields, check if the field exists
- Minor changes in logger.py: make filename a optional argument with
default value of "log.txt"

* Add Argument skip_missing=False
parent d3fed42a
......@@ -4,4 +4,5 @@ from maskrcnn_benchmark.utils.registry import Registry
BACKBONES = Registry()
ROI_BOX_FEATURE_EXTRACTORS = Registry()
ROI_BOX_PREDICTOR = Registry()
RPN_HEADS = Registry()
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
from maskrcnn_benchmark.modeling import registry
from torch import nn
@registry.ROI_BOX_PREDICTOR.register("FastRCNNPredictor")
class FastRCNNPredictor(nn.Module):
def __init__(self, config, pretrained=None):
super(FastRCNNPredictor, self).__init__()
......@@ -31,6 +33,7 @@ class FastRCNNPredictor(nn.Module):
return cls_logit, bbox_pred
@registry.ROI_BOX_PREDICTOR.register("FPNPredictor")
class FPNPredictor(nn.Module):
def __init__(self, cfg):
super(FPNPredictor, self).__init__()
......@@ -53,12 +56,6 @@ class FPNPredictor(nn.Module):
return scores, bbox_deltas
_ROI_BOX_PREDICTOR = {
"FastRCNNPredictor": FastRCNNPredictor,
"FPNPredictor": FPNPredictor,
}
def make_roi_box_predictor(cfg):
func = _ROI_BOX_PREDICTOR[cfg.MODEL.ROI_BOX_HEAD.PREDICTOR]
func = registry.ROI_BOX_PREDICTOR[cfg.MODEL.ROI_BOX_HEAD.PREDICTOR]
return func(cfg)
......@@ -235,12 +235,15 @@ class BoxList(object):
return area
def copy_with_fields(self, fields):
def copy_with_fields(self, fields, skip_missing=False):
bbox = BoxList(self.bbox, self.size, self.mode)
if not isinstance(fields, (list, tuple)):
fields = [fields]
for field in fields:
if self.has_field(field):
bbox.add_field(field, self.get_field(field))
elif not skip_missing:
raise KeyError("Field '{}' not found in {}".format(field, self))
return bbox
def __repr__(self):
......
......@@ -4,7 +4,7 @@ import os
import sys
def setup_logger(name, save_dir, distributed_rank):
def setup_logger(name, save_dir, distributed_rank, filename="log.txt"):
logger = logging.getLogger(name)
logger.setLevel(logging.DEBUG)
# don't log results for the non-master process
......@@ -17,7 +17,7 @@ def setup_logger(name, save_dir, distributed_rank):
logger.addHandler(ch)
if save_dir:
fh = logging.FileHandler(os.path.join(save_dir, "log.txt"))
fh = logging.FileHandler(os.path.join(save_dir, filename))
fh.setLevel(logging.DEBUG)
fh.setFormatter(formatter)
logger.addHandler(fh)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment