Commit 08fcf12f authored by Simon Layton's avatar Simon Layton Committed by Francisco Massa

Initial mixed-precision training (#196)

* Initial multi-precision training

Adds fp16 support via apex.amp
Also switches communication to apex.DistributedDataParallel

* Add Apex install to dockerfile

* Fixes from @fmassa review

Added support to tools/test_net.py
SOLVER.MIXED_PRECISION -> DTYPE \in {float32, float16}
apex.amp not installed now raises ImportError

* Remove extraneous apex DDP import

* Move to new amp API
parent bf043792
......@@ -38,6 +38,12 @@ git clone https://github.com/cocodataset/cocoapi.git
cd cocoapi/PythonAPI
python setup.py build_ext install
# install apex
cd ~github
git clone https://github.com/NVIDIA/apex.git
cd apex
python setup.py install --cuda_ext --cpp_ext
# install PyTorch Detection
cd $INSTALL_DIR
git clone https://github.com/facebookresearch/maskrcnn-benchmark.git
......
......@@ -46,6 +46,11 @@ RUN git clone https://github.com/cocodataset/cocoapi.git \
&& cd cocoapi/PythonAPI \
&& python setup.py build_ext install
# install apex
RUN git clone https://github.com/NVIDIA/apex.git \
&& cd apex \
&& python setup.py install --cuda_ext --cpp_ext
# install PyTorch Detection
ARG FORCE_CUDA="1"
ENV FORCE_CUDA=${FORCE_CUDA}
......
......@@ -431,3 +431,13 @@ _C.TEST.DETECTIONS_PER_IMG = 100
_C.OUTPUT_DIR = "."
_C.PATHS_CATALOG = os.path.join(os.path.dirname(__file__), "paths_catalog.py")
# ---------------------------------------------------------------------------- #
# Precision options
# ---------------------------------------------------------------------------- #
# Precision of input, allowable: (float32, float16)
_C.DTYPE = "float32"
# Enable verbosity in apex.amp
_C.AMP_VERBOSE = False
......@@ -9,6 +9,7 @@ import torch.distributed as dist
from maskrcnn_benchmark.utils.comm import get_world_size
from maskrcnn_benchmark.utils.metric_logger import MetricLogger
from apex import amp
def reduce_loss_dict(loss_dict):
"""
......@@ -73,7 +74,10 @@ def do_train(
meters.update(loss=losses_reduced, **loss_dict_reduced)
optimizer.zero_grad()
losses.backward()
# Note: If mixed precision is not used, this ends up doing nothing
# Otherwise apply loss scaling for mixed-precision recipe
with amp.scale_loss(losses, optimizer) as scaled_losses:
scaled_losses.backward()
optimizer.step()
batch_time = time.time() - end
......
......@@ -17,6 +17,13 @@ class FrozenBatchNorm2d(nn.Module):
self.register_buffer("running_var", torch.ones(n))
def forward(self, x):
# Cast all fixed parameters to half() if necessary
if x.dtype == torch.float16:
self.weight = self.weight.half()
self.bias = self.bias.half()
self.running_mean = self.running_mean.half()
self.running_var = self.running_var.half()
scale = self.weight * self.running_var.rsqrt()
bias = self.bias - self.running_mean * scale
scale = scale.reshape(1, -1, 1, 1)
......
......@@ -2,6 +2,10 @@
# from ._utils import _C
from maskrcnn_benchmark import _C
nms = _C.nms
from apex import amp
# Only valid with fp32 inputs - give AMP the hint
nms = amp.float_function(_C.nms)
# nms.__doc__ = """
# This function performs Non-maximum suppresion"""
......@@ -7,6 +7,7 @@ from torch.nn.modules.utils import _pair
from maskrcnn_benchmark import _C
from apex import amp
class _ROIAlign(Function):
@staticmethod
......@@ -46,7 +47,6 @@ class _ROIAlign(Function):
roi_align = _ROIAlign.apply
class ROIAlign(nn.Module):
def __init__(self, output_size, spatial_scale, sampling_ratio):
super(ROIAlign, self).__init__()
......@@ -54,6 +54,7 @@ class ROIAlign(nn.Module):
self.spatial_scale = spatial_scale
self.sampling_ratio = sampling_ratio
@amp.float_function
def forward(self, input, rois):
return roi_align(
input, rois, self.output_size, self.spatial_scale, self.sampling_ratio
......
......@@ -7,6 +7,7 @@ from torch.nn.modules.utils import _pair
from maskrcnn_benchmark import _C
from apex import amp
class _ROIPool(Function):
@staticmethod
......@@ -52,6 +53,7 @@ class ROIPool(nn.Module):
self.output_size = output_size
self.spatial_scale = spatial_scale
@amp.float_function
def forward(self, input, rois):
return roi_pool(input, rois, self.output_size, self.spatial_scale)
......
......@@ -116,7 +116,7 @@ class Pooler(nn.Module):
for level, (per_level_feature, pooler) in enumerate(zip(x, self.poolers)):
idx_in_level = torch.nonzero(levels == level).squeeze(1)
rois_per_level = rois[idx_in_level]
result[idx_in_level] = pooler(per_level_feature, rois_per_level)
result[idx_in_level] = pooler(per_level_feature, rois_per_level).to(dtype)
return result
......
......@@ -111,11 +111,16 @@ def expand_masks(mask, padding):
pad2 = 2 * padding
scale = float(M + pad2) / M
padded_mask = mask.new_zeros((N, 1, M + pad2, M + pad2))
padded_mask[:, :, padding:-padding, padding:-padding] = mask
return padded_mask, scale
def paste_mask_in_image(mask, box, im_h, im_w, thresh=0.5, padding=1):
# Need to work on the CPU, where fp16 isn't supported - cast to float to avoid this
mask = mask.float()
box = box.float()
padded_mask, scale = expand_masks(mask[None], padding=padding)
mask = padded_mask[0, 0]
box = expand_boxes(box[None], scale)[0]
......
......@@ -17,6 +17,12 @@ from maskrcnn_benchmark.utils.comm import synchronize, get_rank
from maskrcnn_benchmark.utils.logger import setup_logger
from maskrcnn_benchmark.utils.miscellaneous import mkdir
# Check if we can enable mixed-precision via apex.amp
try:
from apex import amp
except ImportError:
raise ImportError('Use APEX for mixed precision via apex.amp')
def main():
parser = argparse.ArgumentParser(description="PyTorch Object Detection Inference")
......@@ -61,6 +67,10 @@ def main():
model = build_detection_model(cfg)
model.to(cfg.MODEL.DEVICE)
# Initialize mixed-precision if necessary
use_mixed_precision = cfg.DTYPE == 'float16'
amp_handle = amp.init(enabled=use_mixed_precision, verbose=cfg.AMP_VERBOSE)
output_dir = cfg.OUTPUT_DIR
checkpointer = DetectronCheckpointer(cfg, model, save_dir=output_dir)
_ = checkpointer.load(cfg.MODEL.WEIGHT)
......
......@@ -25,6 +25,13 @@ from maskrcnn_benchmark.utils.imports import import_file
from maskrcnn_benchmark.utils.logger import setup_logger
from maskrcnn_benchmark.utils.miscellaneous import mkdir
# See if we can use apex.DistributedDataParallel instead of the torch default,
# and enable mixed-precision via apex.amp
try:
from apex import amp
except ImportError:
raise ImportError('Use APEX for multi-precision via apex.amp')
def train(cfg, local_rank, distributed):
model = build_detection_model(cfg)
......@@ -34,6 +41,11 @@ def train(cfg, local_rank, distributed):
optimizer = make_optimizer(cfg, model)
scheduler = make_lr_scheduler(cfg, optimizer)
# Initialize mixed-precision training
use_mixed_precision = cfg.DTYPE == "float16"
amp_opt_level = 'O1' if use_mixed_precision else 'O0'
model, optimizer = amp.initialize(model, optimizer, opt_level=amp_opt_level)
if distributed:
model = torch.nn.parallel.DistributedDataParallel(
model, device_ids=[local_rank], output_device=local_rank,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment