Commit b4d54657 authored by Csaba Botos's avatar Csaba Botos Committed by Francisco Massa

Support Binary Mask with transparent SementationMask interface (#473)

* support RLE and binary mask

* do not convert to numpy

* be consistent with Detectron

* delete wrong comment

* [WIP] add tests for segmentation_mask

* update tests

* minor change

* Refactored segmentation_mask.py

* Add unit test for segmentation_mask.py

* Add RLE support for BinaryMaskList

* PEP8 black formatting

* Minor patch

* Use internal  that handles 0 channels

* Fix polygon slicing
parent f917a555
...@@ -80,7 +80,7 @@ class COCODataset(torchvision.datasets.coco.CocoDetection): ...@@ -80,7 +80,7 @@ class COCODataset(torchvision.datasets.coco.CocoDetection):
target.add_field("labels", classes) target.add_field("labels", classes)
masks = [obj["segmentation"] for obj in anno] masks = [obj["segmentation"] for obj in anno]
masks = SegmentationMask(masks, img.size) masks = SegmentationMask(masks, img.size, mode='poly')
target.add_field("masks", masks) target.add_field("masks", masks)
if anno and "keypoints" in anno[0]: if anno and "keypoints" in anno[0]:
......
...@@ -27,17 +27,15 @@ def project_masks_on_boxes(segmentation_masks, proposals, discretization_size): ...@@ -27,17 +27,15 @@ def project_masks_on_boxes(segmentation_masks, proposals, discretization_size):
assert segmentation_masks.size == proposals.size, "{}, {}".format( assert segmentation_masks.size == proposals.size, "{}, {}".format(
segmentation_masks, proposals segmentation_masks, proposals
) )
# TODO put the proposals on the CPU, as the representation for the
# masks is not efficient GPU-wise (possibly several small tensors for # FIXME: CPU computation bottleneck, this should be parallelized
# representing a single instance mask)
proposals = proposals.bbox.to(torch.device("cpu")) proposals = proposals.bbox.to(torch.device("cpu"))
for segmentation_mask, proposal in zip(segmentation_masks, proposals): for segmentation_mask, proposal in zip(segmentation_masks, proposals):
# crop the masks, resize them to the desired resolution and # crop the masks, resize them to the desired resolution and
# then convert them to the tensor representation, # then convert them to the tensor representation.
# instead of the list representation that was used
cropped_mask = segmentation_mask.crop(proposal) cropped_mask = segmentation_mask.crop(proposal)
scaled_mask = cropped_mask.resize((M, M)) scaled_mask = cropped_mask.resize((M, M))
mask = scaled_mask.convert(mode="mask") mask = scaled_mask.get_mask_tensor()
masks.append(mask) masks.append(mask)
if len(masks) == 0: if len(masks) == 0:
return torch.empty(0, dtype=torch.float32, device=device) return torch.empty(0, dtype=torch.float32, device=device)
......
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. import cv2
import torch import torch
import numpy as np
from maskrcnn_benchmark.layers.misc import interpolate
import pycocotools.mask as mask_utils import pycocotools.mask as mask_utils
...@@ -8,63 +11,207 @@ FLIP_LEFT_RIGHT = 0 ...@@ -8,63 +11,207 @@ FLIP_LEFT_RIGHT = 0
FLIP_TOP_BOTTOM = 1 FLIP_TOP_BOTTOM = 1
class Mask(object): """ ABSTRACT
Segmentations come in either:
1) Binary masks
2) Polygons
Binary masks can be represented in a contiguous array
and operations can be carried out more efficiently,
therefore BinaryMaskList handles them together.
Polygons are handled separately for each instance,
by PolygonInstance and instances are handled by
PolygonList.
SegmentationList is supposed to represent both,
therefore it wraps the functions of BinaryMaskList
and PolygonList to make it transparent.
"""
class BinaryMaskList(object):
"""
This class handles binary masks for all objects in the image
"""
def __init__(self, masks, size):
""" """
This class is unfinished and not meant for use yet Arguments:
It is supposed to contain the mask for an object as masks: Either torch.tensor of [num_instances, H, W]
a 2d tensor or list of torch.tensors of [H, W] with num_instances elems,
or RLE (Run Length Encoding) - interpreted as list of dicts,
or BinaryMaskList.
size: absolute image size, width first
After initialization, a hard copy will be made, to leave the
initializing source data intact.
""" """
def __init__(self, masks, size, mode): if isinstance(masks, torch.Tensor):
# The raw data representation is passed as argument
masks = masks.clone()
elif isinstance(masks, (list, tuple)):
if isinstance(masks[0], torch.Tensor):
masks = torch.stack(masks, dim=2).clone()
elif isinstance(masks[0], dict) and "count" in masks[0]:
# RLE interpretation
masks = mask_utils
else:
RuntimeError(
"Type of `masks[0]` could not be interpreted: %s" % type(masks)
)
elif isinstance(masks, BinaryMaskList):
# just hard copy the BinaryMaskList instance's underlying data
masks = masks.masks.clone()
else:
RuntimeError(
"Type of `masks` argument could not be interpreted:%s" % tpye(masks)
)
if len(masks.shape) == 2:
# if only a single instance mask is passed
masks = masks[None]
assert len(masks.shape) == 3
assert masks.shape[1] == size[1], "%s != %s" % (masks.shape[1], size[1])
assert masks.shape[2] == size[0], "%s != %s" % (masks.shape[2], size[0])
self.masks = masks self.masks = masks
self.size = size self.size = tuple(size)
self.mode = mode
def transpose(self, method): def transpose(self, method):
if method not in (FLIP_LEFT_RIGHT, FLIP_TOP_BOTTOM): dim = 1 if method == FLIP_TOP_BOTTOM else 2
raise NotImplementedError( flipped_masks = self.masks.flip(dim)
"Only FLIP_LEFT_RIGHT and FLIP_TOP_BOTTOM implemented" return BinaryMaskList(flipped_masks, self.size)
def crop(self, box):
assert isinstance(box, (list, tuple, torch.Tensor)), str(type(box))
# box is assumed to be xyxy
current_width, current_height = self.size
xmin, ymin, xmax, ymax = [round(float(b)) for b in box]
assert xmin <= xmax and ymin <= ymax, str(box)
xmin = min(max(xmin, 0), current_width - 1)
ymin = min(max(ymin, 0), current_height - 1)
xmax = min(max(xmax, 0), current_width)
ymax = min(max(ymax, 0), current_height)
xmax = max(xmax, xmin + 1)
ymax = max(ymax, ymin + 1)
width, height = xmax - xmin, ymax - ymin
cropped_masks = self.masks[:, ymin:ymax, xmin:xmax]
cropped_size = width, height
return BinaryMaskList(cropped_masks, cropped_size)
def resize(self, size):
try:
iter(size)
except TypeError:
assert isinstance(size, (int, float))
size = size, size
width, height = map(int, size)
assert width > 0
assert height > 0
# Height comes first here!
resized_masks = torch.nn.functional.interpolate(
input=self.masks[None].float(),
size=(height, width),
mode="bilinear",
align_corners=False,
)[0].type_as(self.masks)
resized_size = width, height
return BinaryMaskList(resized_masks, resized_size)
def convert_to_polygon(self):
contours = self._findContours()
return PolygonList(contours, self.size)
def to(self, *args, **kwargs):
return self
def _findContours(self):
contours = []
masks = self.masks.detach().numpy()
for mask in masks:
mask = cv2.UMat(mask)
contour, hierarchy = cv2.findContours(
mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_TC89_L1
) )
width, height = self.size reshaped_contour = []
if method == FLIP_LEFT_RIGHT: for entity in contour:
dim = width assert len(entity.shape) == 3
idx = 2 assert entity.shape[1] == 1, "Hierarchical contours are not allowed"
elif method == FLIP_TOP_BOTTOM: reshaped_contour.append(entity.reshape(-1).tolist())
dim = height contours.append(reshaped_contour)
idx = 1 return contours
flip_idx = list(range(dim)[::-1]) def __len__(self):
flipped_masks = self.masks.index_select(dim, flip_idx) return len(self.masks)
return Mask(flipped_masks, self.size, self.mode)
def crop(self, box): def __getitem__(self, index):
w, h = box[2] - box[0], box[3] - box[1] # Probably it can cause some overhead
# but preserves consistency
masks = self.masks[index].clone()
return BinaryMaskList(masks, self.size)
cropped_masks = self.masks[:, box[1] : box[3], box[0] : box[2]] def __iter__(self):
return Mask(cropped_masks, size=(w, h), mode=self.mode) return iter(self.masks)
def resize(self, size, *args, **kwargs): def __repr__(self):
pass s = self.__class__.__name__ + "("
s += "num_instances={}, ".format(len(self.masks))
s += "image_width={}, ".format(self.size[0])
s += "image_height={})".format(self.size[1])
return s
class Polygons(object): class PolygonInstance(object):
""" """
This class holds a set of polygons that represents a single instance This class holds a set of polygons that represents a single instance
of an object mask. The object can be represented as a set of of an object mask. The object can be represented as a set of
polygons polygons
""" """
def __init__(self, polygons, size, mode): def __init__(self, polygons, size):
# assert isinstance(polygons, list), '{}'.format(polygons) """
if isinstance(polygons, list): Arguments:
polygons = [torch.as_tensor(p, dtype=torch.float32) for p in polygons] a list of lists of numbers.
elif isinstance(polygons, Polygons): The first level refers to all the polygons that compose the
polygons = polygons.polygons object, and the second level to the polygon coordinates.
"""
if isinstance(polygons, (list, tuple)):
valid_polygons = []
for p in polygons:
p = torch.as_tensor(p, dtype=torch.float32)
if len(p) >= 6: # 3 * 2 coordinates
valid_polygons.append(p)
polygons = valid_polygons
elif isinstance(polygons, PolygonInstance):
polygons = polygons.polygons.copy()
else:
RuntimeError(
"Type of argument `polygons` is not allowed:%s" % (type(polygons))
)
""" This crashes the training way too many times...
for p in polygons:
assert p[::2].min() >= 0
assert p[::2].max() < size[0]
assert p[1::2].min() >= 0
assert p[1::2].max() , size[1]
"""
self.polygons = polygons self.polygons = polygons
self.size = size self.size = tuple(size)
self.mode = mode
def transpose(self, method): def transpose(self, method):
if method not in (FLIP_LEFT_RIGHT, FLIP_TOP_BOTTOM): if method not in (FLIP_LEFT_RIGHT, FLIP_TOP_BOTTOM):
...@@ -87,30 +234,49 @@ class Polygons(object): ...@@ -87,30 +234,49 @@ class Polygons(object):
p[idx::2] = dim - poly[idx::2] - TO_REMOVE p[idx::2] = dim - poly[idx::2] - TO_REMOVE
flipped_polygons.append(p) flipped_polygons.append(p)
return Polygons(flipped_polygons, size=self.size, mode=self.mode) return PolygonInstance(flipped_polygons, size=self.size)
def crop(self, box): def crop(self, box):
w, h = box[2] - box[0], box[3] - box[1] assert isinstance(box, (list, tuple, torch.Tensor)), str(type(box))
# box is assumed to be xyxy
current_width, current_height = self.size
xmin, ymin, xmax, ymax = map(float, box)
assert xmin <= xmax and ymin <= ymax, str(box)
xmin = min(max(xmin, 0), current_width - 1)
ymin = min(max(ymin, 0), current_height - 1)
xmax = min(max(xmax, 0), current_width)
ymax = min(max(ymax, 0), current_height)
# TODO chck if necessary xmax = max(xmax, xmin + 1)
w = max(w, 1) ymax = max(ymax, ymin + 1)
h = max(h, 1)
w, h = xmax - xmin, ymax - ymin
cropped_polygons = [] cropped_polygons = []
for poly in self.polygons: for poly in self.polygons:
p = poly.clone() p = poly.clone()
p[0::2] = p[0::2] - box[0] # .clamp(min=0, max=w) p[0::2] = p[0::2] - xmin # .clamp(min=0, max=w)
p[1::2] = p[1::2] - box[1] # .clamp(min=0, max=h) p[1::2] = p[1::2] - ymin # .clamp(min=0, max=h)
cropped_polygons.append(p) cropped_polygons.append(p)
return Polygons(cropped_polygons, size=(w, h), mode=self.mode) return PolygonInstance(cropped_polygons, size=(w, h))
def resize(self, size):
try:
iter(size)
except TypeError:
assert isinstance(size, (int, float))
size = size, size
def resize(self, size, *args, **kwargs):
ratios = tuple(float(s) / float(s_orig) for s, s_orig in zip(size, self.size)) ratios = tuple(float(s) / float(s_orig) for s, s_orig in zip(size, self.size))
if ratios[0] == ratios[1]: if ratios[0] == ratios[1]:
ratio = ratios[0] ratio = ratios[0]
scaled_polys = [p * ratio for p in self.polygons] scaled_polys = [p * ratio for p in self.polygons]
return Polygons(scaled_polys, size, mode=self.mode) return PolygonInstance(scaled_polys, size)
ratio_w, ratio_h = ratios ratio_w, ratio_h = ratios
scaled_polygons = [] scaled_polygons = []
...@@ -120,47 +286,82 @@ class Polygons(object): ...@@ -120,47 +286,82 @@ class Polygons(object):
p[1::2] *= ratio_h p[1::2] *= ratio_h
scaled_polygons.append(p) scaled_polygons.append(p)
return Polygons(scaled_polygons, size=size, mode=self.mode) return PolygonInstance(scaled_polygons, size=size)
def convert(self, mode): def convert_to_binarymask(self):
width, height = self.size width, height = self.size
if mode == "mask": # formatting for COCO PythonAPI
rles = mask_utils.frPyObjects( polygons = [p.numpy() for p in self.polygons]
[p.numpy() for p in self.polygons], height, width rles = mask_utils.frPyObjects(polygons, height, width)
)
rle = mask_utils.merge(rles) rle = mask_utils.merge(rles)
mask = mask_utils.decode(rle) mask = mask_utils.decode(rle)
mask = torch.from_numpy(mask) mask = torch.from_numpy(mask)
# TODO add squeeze?
return mask return mask
def __len__(self):
return len(self.polygons)
def __repr__(self): def __repr__(self):
s = self.__class__.__name__ + "(" s = self.__class__.__name__ + "("
s += "num_polygons={}, ".format(len(self.polygons)) s += "num_groups={}, ".format(len(self.polygons))
s += "image_width={}, ".format(self.size[0]) s += "image_width={}, ".format(self.size[0])
s += "image_height={}, ".format(self.size[1]) s += "image_height={}, ".format(self.size[1])
s += "mode={})".format(self.mode)
return s return s
class SegmentationMask(object): class PolygonList(object):
""" """
This class stores the segmentations for all objects in the image This class handles PolygonInstances for all objects in the image
""" """
def __init__(self, polygons, size, mode=None): def __init__(self, polygons, size):
""" """
Arguments: Arguments:
polygons: a list of list of lists of numbers. The first polygons:
a list of list of lists of numbers. The first
level of the list correspond to individual instances, level of the list correspond to individual instances,
the second level to all the polygons that compose the the second level to all the polygons that compose the
object, and the third level to the polygon coordinates. object, and the third level to the polygon coordinates.
OR
a list of PolygonInstances.
OR
a PolygonList
size: absolute image size
""" """
assert isinstance(polygons, list) if isinstance(polygons, (list, tuple)):
if len(polygons) == 0:
polygons = [[[]]]
if isinstance(polygons[0], (list, tuple)):
assert isinstance(polygons[0][0], (list, tuple)), str(
type(polygons[0][0])
)
else:
assert isinstance(polygons[0], PolygonInstance), str(type(polygons[0]))
self.polygons = [Polygons(p, size, mode) for p in polygons] elif isinstance(polygons, PolygonList):
self.size = size size = polygons.size
self.mode = mode polygons = polygons.polygons
else:
RuntimeError(
"Type of argument `polygons` is not allowed:%s" % (type(polygons))
)
assert isinstance(size, (list, tuple)), str(type(size))
self.polygons = []
for p in polygons:
p = PolygonInstance(p, size)
if len(p) > 0:
self.polygons.append(p)
self.size = tuple(size)
def transpose(self, method): def transpose(self, method):
if method not in (FLIP_LEFT_RIGHT, FLIP_TOP_BOTTOM): if method not in (FLIP_LEFT_RIGHT, FLIP_TOP_BOTTOM):
...@@ -168,30 +369,49 @@ class SegmentationMask(object): ...@@ -168,30 +369,49 @@ class SegmentationMask(object):
"Only FLIP_LEFT_RIGHT and FLIP_TOP_BOTTOM implemented" "Only FLIP_LEFT_RIGHT and FLIP_TOP_BOTTOM implemented"
) )
flipped = [] flipped_polygons = []
for polygon in self.polygons: for polygon in self.polygons:
flipped.append(polygon.transpose(method)) flipped_polygons.append(polygon.transpose(method))
return SegmentationMask(flipped, size=self.size, mode=self.mode)
return PolygonList(flipped_polygons, size=self.size)
def crop(self, box): def crop(self, box):
w, h = box[2] - box[0], box[3] - box[1] w, h = box[2] - box[0], box[3] - box[1]
cropped = [] cropped_polygons = []
for polygon in self.polygons: for polygon in self.polygons:
cropped.append(polygon.crop(box)) cropped_polygons.append(polygon.crop(box))
return SegmentationMask(cropped, size=(w, h), mode=self.mode)
def resize(self, size, *args, **kwargs): cropped_size = w, h
scaled = [] return PolygonList(cropped_polygons, cropped_size)
def resize(self, size):
resized_polygons = []
for polygon in self.polygons: for polygon in self.polygons:
scaled.append(polygon.resize(size, *args, **kwargs)) resized_polygons.append(polygon.resize(size))
return SegmentationMask(scaled, size=size, mode=self.mode)
resized_size = size
return PolygonList(resized_polygons, resized_size)
def to(self, *args, **kwargs): def to(self, *args, **kwargs):
return self return self
def convert_to_binarymask(self):
if len(self) > 0:
masks = torch.stack([p.convert_to_binarymask() for p in self.polygons])
else:
size = self.size
masks = torch.empty([0, size[1], size[0]], dtype=torch.uint8)
return BinaryMaskList(masks, size=self.size)
def __len__(self):
return len(self.polygons)
def __getitem__(self, item): def __getitem__(self, item):
if isinstance(item, (int, slice)): if isinstance(item, int):
selected_polygons = [self.polygons[item]] selected_polygons = [self.polygons[item]]
elif isinstance(item, slice):
selected_polygons = self.polygons[item]
else: else:
# advanced indexing on a single dimension # advanced indexing on a single dimension
selected_polygons = [] selected_polygons = []
...@@ -201,7 +421,7 @@ class SegmentationMask(object): ...@@ -201,7 +421,7 @@ class SegmentationMask(object):
item = item.tolist() item = item.tolist()
for i in item: for i in item:
selected_polygons.append(self.polygons[i]) selected_polygons.append(self.polygons[i])
return SegmentationMask(selected_polygons, size=self.size, mode=self.mode) return PolygonList(selected_polygons, size=self.size)
def __iter__(self): def __iter__(self):
return iter(self.polygons) return iter(self.polygons)
...@@ -212,3 +432,103 @@ class SegmentationMask(object): ...@@ -212,3 +432,103 @@ class SegmentationMask(object):
s += "image_width={}, ".format(self.size[0]) s += "image_width={}, ".format(self.size[0])
s += "image_height={})".format(self.size[1]) s += "image_height={})".format(self.size[1])
return s return s
class SegmentationMask(object):
"""
This class stores the segmentations for all objects in the image.
It wraps BinaryMaskList and PolygonList conveniently.
"""
def __init__(self, instances, size, mode="poly"):
"""
Arguments:
instances: two types
(1) polygon
(2) binary mask
size: (width, height)
mode: 'poly', 'mask'. if mode is 'mask', convert mask of any format to binary mask
"""
assert isinstance(size, (list, tuple))
assert len(size) == 2
if isinstance(size[0], torch.Tensor):
assert isinstance(size[1], torch.Tensor)
size = size[0].item(), size[1].item()
assert isinstance(size[0], (int, float))
assert isinstance(size[1], (int, float))
if mode == "poly":
self.instances = PolygonList(instances, size)
elif mode == "mask":
self.instances = BinaryMaskList(instances, size)
else:
raise NotImplementedError("Unknown mode: %s" % str(mode))
self.mode = mode
self.size = tuple(size)
def transpose(self, method):
flipped_instances = self.instances.transpose(method)
return SegmentationMask(flipped_instances, self.size, self.mode)
def crop(self, box):
cropped_instances = self.instances.crop(box)
cropped_size = cropped_instances.size
return SegmentationMask(cropped_instances, cropped_size, self.mode)
def resize(self, size, *args, **kwargs):
resized_instances = self.instances.resize(size)
resized_size = size
return SegmentationMask(resized_instances, resized_size, self.mode)
def to(self, *args, **kwargs):
return self
def convert(self, mode):
if mode == self.mode:
return self
if mode == "poly":
converted_instances = self.instances.convert_to_polygon()
elif mode == "mask":
converted_instances = self.instances.convert_to_binarymask()
else:
raise NotImplementedError("Unknown mode: %s" % str(mode))
return SegmentationMask(converted_instances, self.size, mode)
def get_mask_tensor(self):
instances = self.instances
if self.mode == "poly":
instances = instances.convert_to_binarymask()
# If there is only 1 instance
return instances.masks.squeeze(0)
def __len__(self):
return len(self.instances)
def __getitem__(self, item):
selected_instances = self.instances.__getitem__(item)
return SegmentationMask(selected_instances, self.size, self.mode)
def __iter__(self):
self.iter_idx = 0
return self
def __next__(self):
if self.iter_idx < self.__len__():
next_segmentation = self.__getitem__(self.iter_idx)
self.iter_idx += 1
return next_segmentation
raise StopIteration
def __repr__(self):
s = self.__class__.__name__ + "("
s += "num_instances={}, ".format(len(self.instances))
s += "image_width={}, ".format(self.size[0])
s += "image_height={}, ".format(self.size[1])
s += "mode={})".format(self.mode)
return s
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
import unittest
import torch
from maskrcnn_benchmark.structures.segmentation_mask import SegmentationMask
class TestSegmentationMask(unittest.TestCase):
def __init__(self, method_name='runTest'):
super(TestSegmentationMask, self).__init__(method_name)
poly = [[[423.0, 306.5, 406.5, 277.0, 400.0, 271.5, 389.5, 277.0,
387.5, 292.0, 384.5, 295.0, 374.5, 220.0, 378.5, 210.0,
391.0, 200.5, 404.0, 199.5, 414.0, 203.5, 425.5, 221.0,
438.5, 297.0, 423.0, 306.5],
[100, 100, 200, 100, 200, 200, 100, 200],
]]
width = 640
height = 480
size = width, height
self.P = SegmentationMask(poly, size, 'poly')
self.M = SegmentationMask(poly, size, 'poly').convert('mask')
def L1(self, A, B):
diff = A.get_mask_tensor() - B.get_mask_tensor()
diff = torch.sum(torch.abs(diff.float())).item()
return diff
def test_convert(self):
M_hat = self.M.convert('poly').convert('mask')
P_hat = self.P.convert('mask').convert('poly')
diff_mask = self.L1(self.M, M_hat)
diff_poly = self.L1(self.P, P_hat)
self.assertTrue(diff_mask == diff_poly)
self.assertTrue(diff_mask <= 8169.)
self.assertTrue(diff_poly <= 8169.)
def test_crop(self):
box = [400, 250, 500, 300] # xyxy
diff = self.L1(self.M.crop(box), self.P.crop(box))
self.assertTrue(diff <= 1.)
def test_resize(self):
new_size = 50, 25
M_hat = self.M.resize(new_size)
P_hat = self.P.resize(new_size)
diff = self.L1(M_hat, P_hat)
self.assertTrue(self.M.size == self.P.size)
self.assertTrue(M_hat.size == P_hat.size)
self.assertTrue(self.M.size != M_hat.size)
self.assertTrue(diff <= 255.)
def test_transpose(self):
FLIP_LEFT_RIGHT = 0
FLIP_TOP_BOTTOM = 1
diff_hor = self.L1(self.M.transpose(FLIP_LEFT_RIGHT),
self.P.transpose(FLIP_LEFT_RIGHT))
diff_ver = self.L1(self.M.transpose(FLIP_TOP_BOTTOM),
self.P.transpose(FLIP_TOP_BOTTOM))
self.assertTrue(diff_hor <= 53250.)
self.assertTrue(diff_ver <= 42494.)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment