Unverified Commit ca9531b9 authored by Francisco Massa's avatar Francisco Massa Committed by GitHub

Add an option to postprocess masks during inference (#180)

* Add an option to postprocess masks during inference

* Fix COCO evaluation to resize masks ony if needed.

* Fix casting

* Fix minor issues in paste_mask_in_image

* Cast mask to uint8

* Make Masker batch compatible

* Remove warnings and stylistic changes
parent 13555fc3
...@@ -194,6 +194,9 @@ _C.MODEL.ROI_MASK_HEAD.MLP_HEAD_DIM = 1024 ...@@ -194,6 +194,9 @@ _C.MODEL.ROI_MASK_HEAD.MLP_HEAD_DIM = 1024
_C.MODEL.ROI_MASK_HEAD.CONV_LAYERS = (256, 256, 256, 256) _C.MODEL.ROI_MASK_HEAD.CONV_LAYERS = (256, 256, 256, 256)
_C.MODEL.ROI_MASK_HEAD.RESOLUTION = 14 _C.MODEL.ROI_MASK_HEAD.RESOLUTION = 14
_C.MODEL.ROI_MASK_HEAD.SHARE_BOX_FEATURE_EXTRACTOR = True _C.MODEL.ROI_MASK_HEAD.SHARE_BOX_FEATURE_EXTRACTOR = True
# Whether or not resize and translate masks to the input image.
_C.MODEL.ROI_MASK_HEAD.POSTPROCESS_MASKS = False
_C.MODEL.ROI_MASK_HEAD.POSTPROCESS_MASKS_THRESHOLD = 0.5
# ---------------------------------------------------------------------------- # # ---------------------------------------------------------------------------- #
# ResNe[X]t options (ResNets = {ResNet, ResNeXt} # ResNe[X]t options (ResNets = {ResNet, ResNeXt}
......
...@@ -87,8 +87,12 @@ def prepare_for_coco_segmentation(predictions, dataset): ...@@ -87,8 +87,12 @@ def prepare_for_coco_segmentation(predictions, dataset):
image_height = dataset.coco.imgs[original_id]["height"] image_height = dataset.coco.imgs[original_id]["height"]
prediction = prediction.resize((image_width, image_height)) prediction = prediction.resize((image_width, image_height))
masks = prediction.get_field("mask") masks = prediction.get_field("mask")
# t = time.time() # t = time.time()
masks = masker(masks, prediction) # Masker is necessary only if masks haven't been already resized.
if list(masks.shape[-2:]) != [image_height, image_width]:
masks = masker(masks.expand(1, -1, -1, -1, -1), prediction)
masks = masks[0]
# logger.info('Time mask: {}'.format(time.time() - t)) # logger.info('Time mask: {}'.format(time.time() - t))
# prediction = prediction.convert('xywh') # prediction = prediction.convert('xywh')
...@@ -426,6 +430,6 @@ def inference( ...@@ -426,6 +430,6 @@ def inference(
check_expected_results(results, expected_results, expected_results_sigma_tol) check_expected_results(results, expected_results, expected_results_sigma_tol)
if output_folder: if output_folder:
torch.save(results, os.path.join(output_folder, "coco_results.pth")) torch.save(results, os.path.join(output_folder, "coco_results.pth"))
return results, coco_results, predictions return results, coco_results, predictions
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
import numpy as np import numpy as np
import torch import torch
from PIL import Image
from torch import nn from torch import nn
import torch.nn.functional as F
from maskrcnn_benchmark.structures.bounding_box import BoxList from maskrcnn_benchmark.structures.bounding_box import BoxList
...@@ -44,12 +44,12 @@ class MaskPostProcessor(nn.Module): ...@@ -44,12 +44,12 @@ class MaskPostProcessor(nn.Module):
index = torch.arange(num_masks, device=labels.device) index = torch.arange(num_masks, device=labels.device)
mask_prob = mask_prob[index, labels][:, None] mask_prob = mask_prob[index, labels][:, None]
if self.masker:
mask_prob = self.masker(mask_prob, boxes)
boxes_per_image = [len(box) for box in boxes] boxes_per_image = [len(box) for box in boxes]
mask_prob = mask_prob.split(boxes_per_image, dim=0) mask_prob = mask_prob.split(boxes_per_image, dim=0)
if self.masker:
mask_prob = self.masker(mask_prob, boxes)
results = [] results = []
for prob, box in zip(mask_prob, boxes): for prob, box in zip(mask_prob, boxes):
bbox = BoxList(box.bbox, box.size, mode="xyxy") bbox = BoxList(box.bbox, box.size, mode="xyxy")
...@@ -119,7 +119,7 @@ def paste_mask_in_image(mask, box, im_h, im_w, thresh=0.5, padding=1): ...@@ -119,7 +119,7 @@ def paste_mask_in_image(mask, box, im_h, im_w, thresh=0.5, padding=1):
padded_mask, scale = expand_masks(mask[None], padding=padding) padded_mask, scale = expand_masks(mask[None], padding=padding)
mask = padded_mask[0, 0] mask = padded_mask[0, 0]
box = expand_boxes(box[None], scale)[0] box = expand_boxes(box[None], scale)[0]
box = box.numpy().astype(np.int32) box = box.to(dtype=torch.int32)
TO_REMOVE = 1 TO_REMOVE = 1
w = box[2] - box[0] + TO_REMOVE w = box[2] - box[0] + TO_REMOVE
...@@ -127,17 +127,20 @@ def paste_mask_in_image(mask, box, im_h, im_w, thresh=0.5, padding=1): ...@@ -127,17 +127,20 @@ def paste_mask_in_image(mask, box, im_h, im_w, thresh=0.5, padding=1):
w = max(w, 1) w = max(w, 1)
h = max(h, 1) h = max(h, 1)
mask = Image.fromarray(mask.cpu().numpy()) # Set shape to [batchxCxHxW]
mask = mask.resize((w, h), resample=Image.BILINEAR) mask = mask.expand((1, 1, -1, -1))
mask = np.array(mask, copy=False)
# Resize mask
mask = mask.to(torch.float32)
mask = F.interpolate(mask, size=(h, w), mode='bilinear', align_corners=False)
mask = mask[0][0]
if thresh >= 0: if thresh >= 0:
mask = np.array(mask > thresh, dtype=np.uint8) mask = mask > thresh
mask = torch.from_numpy(mask)
else: else:
# for visualization and debugging, we also # for visualization and debugging, we also
# allow it to return an unmodified mask # allow it to return an unmodified mask
mask = torch.from_numpy(mask * 255).to(torch.uint8) mask = (mask * 255).to(torch.uint8)
im_mask = torch.zeros((im_h, im_w), dtype=torch.uint8) im_mask = torch.zeros((im_h, im_w), dtype=torch.uint8)
x_0 = max(box[0], 0) x_0 = max(box[0], 0)
...@@ -175,15 +178,27 @@ class Masker(object): ...@@ -175,15 +178,27 @@ class Masker(object):
return res return res
def __call__(self, masks, boxes): def __call__(self, masks, boxes):
# TODO do this properly
if isinstance(boxes, BoxList): if isinstance(boxes, BoxList):
boxes = [boxes] boxes = [boxes]
assert len(boxes) == 1, "Only single image batch supported"
result = self.forward_single_image(masks, boxes[0]) # Make some sanity check
return result assert len(boxes) == len(masks), "Masks and boxes should have the same length."
# TODO: Is this JIT compatible?
# If not we should make it compatible.
results = []
for mask, box in zip(masks, boxes):
assert mask.shape[0] == len(box), "Number of objects should be the same."
result = self.forward_single_image(mask, box)
results.append(result)
return results
def make_roi_mask_post_processor(cfg): def make_roi_mask_post_processor(cfg):
masker = None if cfg.MODEL.ROI_MASK_HEAD.POSTPROCESS_MASKS:
mask_threshold = cfg.MODEL.ROI_MASK_HEAD.POSTPROCESS_MASKS_THRESHOLD
masker = Masker(threshold=mask_threshold, padding=1)
else:
masker = None
mask_post_processor = MaskPostProcessor(masker) mask_post_processor = MaskPostProcessor(masker)
return mask_post_processor return mask_post_processor
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment