Commit b5dcc0fe authored by Ilija Radosavovic's avatar Ilija Radosavovic Committed by Facebook Github Bot

Provide a helper for determing segm mask format

Reviewed By: rbgirshick

Differential Revision: D8836361

fbshipit-source-id: f7d8bad14e0d6309ae349616f2b837d7082e2d74
parent b2d24147
...@@ -43,6 +43,7 @@ from detectron.core.config import cfg ...@@ -43,6 +43,7 @@ from detectron.core.config import cfg
from detectron.utils.timer import Timer from detectron.utils.timer import Timer
import detectron.datasets.dataset_catalog as dataset_catalog import detectron.datasets.dataset_catalog as dataset_catalog
import detectron.utils.boxes as box_utils import detectron.utils.boxes as box_utils
import detectron.utils.segms as segm_utils
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -167,8 +168,8 @@ class JsonDataset(object): ...@@ -167,8 +168,8 @@ class JsonDataset(object):
width = entry['width'] width = entry['width']
height = entry['height'] height = entry['height']
for obj in objs: for obj in objs:
# crowd regions are RLE encoded and stored as dicts # crowd regions are RLE encoded
if isinstance(obj['segmentation'], list): if segm_utils.is_poly(obj['segmentation']):
# Valid polygons have >= 3 points, so require >= 6 coordinates # Valid polygons have >= 3 points, so require >= 6 coordinates
obj['segmentation'] = [ obj['segmentation'] = [
p for p in obj['segmentation'] if len(p) >= 6 p for p in obj['segmentation'] if len(p) >= 6
......
...@@ -31,6 +31,18 @@ import numpy as np ...@@ -31,6 +31,18 @@ import numpy as np
import pycocotools.mask as mask_util import pycocotools.mask as mask_util
# Type used for storing masks in polygon format
_POLY_TYPE = list
# Type used for storing masks in RLE format
_RLE_TYPE = dict
def is_poly(segm):
"""Determine if segm is a polygon. Valid segm expected (polygon or RLE)."""
assert isinstance(segm, (_POLY_TYPE, _RLE_TYPE)), \
'Invalid segm type: {}'.format(type(segm))
return isinstance(segm, _POLY_TYPE)
def flip_segms(segms, height, width): def flip_segms(segms, height, width):
"""Left/right flip each mask in a list of masks.""" """Left/right flip each mask in a list of masks."""
...@@ -51,12 +63,11 @@ def flip_segms(segms, height, width): ...@@ -51,12 +63,11 @@ def flip_segms(segms, height, width):
flipped_segms = [] flipped_segms = []
for segm in segms: for segm in segms:
if type(segm) == list: if is_poly(segm):
# Polygon format # Polygon format
flipped_segms.append([_flip_poly(poly, width) for poly in segm]) flipped_segms.append([_flip_poly(poly, width) for poly in segm])
else: else:
# RLE format # RLE format
assert type(segm) == dict
flipped_segms.append(_flip_rle(segm, height, width)) flipped_segms.append(_flip_rle(segm, height, width))
return flipped_segms return flipped_segms
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment