Unverified Commit ca9531b9 authored by Francisco Massa's avatar Francisco Massa Committed by GitHub

Add an option to postprocess masks during inference (#180)

* Add an option to postprocess masks during inference

* Fix COCO evaluation to resize masks ony if needed.

* Fix casting

* Fix minor issues in paste_mask_in_image

* Cast mask to uint8

* Make Masker batch compatible

* Remove warnings and stylistic changes
parent 13555fc3
......@@ -194,6 +194,9 @@ _C.MODEL.ROI_MASK_HEAD.MLP_HEAD_DIM = 1024
_C.MODEL.ROI_MASK_HEAD.CONV_LAYERS = (256, 256, 256, 256)
_C.MODEL.ROI_MASK_HEAD.RESOLUTION = 14
_C.MODEL.ROI_MASK_HEAD.SHARE_BOX_FEATURE_EXTRACTOR = True
# Whether or not resize and translate masks to the input image.
_C.MODEL.ROI_MASK_HEAD.POSTPROCESS_MASKS = False
_C.MODEL.ROI_MASK_HEAD.POSTPROCESS_MASKS_THRESHOLD = 0.5
# ---------------------------------------------------------------------------- #
# ResNe[X]t options (ResNets = {ResNet, ResNeXt}
......
......@@ -87,8 +87,12 @@ def prepare_for_coco_segmentation(predictions, dataset):
image_height = dataset.coco.imgs[original_id]["height"]
prediction = prediction.resize((image_width, image_height))
masks = prediction.get_field("mask")
# t = time.time()
masks = masker(masks, prediction)
# Masker is necessary only if masks haven't been already resized.
if list(masks.shape[-2:]) != [image_height, image_width]:
masks = masker(masks.expand(1, -1, -1, -1, -1), prediction)
masks = masks[0]
# logger.info('Time mask: {}'.format(time.time() - t))
# prediction = prediction.convert('xywh')
......@@ -426,6 +430,6 @@ def inference(
check_expected_results(results, expected_results, expected_results_sigma_tol)
if output_folder:
torch.save(results, os.path.join(output_folder, "coco_results.pth"))
return results, coco_results, predictions
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
import numpy as np
import torch
from PIL import Image
from torch import nn
import torch.nn.functional as F
from maskrcnn_benchmark.structures.bounding_box import BoxList
......@@ -44,12 +44,12 @@ class MaskPostProcessor(nn.Module):
index = torch.arange(num_masks, device=labels.device)
mask_prob = mask_prob[index, labels][:, None]
if self.masker:
mask_prob = self.masker(mask_prob, boxes)
boxes_per_image = [len(box) for box in boxes]
mask_prob = mask_prob.split(boxes_per_image, dim=0)
if self.masker:
mask_prob = self.masker(mask_prob, boxes)
results = []
for prob, box in zip(mask_prob, boxes):
bbox = BoxList(box.bbox, box.size, mode="xyxy")
......@@ -119,7 +119,7 @@ def paste_mask_in_image(mask, box, im_h, im_w, thresh=0.5, padding=1):
padded_mask, scale = expand_masks(mask[None], padding=padding)
mask = padded_mask[0, 0]
box = expand_boxes(box[None], scale)[0]
box = box.numpy().astype(np.int32)
box = box.to(dtype=torch.int32)
TO_REMOVE = 1
w = box[2] - box[0] + TO_REMOVE
......@@ -127,17 +127,20 @@ def paste_mask_in_image(mask, box, im_h, im_w, thresh=0.5, padding=1):
w = max(w, 1)
h = max(h, 1)
mask = Image.fromarray(mask.cpu().numpy())
mask = mask.resize((w, h), resample=Image.BILINEAR)
mask = np.array(mask, copy=False)
# Set shape to [batchxCxHxW]
mask = mask.expand((1, 1, -1, -1))
# Resize mask
mask = mask.to(torch.float32)
mask = F.interpolate(mask, size=(h, w), mode='bilinear', align_corners=False)
mask = mask[0][0]
if thresh >= 0:
mask = np.array(mask > thresh, dtype=np.uint8)
mask = torch.from_numpy(mask)
mask = mask > thresh
else:
# for visualization and debugging, we also
# allow it to return an unmodified mask
mask = torch.from_numpy(mask * 255).to(torch.uint8)
mask = (mask * 255).to(torch.uint8)
im_mask = torch.zeros((im_h, im_w), dtype=torch.uint8)
x_0 = max(box[0], 0)
......@@ -175,15 +178,27 @@ class Masker(object):
return res
def __call__(self, masks, boxes):
# TODO do this properly
if isinstance(boxes, BoxList):
boxes = [boxes]
assert len(boxes) == 1, "Only single image batch supported"
result = self.forward_single_image(masks, boxes[0])
return result
# Make some sanity check
assert len(boxes) == len(masks), "Masks and boxes should have the same length."
# TODO: Is this JIT compatible?
# If not we should make it compatible.
results = []
for mask, box in zip(masks, boxes):
assert mask.shape[0] == len(box), "Number of objects should be the same."
result = self.forward_single_image(mask, box)
results.append(result)
return results
def make_roi_mask_post_processor(cfg):
masker = None
if cfg.MODEL.ROI_MASK_HEAD.POSTPROCESS_MASKS:
mask_threshold = cfg.MODEL.ROI_MASK_HEAD.POSTPROCESS_MASKS_THRESHOLD
masker = Masker(threshold=mask_threshold, padding=1)
else:
masker = None
mask_post_processor = MaskPostProcessor(masker)
return mask_post_processor
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment