Commit 464b1af1 authored by Stzpz's avatar Stzpz Committed by Francisco Massa

Fbnet benchmark (#507)

* Added a timer to benchmark model inference time in addition to total runtime.

* Updated FBNet configs and included some baselines benchmark results.

* Added a unit test for detectors.

* Add links to the models
parent fd204722
......@@ -33,7 +33,28 @@ backbone | type | lr sched | im / gpu | train mem(GB) | train time (s/iter) | to
-- | -- | -- | -- | -- | -- | -- | -- | -- | -- | --
R-50-FPN | Keypoint | 1x | 2 | 5.7 | 0.3771 | 9.4 | 0.10941 | 53.7 | 64.3 | 9981060
### Light-weight Model baselines
We provided pre-trained models for selected FBNet models.
* All the models are trained from scratched with BN using the training schedule specified below.
* Evaluation is performed on a single NVIDIA V100 GPU with `MODEL.RPN.POST_NMS_TOP_N_TEST` set to `200`.
The following inference time is reported:
* inference total batch=8: Total inference time including data loading, model inference and pre/post preprocessing using 8 images per batch.
* inference model batch=8: Model inference time only and using 8 images per batch.
* inference model batch=1: Model inference time only and using 1 image per batch.
* inferenee caffe2 batch=1: Model inference time for the model in Caffe2 format using 1 image per batch. The Caffe2 models fused the BN to Conv and purely run on C++/CUDA by using Caffe2 ops for rpn/detection post processing.
The pre-trained models are available in the link in the model id.
backbone | type | resolution | lr sched | im / gpu | train mem(GB) | train time (s/iter) | total train time (hr) | inference total batch=8 (s/im) | inference model batch=8 (s/im) | inference model batch=1 (s/im) | inference caffe2 batch=1 (s/im) | box AP | mask AP | model id
-- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | --
[R-50-C4](configs/e2e_faster_rcnn_R_50_C4_1x.yaml) (reference) | Fast | 800 | 1x | 1 | 5.8 | 0.4036 | 20.2 | 0.0875 | **0.0793** | 0.0831 | **0.0625** | 34.4 | - | f35857197
[fbnet_chamv1a](configs/e2e_faster_rcnn_fbnet_chamv1a_600.yaml) | Fast | 600 | 0.75x | 12 | 13.6 | 0.5444 | 20.5 | 0.0315 | **0.0260** | 0.0376 | **0.0188** | 33.5 | - | [f100940543](https://download.pytorch.org/models/maskrcnn/e2e_faster_rcnn_fbnet_chamv1a_600.pth)
[fbnet_default](configs/e2e_faster_rcnn_fbnet_600.yaml) | Fast | 600 | 0.5x | 16 | 11.1 | 0.4872 | 12.5 | 0.0316 | **0.0250** | 0.0297 | **0.0130** | 28.2 | - | [f101086388](https://download.pytorch.org/models/maskrcnn/e2e_faster_rcnn_fbnet_600.pth)
[R-50-C4](configs/e2e_mask_rcnn_R_50_C4_1x.yaml) (reference) | Mask | 800 | 1x | 1 | 5.8 | 0.452 | 22.6 | 0.0918 | **0.0848** | 0.0844 | - | 35.2 | 31.0 | f35858791
[fbnet_xirb16d](configs/e2e_mask_rcnn_fbnet_xirb16d_dsmask_600.yaml) | Mask | 600 | 0.5x | 16 | 13.4 | 1.1732 | 29 | 0.0386 | **0.0319** | 0.0356 | - | 30.7 | 26.9 | [f101086394](https://download.pytorch.org/models/maskrcnn/e2e_mask_rcnn_fbnet_xirb16d_dsmask.pth)
[fbnet_default](configs/e2e_mask_rcnn_fbnet_600.yaml) | Mask | 600 | 0.5x | 16 | 13.0 | 0.9036 | 23.0 | 0.0327 | **0.0269** | 0.0385 | - | 29.0 | 26.1 | [f101086385](https://download.pytorch.org/models/maskrcnn/e2e_mask_rcnn_fbnet_600.pth)
## Comparison with Detectron and mmdetection
......
......@@ -15,7 +15,7 @@ MODEL:
PRE_NMS_TOP_N_TRAIN: 6000
PRE_NMS_TOP_N_TEST: 6000
POST_NMS_TOP_N_TRAIN: 2000
POST_NMS_TOP_N_TEST: 1000
POST_NMS_TOP_N_TEST: 100
RPN_HEAD: FBNet.rpn_head
ROI_HEADS:
BATCH_SIZE_PER_IMAGE: 512
......
......@@ -15,7 +15,7 @@ MODEL:
PRE_NMS_TOP_N_TRAIN: 6000
PRE_NMS_TOP_N_TEST: 6000
POST_NMS_TOP_N_TRAIN: 2000
POST_NMS_TOP_N_TEST: 1000
POST_NMS_TOP_N_TEST: 200
RPN_HEAD: FBNet.rpn_head
ROI_HEADS:
BATCH_SIZE_PER_IMAGE: 256
......
MODEL:
META_ARCHITECTURE: "GeneralizedRCNN"
BACKBONE:
CONV_BODY: FBNet
FBNET:
ARCH: "cham_v1a"
BN_TYPE: "bn"
WIDTH_DIVISOR: 8
DW_CONV_SKIP_BN: True
DW_CONV_SKIP_RELU: True
RPN:
ANCHOR_SIZES: (32, 64, 128, 256, 512)
ANCHOR_STRIDE: (16, )
BATCH_SIZE_PER_IMAGE: 256
PRE_NMS_TOP_N_TRAIN: 6000
PRE_NMS_TOP_N_TEST: 6000
POST_NMS_TOP_N_TRAIN: 2000
POST_NMS_TOP_N_TEST: 200
RPN_HEAD: FBNet.rpn_head
ROI_HEADS:
BATCH_SIZE_PER_IMAGE: 128
ROI_BOX_HEAD:
POOLER_RESOLUTION: 6
FEATURE_EXTRACTOR: FBNet.roi_head
NUM_CLASSES: 81
DATASETS:
TRAIN: ("coco_2014_train", "coco_2014_valminusminival")
TEST: ("coco_2014_minival",)
SOLVER:
BASE_LR: 0.045
WARMUP_FACTOR: 0.1
WEIGHT_DECAY: 0.0001
STEPS: (90000, 120000)
MAX_ITER: 135000
IMS_PER_BATCH: 96 # for 8GPUs
# TEST:
# IMS_PER_BATCH: 8
INPUT:
MIN_SIZE_TRAIN: (600, )
MAX_SIZE_TRAIN: 1000
MIN_SIZE_TEST: 600
MAX_SIZE_TEST: 1000
PIXEL_MEAN: [103.53, 116.28, 123.675]
PIXEL_STD: [57.375, 57.12, 58.395]
......@@ -8,7 +8,7 @@ MODEL:
WIDTH_DIVISOR: 8
DW_CONV_SKIP_BN: True
DW_CONV_SKIP_RELU: True
DET_HEAD_LAST_SCALE: -1.0
DET_HEAD_LAST_SCALE: 0.0
RPN:
ANCHOR_SIZES: (16, 32, 64, 128, 256)
ANCHOR_STRIDE: (16, )
......@@ -16,7 +16,7 @@ MODEL:
PRE_NMS_TOP_N_TRAIN: 6000
PRE_NMS_TOP_N_TEST: 6000
POST_NMS_TOP_N_TRAIN: 2000
POST_NMS_TOP_N_TEST: 1000
POST_NMS_TOP_N_TEST: 100
RPN_HEAD: FBNet.rpn_head
ROI_HEADS:
BATCH_SIZE_PER_IMAGE: 256
......
MODEL:
META_ARCHITECTURE: "GeneralizedRCNN"
BACKBONE:
CONV_BODY: FBNet
FBNET:
ARCH: "default"
BN_TYPE: "bn"
WIDTH_DIVISOR: 8
DW_CONV_SKIP_BN: True
DW_CONV_SKIP_RELU: True
DET_HEAD_LAST_SCALE: 0.0
RPN:
ANCHOR_SIZES: (32, 64, 128, 256, 512)
ANCHOR_STRIDE: (16, )
BATCH_SIZE_PER_IMAGE: 256
PRE_NMS_TOP_N_TRAIN: 6000
PRE_NMS_TOP_N_TEST: 6000
POST_NMS_TOP_N_TRAIN: 2000
POST_NMS_TOP_N_TEST: 200
RPN_HEAD: FBNet.rpn_head
ROI_HEADS:
BATCH_SIZE_PER_IMAGE: 256
ROI_BOX_HEAD:
POOLER_RESOLUTION: 6
FEATURE_EXTRACTOR: FBNet.roi_head
NUM_CLASSES: 81
ROI_MASK_HEAD:
POOLER_RESOLUTION: 6
FEATURE_EXTRACTOR: FBNet.roi_head_mask
PREDICTOR: "MaskRCNNConv1x1Predictor"
RESOLUTION: 12
SHARE_BOX_FEATURE_EXTRACTOR: False
MASK_ON: True
DATASETS:
TRAIN: ("coco_2014_train", "coco_2014_valminusminival")
TEST: ("coco_2014_minival",)
SOLVER:
BASE_LR: 0.06
WARMUP_FACTOR: 0.1
WEIGHT_DECAY: 0.0001
STEPS: (60000, 80000)
MAX_ITER: 90000
IMS_PER_BATCH: 128 # for 8GPUs
# TEST:
# IMS_PER_BATCH: 8
INPUT:
MIN_SIZE_TRAIN: (600, )
MAX_SIZE_TRAIN: 1000
MIN_SIZE_TEST: 600
MAX_SIZE_TEST: 1000
PIXEL_MEAN: [103.53, 116.28, 123.675]
PIXEL_STD: [57.375, 57.12, 58.395]
......@@ -16,7 +16,7 @@ MODEL:
PRE_NMS_TOP_N_TRAIN: 6000
PRE_NMS_TOP_N_TEST: 6000
POST_NMS_TOP_N_TRAIN: 2000
POST_NMS_TOP_N_TEST: 1000
POST_NMS_TOP_N_TEST: 100
RPN_HEAD: FBNet.rpn_head
ROI_HEADS:
BATCH_SIZE_PER_IMAGE: 512
......
MODEL:
META_ARCHITECTURE: "GeneralizedRCNN"
BACKBONE:
CONV_BODY: FBNet
FBNET:
ARCH: "xirb16d_dsmask"
BN_TYPE: "bn"
WIDTH_DIVISOR: 8
DW_CONV_SKIP_BN: True
DW_CONV_SKIP_RELU: True
DET_HEAD_LAST_SCALE: 0.0
RPN:
ANCHOR_SIZES: (32, 64, 128, 256, 512)
ANCHOR_STRIDE: (16, )
BATCH_SIZE_PER_IMAGE: 256
PRE_NMS_TOP_N_TRAIN: 6000
PRE_NMS_TOP_N_TEST: 6000
POST_NMS_TOP_N_TRAIN: 2000
POST_NMS_TOP_N_TEST: 200
RPN_HEAD: FBNet.rpn_head
ROI_HEADS:
BATCH_SIZE_PER_IMAGE: 256
ROI_BOX_HEAD:
POOLER_RESOLUTION: 6
FEATURE_EXTRACTOR: FBNet.roi_head
NUM_CLASSES: 81
ROI_MASK_HEAD:
POOLER_RESOLUTION: 6
FEATURE_EXTRACTOR: FBNet.roi_head_mask
PREDICTOR: "MaskRCNNConv1x1Predictor"
RESOLUTION: 12
SHARE_BOX_FEATURE_EXTRACTOR: False
MASK_ON: True
DATASETS:
TRAIN: ("coco_2014_train", "coco_2014_valminusminival")
TEST: ("coco_2014_minival",)
SOLVER:
BASE_LR: 0.06
WARMUP_FACTOR: 0.1
WEIGHT_DECAY: 0.0001
STEPS: (60000, 80000)
MAX_ITER: 90000
IMS_PER_BATCH: 128 # for 8GPUs
# TEST:
# IMS_PER_BATCH: 8
INPUT:
MIN_SIZE_TRAIN: (600, )
MAX_SIZE_TRAIN: 1000
MIN_SIZE_TEST: 600
MAX_SIZE_TEST: 1000
PIXEL_MEAN: [103.53, 116.28, 123.675]
PIXEL_STD: [57.375, 57.12, 58.395]
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
import datetime
import logging
import time
import os
......@@ -11,17 +10,23 @@ from maskrcnn_benchmark.data.datasets.evaluation import evaluate
from ..utils.comm import is_main_process, get_world_size
from ..utils.comm import all_gather
from ..utils.comm import synchronize
from ..utils.timer import Timer, get_time_str
def compute_on_dataset(model, data_loader, device):
def compute_on_dataset(model, data_loader, device, timer=None):
model.eval()
results_dict = {}
cpu_device = torch.device("cpu")
for i, batch in enumerate(tqdm(data_loader)):
for _, batch in enumerate(tqdm(data_loader)):
images, targets, image_ids = batch
images = images.to(device)
with torch.no_grad():
if timer:
timer.tic()
output = model(images)
if timer:
torch.cuda.synchronize()
timer.toc()
output = [o.to(cpu_device) for o in output]
results_dict.update(
{img_id: result for img_id, result in zip(image_ids, output)}
......@@ -68,17 +73,27 @@ def inference(
logger = logging.getLogger("maskrcnn_benchmark.inference")
dataset = data_loader.dataset
logger.info("Start evaluation on {} dataset({} images).".format(dataset_name, len(dataset)))
start_time = time.time()
predictions = compute_on_dataset(model, data_loader, device)
total_timer = Timer()
inference_timer = Timer()
total_timer.tic()
predictions = compute_on_dataset(model, data_loader, device, inference_timer)
# wait for all processes to complete before measuring the time
synchronize()
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=total_time))
total_time = total_timer.toc()
total_time_str = get_time_str(total_time)
logger.info(
"Total inference time: {} ({} s / img per device, on {} devices)".format(
"Total run time: {} ({} s / img per device, on {} devices)".format(
total_time_str, total_time * num_devices / len(dataset), num_devices
)
)
total_infer_time = get_time_str(inference_timer.total_time)
logger.info(
"Model inference time: {} ({} s / img per device, on {} devices)".format(
total_infer_time,
inference_timer.total_time * num_devices / len(dataset),
num_devices,
)
)
predictions = _accumulate_predictions_from_multiple_gpus(predictions)
if not is_main_process():
......
......@@ -199,13 +199,6 @@ class FBNetROIHead(nn.Module):
("last", last)
]))
# output_blob = builder.add_final_pool(
# # model, output_blob, kernel_size=cfg.FAST_RCNN.ROI_XFORM_RESOLUTION)
# model,
# output_blob,
# kernel_size=int(cfg.FAST_RCNN.ROI_XFORM_RESOLUTION / stride_init),
# )
self.out_channels = builder.last_depth
def forward(self, x, proposals):
......
......@@ -771,6 +771,9 @@ class FBNetBuilder(object):
last_channel = int(self.last_depth * (-channel_scale))
last_channel = self._get_divisible_width(last_channel)
if last_channel == 0:
return nn.Sequential()
dim_in = self.last_depth
ret = ConvBNRelu(
dim_in,
......
......@@ -47,7 +47,7 @@ MODEL_ARCH = {
[[4, 160, 1, 1], [6, 160, 3, 1], [3, 80, 1, -2]],
],
# [c, channel_scale]
"last": [1280, 0.0],
"last": [0, 0.0],
"backbone": [0, 1, 2, 3],
"rpn": [5],
"bbox": [4],
......@@ -91,7 +91,7 @@ MODEL_ARCH = {
[[6, 128, 3, 1]],
],
# [c, channel_scale]
"last": [1280, 0.0],
"last": [0, 0.0],
"backbone": [0, 1, 2, 3],
"rpn": [6],
"bbox": [4],
......@@ -127,9 +127,92 @@ MODEL_ARCH = {
[[6, 160, 3, 1], [6, 320, 1, 1]],
],
# [c, channel_scale]
"last": [1280, 0.0],
"last": [0, 0.0],
"backbone": [0, 1, 2, 3],
"bbox": [4],
},
},
}
MODEL_ARCH_CHAM = {
"cham_v1a": {
"block_op_type": [
# stage 0
["ir_k3"],
# stage 1
["ir_k7"] * 2,
# stage 2
["ir_k3"] * 5,
# stage 3
["ir_k5"] * 7 + ["ir_k3"] * 5,
# stage 4, bbox head
["ir_k3"] * 5,
# stage 5, rpn
["ir_k3"] * 3,
],
"block_cfg": {
"first": [32, 2],
"stages": [
# [t, c, n, s]
# stage 0
[[1, 24, 1, 1]],
# stage 1
[[4, 48, 2, 2]],
# stage 2
[[7, 64, 5, 2]],
# stage 3
[[12, 56, 7, 2], [8, 88, 5, 1]],
# stage 4, bbox head
[[7, 152, 4, 2], [10, 104, 1, 1]],
# stage 5, rpn head
[[8, 88, 3, 1]],
],
# [c, channel_scale]
"last": [0, 0.0],
"backbone": [0, 1, 2, 3],
"rpn": [5],
"bbox": [4],
},
},
"cham_v2": {
"block_op_type": [
# stage 0
["ir_k3"],
# stage 1
["ir_k5"] * 4,
# stage 2
["ir_k7"] * 6,
# stage 3
["ir_k5"] * 3 + ["ir_k3"] * 6,
# stage 4, bbox head
["ir_k3"] * 7,
# stage 5, rpn
["ir_k3"] * 1,
],
"block_cfg": {
"first": [32, 2],
"stages": [
# [t, c, n, s]
# stage 0
[[1, 24, 1, 1]],
# stage 1
[[8, 32, 4, 2]],
# stage 2
[[5, 48, 6, 2]],
# stage 3
[[9, 56, 3, 2], [6, 56, 6, 1]],
# stage 4, bbox head
[[2, 160, 6, 2], [6, 112, 1, 1]],
# stage 5, rpn head
[[6, 56, 1, 1]],
],
# [c, channel_scale]
"last": [0, 0.0],
"backbone": [0, 1, 2, 3],
"rpn": [5],
"bbox": [4],
},
},
}
add_archs(MODEL_ARCH_CHAM)
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
import time
import datetime
class Timer(object):
def __init__(self):
self.reset()
@property
def average_time(self):
return self.total_time / self.calls if self.calls > 0 else 0.0
def tic(self):
# using time.time instead of time.clock because time time.clock
# does not normalize for multithreading
self.start_time = time.time()
def toc(self, average=True):
self.add(time.time() - self.start_time)
if average:
return self.average_time
else:
return self.diff
def add(self, time_diff):
self.diff = time_diff
self.total_time += self.diff
self.calls += 1
def reset(self):
self.total_time = 0.0
self.calls = 0
self.start_time = 0.0
self.diff = 0.0
def avg_time_str(self):
time_str = str(datetime.timedelta(seconds=self.average_time))
return time_str
def get_time_str(time_diff):
time_str = str(datetime.timedelta(seconds=time_diff))
return time_str
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
import unittest
import glob
import os
import copy
import torch
from maskrcnn_benchmark.modeling.detector import build_detection_model
from maskrcnn_benchmark.structures.image_list import to_image_list
import utils
CONFIG_FILES = [
# bbox
"e2e_faster_rcnn_R_50_C4_1x.yaml",
"e2e_faster_rcnn_R_50_FPN_1x.yaml",
"e2e_faster_rcnn_fbnet.yaml",
# mask
"e2e_mask_rcnn_R_50_C4_1x.yaml",
"e2e_mask_rcnn_R_50_FPN_1x.yaml",
"e2e_mask_rcnn_fbnet.yaml",
# keypoints
# TODO: fail to run for random model due to empty head input
# "e2e_keypoint_rcnn_R_50_FPN_1x.yaml",
# gn
"gn_baselines/e2e_faster_rcnn_R_50_FPN_1x_gn.yaml",
# TODO: fail to run for random model due to empty head input
# "gn_baselines/e2e_mask_rcnn_R_50_FPN_Xconv1fc_1x_gn.yaml",
# retinanet
"retinanet/retinanet_R-50-FPN_1x.yaml",
# rpn only
"rpn_R_50_C4_1x.yaml",
"rpn_R_50_FPN_1x.yaml",
]
EXCLUDED_FOLDERS = [
"caffe2",
"quick_schedules",
"pascal_voc",
"cityscapes",
]
TEST_CUDA = torch.cuda.is_available()
def get_config_files(file_list, exclude_folders):
cfg_root_path = utils.get_config_root_path()
if file_list is not None:
files = [os.path.join(cfg_root_path, x) for x in file_list]
else:
files = glob.glob(
os.path.join(cfg_root_path, "./**/*.yaml"), recursive=True)
def _contains(path, exclude_dirs):
return any(x in path for x in exclude_dirs)
if exclude_folders is not None:
files = [x for x in files if not _contains(x, exclude_folders)]
return files
def create_model(cfg, device):
cfg = copy.deepcopy(cfg)
cfg.freeze()
model = build_detection_model(cfg)
model = model.to(device)
return model
def create_random_input(cfg, device):
ret = []
for x in cfg.INPUT.MIN_SIZE_TRAIN:
ret.append(torch.rand(3, x, int(x * 1.2)))
ret = to_image_list(ret, cfg.DATALOADER.SIZE_DIVISIBILITY)
ret = ret.to(device)
return ret
def _test_build_detectors(self, device):
''' Make sure models build '''
cfg_files = get_config_files(None, EXCLUDED_FOLDERS)
self.assertGreater(len(cfg_files), 0)
for cfg_file in cfg_files:
with self.subTest(cfg_file=cfg_file):
print('Testing {}...'.format(cfg_file))
cfg = utils.load_config_from_file(cfg_file)
create_model(cfg, device)
def _test_run_selected_detectors(self, cfg_files, device):
''' Make sure models build and run '''
self.assertGreater(len(cfg_files), 0)
for cfg_file in cfg_files:
with self.subTest(cfg_file=cfg_file):
print('Testing {}...'.format(cfg_file))
cfg = utils.load_config_from_file(cfg_file)
cfg.MODEL.RPN.POST_NMS_TOP_N_TEST = 10
cfg.MODEL.RPN.FPN_POST_NMS_TOP_N_TEST = 10
model = create_model(cfg, device)
inputs = create_random_input(cfg, device)
model.eval()
output = model(inputs)
self.assertEqual(len(output), len(inputs.image_sizes))
class TestDetectors(unittest.TestCase):
def test_build_detectors(self):
''' Make sure models build '''
_test_build_detectors(self, "cpu")
@unittest.skipIf(not TEST_CUDA, "no CUDA detected")
def test_build_detectors_cuda(self):
''' Make sure models build on gpu'''
_test_build_detectors(self, "cuda")
def test_run_selected_detectors(self):
''' Make sure models build and run '''
# run on selected models
cfg_files = get_config_files(CONFIG_FILES, None)
# cfg_files = get_config_files(None, EXCLUDED_FOLDERS)
_test_run_selected_detectors(self, cfg_files, "cpu")
@unittest.skipIf(not TEST_CUDA, "no CUDA detected")
def test_run_selected_detectors_cuda(self):
''' Make sure models build and run on cuda '''
# run on selected models
cfg_files = get_config_files(CONFIG_FILES, None)
# cfg_files = get_config_files(None, EXCLUDED_FOLDERS)
_test_run_selected_detectors(self, cfg_files, "cuda")
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment