Commit 862347d5 authored by CoinCheung's avatar CoinCheung Committed by Francisco Massa

add color jitter augmentation (#680)

* add color jitter augmentation

* fix spelling
parent 1d6e9add
......@@ -54,6 +54,12 @@ _C.INPUT.PIXEL_STD = [1., 1., 1.]
# Convert image to BGR format (for Caffe2 models), in range 0-255
_C.INPUT.TO_BGR255 = True
# Image ColorJitter
_C.INPUT.BRIGHTNESS = 0.0
_C.INPUT.CONTRAST = 0.0
_C.INPUT.SATURATION = 0.0
_C.INPUT.HUE = 0.0
# -----------------------------------------------------------------------------
# Dataset
......
......@@ -16,9 +16,16 @@ def build_transforms(cfg, is_train=True):
normalize_transform = T.Normalize(
mean=cfg.INPUT.PIXEL_MEAN, std=cfg.INPUT.PIXEL_STD, to_bgr255=to_bgr255
)
color_jitter = T.ColorJitter(
brightness=cfg.INPUT.BRIGHTNESS,
contrast=cfg.INPUT.CONTRAST,
saturation=cfg.INPUT.SATURATION,
hue=cfg.INPUT.HUE,
)
transform = T.Compose(
[
color_jitter,
T.Resize(min_size, max_size),
T.RandomHorizontalFlip(flip_prob),
T.ToTensor(),
......
......@@ -72,6 +72,24 @@ class RandomHorizontalFlip(object):
return image, target
class ColorJitter(object):
def __init__(self,
brightness=None,
contrast=None,
saturation=None,
hue=None,
):
self.color_jitter = torchvision.transforms.ColorJitter(
brightness=brightness,
contrast=contrast,
saturation=saturation,
hue=hue,)
def __call__(self, image, target):
image = self.color_jitter(image)
return image, target
class ToTensor(object):
def __call__(self, image, target):
return F.to_tensor(image), target
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment