Commit b4d54657 authored by Csaba Botos's avatar Csaba Botos Committed by Francisco Massa

Support Binary Mask with transparent SementationMask interface (#473)

* support RLE and binary mask

* do not convert to numpy

* be consistent with Detectron

* delete wrong comment

* [WIP] add tests for segmentation_mask

* update tests

* minor change

* Refactored segmentation_mask.py

* Add unit test for segmentation_mask.py

* Add RLE support for BinaryMaskList

* PEP8 black formatting

* Minor patch

* Use internal  that handles 0 channels

* Fix polygon slicing
parent f917a555
......@@ -80,7 +80,7 @@ class COCODataset(torchvision.datasets.coco.CocoDetection):
target.add_field("labels", classes)
masks = [obj["segmentation"] for obj in anno]
masks = SegmentationMask(masks, img.size)
masks = SegmentationMask(masks, img.size, mode='poly')
target.add_field("masks", masks)
if anno and "keypoints" in anno[0]:
......
......@@ -27,17 +27,15 @@ def project_masks_on_boxes(segmentation_masks, proposals, discretization_size):
assert segmentation_masks.size == proposals.size, "{}, {}".format(
segmentation_masks, proposals
)
# TODO put the proposals on the CPU, as the representation for the
# masks is not efficient GPU-wise (possibly several small tensors for
# representing a single instance mask)
# FIXME: CPU computation bottleneck, this should be parallelized
proposals = proposals.bbox.to(torch.device("cpu"))
for segmentation_mask, proposal in zip(segmentation_masks, proposals):
# crop the masks, resize them to the desired resolution and
# then convert them to the tensor representation,
# instead of the list representation that was used
# then convert them to the tensor representation.
cropped_mask = segmentation_mask.crop(proposal)
scaled_mask = cropped_mask.resize((M, M))
mask = scaled_mask.convert(mode="mask")
mask = scaled_mask.get_mask_tensor()
masks.append(mask)
if len(masks) == 0:
return torch.empty(0, dtype=torch.float32, device=device)
......
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
import cv2
import torch
import numpy as np
from maskrcnn_benchmark.layers.misc import interpolate
import pycocotools.mask as mask_utils
......@@ -8,63 +11,207 @@ FLIP_LEFT_RIGHT = 0
FLIP_TOP_BOTTOM = 1
class Mask(object):
""" ABSTRACT
Segmentations come in either:
1) Binary masks
2) Polygons
Binary masks can be represented in a contiguous array
and operations can be carried out more efficiently,
therefore BinaryMaskList handles them together.
Polygons are handled separately for each instance,
by PolygonInstance and instances are handled by
PolygonList.
SegmentationList is supposed to represent both,
therefore it wraps the functions of BinaryMaskList
and PolygonList to make it transparent.
"""
class BinaryMaskList(object):
"""
This class handles binary masks for all objects in the image
"""
def __init__(self, masks, size):
"""
This class is unfinished and not meant for use yet
It is supposed to contain the mask for an object as
a 2d tensor
Arguments:
masks: Either torch.tensor of [num_instances, H, W]
or list of torch.tensors of [H, W] with num_instances elems,
or RLE (Run Length Encoding) - interpreted as list of dicts,
or BinaryMaskList.
size: absolute image size, width first
After initialization, a hard copy will be made, to leave the
initializing source data intact.
"""
def __init__(self, masks, size, mode):
if isinstance(masks, torch.Tensor):
# The raw data representation is passed as argument
masks = masks.clone()
elif isinstance(masks, (list, tuple)):
if isinstance(masks[0], torch.Tensor):
masks = torch.stack(masks, dim=2).clone()
elif isinstance(masks[0], dict) and "count" in masks[0]:
# RLE interpretation
masks = mask_utils
else:
RuntimeError(
"Type of `masks[0]` could not be interpreted: %s" % type(masks)
)
elif isinstance(masks, BinaryMaskList):
# just hard copy the BinaryMaskList instance's underlying data
masks = masks.masks.clone()
else:
RuntimeError(
"Type of `masks` argument could not be interpreted:%s" % tpye(masks)
)
if len(masks.shape) == 2:
# if only a single instance mask is passed
masks = masks[None]
assert len(masks.shape) == 3
assert masks.shape[1] == size[1], "%s != %s" % (masks.shape[1], size[1])
assert masks.shape[2] == size[0], "%s != %s" % (masks.shape[2], size[0])
self.masks = masks
self.size = size
self.mode = mode
self.size = tuple(size)
def transpose(self, method):
if method not in (FLIP_LEFT_RIGHT, FLIP_TOP_BOTTOM):
raise NotImplementedError(
"Only FLIP_LEFT_RIGHT and FLIP_TOP_BOTTOM implemented"
dim = 1 if method == FLIP_TOP_BOTTOM else 2
flipped_masks = self.masks.flip(dim)
return BinaryMaskList(flipped_masks, self.size)
def crop(self, box):
assert isinstance(box, (list, tuple, torch.Tensor)), str(type(box))
# box is assumed to be xyxy
current_width, current_height = self.size
xmin, ymin, xmax, ymax = [round(float(b)) for b in box]
assert xmin <= xmax and ymin <= ymax, str(box)
xmin = min(max(xmin, 0), current_width - 1)
ymin = min(max(ymin, 0), current_height - 1)
xmax = min(max(xmax, 0), current_width)
ymax = min(max(ymax, 0), current_height)
xmax = max(xmax, xmin + 1)
ymax = max(ymax, ymin + 1)
width, height = xmax - xmin, ymax - ymin
cropped_masks = self.masks[:, ymin:ymax, xmin:xmax]
cropped_size = width, height
return BinaryMaskList(cropped_masks, cropped_size)
def resize(self, size):
try:
iter(size)
except TypeError:
assert isinstance(size, (int, float))
size = size, size
width, height = map(int, size)
assert width > 0
assert height > 0
# Height comes first here!
resized_masks = torch.nn.functional.interpolate(
input=self.masks[None].float(),
size=(height, width),
mode="bilinear",
align_corners=False,
)[0].type_as(self.masks)
resized_size = width, height
return BinaryMaskList(resized_masks, resized_size)
def convert_to_polygon(self):
contours = self._findContours()
return PolygonList(contours, self.size)
def to(self, *args, **kwargs):
return self
def _findContours(self):
contours = []
masks = self.masks.detach().numpy()
for mask in masks:
mask = cv2.UMat(mask)
contour, hierarchy = cv2.findContours(
mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_TC89_L1
)
width, height = self.size
if method == FLIP_LEFT_RIGHT:
dim = width
idx = 2
elif method == FLIP_TOP_BOTTOM:
dim = height
idx = 1
reshaped_contour = []
for entity in contour:
assert len(entity.shape) == 3
assert entity.shape[1] == 1, "Hierarchical contours are not allowed"
reshaped_contour.append(entity.reshape(-1).tolist())
contours.append(reshaped_contour)
return contours
flip_idx = list(range(dim)[::-1])
flipped_masks = self.masks.index_select(dim, flip_idx)
return Mask(flipped_masks, self.size, self.mode)
def __len__(self):
return len(self.masks)
def crop(self, box):
w, h = box[2] - box[0], box[3] - box[1]
def __getitem__(self, index):
# Probably it can cause some overhead
# but preserves consistency
masks = self.masks[index].clone()
return BinaryMaskList(masks, self.size)
cropped_masks = self.masks[:, box[1] : box[3], box[0] : box[2]]
return Mask(cropped_masks, size=(w, h), mode=self.mode)
def __iter__(self):
return iter(self.masks)
def resize(self, size, *args, **kwargs):
pass
def __repr__(self):
s = self.__class__.__name__ + "("
s += "num_instances={}, ".format(len(self.masks))
s += "image_width={}, ".format(self.size[0])
s += "image_height={})".format(self.size[1])
return s
class Polygons(object):
class PolygonInstance(object):
"""
This class holds a set of polygons that represents a single instance
of an object mask. The object can be represented as a set of
polygons
"""
def __init__(self, polygons, size, mode):
# assert isinstance(polygons, list), '{}'.format(polygons)
if isinstance(polygons, list):
polygons = [torch.as_tensor(p, dtype=torch.float32) for p in polygons]
elif isinstance(polygons, Polygons):
polygons = polygons.polygons
def __init__(self, polygons, size):
"""
Arguments:
a list of lists of numbers.
The first level refers to all the polygons that compose the
object, and the second level to the polygon coordinates.
"""
if isinstance(polygons, (list, tuple)):
valid_polygons = []
for p in polygons:
p = torch.as_tensor(p, dtype=torch.float32)
if len(p) >= 6: # 3 * 2 coordinates
valid_polygons.append(p)
polygons = valid_polygons
elif isinstance(polygons, PolygonInstance):
polygons = polygons.polygons.copy()
else:
RuntimeError(
"Type of argument `polygons` is not allowed:%s" % (type(polygons))
)
""" This crashes the training way too many times...
for p in polygons:
assert p[::2].min() >= 0
assert p[::2].max() < size[0]
assert p[1::2].min() >= 0
assert p[1::2].max() , size[1]
"""
self.polygons = polygons
self.size = size
self.mode = mode
self.size = tuple(size)
def transpose(self, method):
if method not in (FLIP_LEFT_RIGHT, FLIP_TOP_BOTTOM):
......@@ -87,30 +234,49 @@ class Polygons(object):
p[idx::2] = dim - poly[idx::2] - TO_REMOVE
flipped_polygons.append(p)
return Polygons(flipped_polygons, size=self.size, mode=self.mode)
return PolygonInstance(flipped_polygons, size=self.size)
def crop(self, box):
w, h = box[2] - box[0], box[3] - box[1]
assert isinstance(box, (list, tuple, torch.Tensor)), str(type(box))
# box is assumed to be xyxy
current_width, current_height = self.size
xmin, ymin, xmax, ymax = map(float, box)
assert xmin <= xmax and ymin <= ymax, str(box)
xmin = min(max(xmin, 0), current_width - 1)
ymin = min(max(ymin, 0), current_height - 1)
xmax = min(max(xmax, 0), current_width)
ymax = min(max(ymax, 0), current_height)
# TODO chck if necessary
w = max(w, 1)
h = max(h, 1)
xmax = max(xmax, xmin + 1)
ymax = max(ymax, ymin + 1)
w, h = xmax - xmin, ymax - ymin
cropped_polygons = []
for poly in self.polygons:
p = poly.clone()
p[0::2] = p[0::2] - box[0] # .clamp(min=0, max=w)
p[1::2] = p[1::2] - box[1] # .clamp(min=0, max=h)
p[0::2] = p[0::2] - xmin # .clamp(min=0, max=w)
p[1::2] = p[1::2] - ymin # .clamp(min=0, max=h)
cropped_polygons.append(p)
return Polygons(cropped_polygons, size=(w, h), mode=self.mode)
return PolygonInstance(cropped_polygons, size=(w, h))
def resize(self, size):
try:
iter(size)
except TypeError:
assert isinstance(size, (int, float))
size = size, size
def resize(self, size, *args, **kwargs):
ratios = tuple(float(s) / float(s_orig) for s, s_orig in zip(size, self.size))
if ratios[0] == ratios[1]:
ratio = ratios[0]
scaled_polys = [p * ratio for p in self.polygons]
return Polygons(scaled_polys, size, mode=self.mode)
return PolygonInstance(scaled_polys, size)
ratio_w, ratio_h = ratios
scaled_polygons = []
......@@ -120,47 +286,82 @@ class Polygons(object):
p[1::2] *= ratio_h
scaled_polygons.append(p)
return Polygons(scaled_polygons, size=size, mode=self.mode)
return PolygonInstance(scaled_polygons, size=size)
def convert(self, mode):
def convert_to_binarymask(self):
width, height = self.size
if mode == "mask":
rles = mask_utils.frPyObjects(
[p.numpy() for p in self.polygons], height, width
)
# formatting for COCO PythonAPI
polygons = [p.numpy() for p in self.polygons]
rles = mask_utils.frPyObjects(polygons, height, width)
rle = mask_utils.merge(rles)
mask = mask_utils.decode(rle)
mask = torch.from_numpy(mask)
# TODO add squeeze?
return mask
def __len__(self):
return len(self.polygons)
def __repr__(self):
s = self.__class__.__name__ + "("
s += "num_polygons={}, ".format(len(self.polygons))
s += "num_groups={}, ".format(len(self.polygons))
s += "image_width={}, ".format(self.size[0])
s += "image_height={}, ".format(self.size[1])
s += "mode={})".format(self.mode)
return s
class SegmentationMask(object):
class PolygonList(object):
"""
This class stores the segmentations for all objects in the image
This class handles PolygonInstances for all objects in the image
"""
def __init__(self, polygons, size, mode=None):
def __init__(self, polygons, size):
"""
Arguments:
polygons: a list of list of lists of numbers. The first
polygons:
a list of list of lists of numbers. The first
level of the list correspond to individual instances,
the second level to all the polygons that compose the
object, and the third level to the polygon coordinates.
OR
a list of PolygonInstances.
OR
a PolygonList
size: absolute image size
"""
assert isinstance(polygons, list)
if isinstance(polygons, (list, tuple)):
if len(polygons) == 0:
polygons = [[[]]]
if isinstance(polygons[0], (list, tuple)):
assert isinstance(polygons[0][0], (list, tuple)), str(
type(polygons[0][0])
)
else:
assert isinstance(polygons[0], PolygonInstance), str(type(polygons[0]))
self.polygons = [Polygons(p, size, mode) for p in polygons]
self.size = size
self.mode = mode
elif isinstance(polygons, PolygonList):
size = polygons.size
polygons = polygons.polygons
else:
RuntimeError(
"Type of argument `polygons` is not allowed:%s" % (type(polygons))
)
assert isinstance(size, (list, tuple)), str(type(size))
self.polygons = []
for p in polygons:
p = PolygonInstance(p, size)
if len(p) > 0:
self.polygons.append(p)
self.size = tuple(size)
def transpose(self, method):
if method not in (FLIP_LEFT_RIGHT, FLIP_TOP_BOTTOM):
......@@ -168,30 +369,49 @@ class SegmentationMask(object):
"Only FLIP_LEFT_RIGHT and FLIP_TOP_BOTTOM implemented"
)
flipped = []
flipped_polygons = []
for polygon in self.polygons:
flipped.append(polygon.transpose(method))
return SegmentationMask(flipped, size=self.size, mode=self.mode)
flipped_polygons.append(polygon.transpose(method))
return PolygonList(flipped_polygons, size=self.size)
def crop(self, box):
w, h = box[2] - box[0], box[3] - box[1]
cropped = []
cropped_polygons = []
for polygon in self.polygons:
cropped.append(polygon.crop(box))
return SegmentationMask(cropped, size=(w, h), mode=self.mode)
cropped_polygons.append(polygon.crop(box))
def resize(self, size, *args, **kwargs):
scaled = []
cropped_size = w, h
return PolygonList(cropped_polygons, cropped_size)
def resize(self, size):
resized_polygons = []
for polygon in self.polygons:
scaled.append(polygon.resize(size, *args, **kwargs))
return SegmentationMask(scaled, size=size, mode=self.mode)
resized_polygons.append(polygon.resize(size))
resized_size = size
return PolygonList(resized_polygons, resized_size)
def to(self, *args, **kwargs):
return self
def convert_to_binarymask(self):
if len(self) > 0:
masks = torch.stack([p.convert_to_binarymask() for p in self.polygons])
else:
size = self.size
masks = torch.empty([0, size[1], size[0]], dtype=torch.uint8)
return BinaryMaskList(masks, size=self.size)
def __len__(self):
return len(self.polygons)
def __getitem__(self, item):
if isinstance(item, (int, slice)):
if isinstance(item, int):
selected_polygons = [self.polygons[item]]
elif isinstance(item, slice):
selected_polygons = self.polygons[item]
else:
# advanced indexing on a single dimension
selected_polygons = []
......@@ -201,7 +421,7 @@ class SegmentationMask(object):
item = item.tolist()
for i in item:
selected_polygons.append(self.polygons[i])
return SegmentationMask(selected_polygons, size=self.size, mode=self.mode)
return PolygonList(selected_polygons, size=self.size)
def __iter__(self):
return iter(self.polygons)
......@@ -212,3 +432,103 @@ class SegmentationMask(object):
s += "image_width={}, ".format(self.size[0])
s += "image_height={})".format(self.size[1])
return s
class SegmentationMask(object):
"""
This class stores the segmentations for all objects in the image.
It wraps BinaryMaskList and PolygonList conveniently.
"""
def __init__(self, instances, size, mode="poly"):
"""
Arguments:
instances: two types
(1) polygon
(2) binary mask
size: (width, height)
mode: 'poly', 'mask'. if mode is 'mask', convert mask of any format to binary mask
"""
assert isinstance(size, (list, tuple))
assert len(size) == 2
if isinstance(size[0], torch.Tensor):
assert isinstance(size[1], torch.Tensor)
size = size[0].item(), size[1].item()
assert isinstance(size[0], (int, float))
assert isinstance(size[1], (int, float))
if mode == "poly":
self.instances = PolygonList(instances, size)
elif mode == "mask":
self.instances = BinaryMaskList(instances, size)
else:
raise NotImplementedError("Unknown mode: %s" % str(mode))
self.mode = mode
self.size = tuple(size)
def transpose(self, method):
flipped_instances = self.instances.transpose(method)
return SegmentationMask(flipped_instances, self.size, self.mode)
def crop(self, box):
cropped_instances = self.instances.crop(box)
cropped_size = cropped_instances.size
return SegmentationMask(cropped_instances, cropped_size, self.mode)
def resize(self, size, *args, **kwargs):
resized_instances = self.instances.resize(size)
resized_size = size
return SegmentationMask(resized_instances, resized_size, self.mode)
def to(self, *args, **kwargs):
return self
def convert(self, mode):
if mode == self.mode:
return self
if mode == "poly":
converted_instances = self.instances.convert_to_polygon()
elif mode == "mask":
converted_instances = self.instances.convert_to_binarymask()
else:
raise NotImplementedError("Unknown mode: %s" % str(mode))
return SegmentationMask(converted_instances, self.size, mode)
def get_mask_tensor(self):
instances = self.instances
if self.mode == "poly":
instances = instances.convert_to_binarymask()
# If there is only 1 instance
return instances.masks.squeeze(0)
def __len__(self):
return len(self.instances)
def __getitem__(self, item):
selected_instances = self.instances.__getitem__(item)
return SegmentationMask(selected_instances, self.size, self.mode)
def __iter__(self):
self.iter_idx = 0
return self
def __next__(self):
if self.iter_idx < self.__len__():
next_segmentation = self.__getitem__(self.iter_idx)
self.iter_idx += 1
return next_segmentation
raise StopIteration
def __repr__(self):
s = self.__class__.__name__ + "("
s += "num_instances={}, ".format(len(self.instances))
s += "image_width={}, ".format(self.size[0])
s += "image_height={}, ".format(self.size[1])
s += "mode={})".format(self.mode)
return s
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
import unittest
import torch
from maskrcnn_benchmark.structures.segmentation_mask import SegmentationMask
class TestSegmentationMask(unittest.TestCase):
def __init__(self, method_name='runTest'):
super(TestSegmentationMask, self).__init__(method_name)
poly = [[[423.0, 306.5, 406.5, 277.0, 400.0, 271.5, 389.5, 277.0,
387.5, 292.0, 384.5, 295.0, 374.5, 220.0, 378.5, 210.0,
391.0, 200.5, 404.0, 199.5, 414.0, 203.5, 425.5, 221.0,
438.5, 297.0, 423.0, 306.5],
[100, 100, 200, 100, 200, 200, 100, 200],
]]
width = 640
height = 480
size = width, height
self.P = SegmentationMask(poly, size, 'poly')
self.M = SegmentationMask(poly, size, 'poly').convert('mask')
def L1(self, A, B):
diff = A.get_mask_tensor() - B.get_mask_tensor()
diff = torch.sum(torch.abs(diff.float())).item()
return diff
def test_convert(self):
M_hat = self.M.convert('poly').convert('mask')
P_hat = self.P.convert('mask').convert('poly')
diff_mask = self.L1(self.M, M_hat)
diff_poly = self.L1(self.P, P_hat)
self.assertTrue(diff_mask == diff_poly)
self.assertTrue(diff_mask <= 8169.)
self.assertTrue(diff_poly <= 8169.)
def test_crop(self):
box = [400, 250, 500, 300] # xyxy
diff = self.L1(self.M.crop(box), self.P.crop(box))
self.assertTrue(diff <= 1.)
def test_resize(self):
new_size = 50, 25
M_hat = self.M.resize(new_size)
P_hat = self.P.resize(new_size)
diff = self.L1(M_hat, P_hat)
self.assertTrue(self.M.size == self.P.size)
self.assertTrue(M_hat.size == P_hat.size)
self.assertTrue(self.M.size != M_hat.size)
self.assertTrue(diff <= 255.)
def test_transpose(self):
FLIP_LEFT_RIGHT = 0
FLIP_TOP_BOTTOM = 1
diff_hor = self.L1(self.M.transpose(FLIP_LEFT_RIGHT),
self.P.transpose(FLIP_LEFT_RIGHT))
diff_ver = self.L1(self.M.transpose(FLIP_TOP_BOTTOM),
self.P.transpose(FLIP_TOP_BOTTOM))
self.assertTrue(diff_hor <= 53250.)
self.assertTrue(diff_ver <= 42494.)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment