Commit 862347d5 authored by CoinCheung's avatar CoinCheung Committed by Francisco Massa

add color jitter augmentation (#680)

* add color jitter augmentation

* fix spelling
parent 1d6e9add
...@@ -54,6 +54,12 @@ _C.INPUT.PIXEL_STD = [1., 1., 1.] ...@@ -54,6 +54,12 @@ _C.INPUT.PIXEL_STD = [1., 1., 1.]
# Convert image to BGR format (for Caffe2 models), in range 0-255 # Convert image to BGR format (for Caffe2 models), in range 0-255
_C.INPUT.TO_BGR255 = True _C.INPUT.TO_BGR255 = True
# Image ColorJitter
_C.INPUT.BRIGHTNESS = 0.0
_C.INPUT.CONTRAST = 0.0
_C.INPUT.SATURATION = 0.0
_C.INPUT.HUE = 0.0
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Dataset # Dataset
......
...@@ -16,9 +16,16 @@ def build_transforms(cfg, is_train=True): ...@@ -16,9 +16,16 @@ def build_transforms(cfg, is_train=True):
normalize_transform = T.Normalize( normalize_transform = T.Normalize(
mean=cfg.INPUT.PIXEL_MEAN, std=cfg.INPUT.PIXEL_STD, to_bgr255=to_bgr255 mean=cfg.INPUT.PIXEL_MEAN, std=cfg.INPUT.PIXEL_STD, to_bgr255=to_bgr255
) )
color_jitter = T.ColorJitter(
brightness=cfg.INPUT.BRIGHTNESS,
contrast=cfg.INPUT.CONTRAST,
saturation=cfg.INPUT.SATURATION,
hue=cfg.INPUT.HUE,
)
transform = T.Compose( transform = T.Compose(
[ [
color_jitter,
T.Resize(min_size, max_size), T.Resize(min_size, max_size),
T.RandomHorizontalFlip(flip_prob), T.RandomHorizontalFlip(flip_prob),
T.ToTensor(), T.ToTensor(),
......
...@@ -72,6 +72,24 @@ class RandomHorizontalFlip(object): ...@@ -72,6 +72,24 @@ class RandomHorizontalFlip(object):
return image, target return image, target
class ColorJitter(object):
def __init__(self,
brightness=None,
contrast=None,
saturation=None,
hue=None,
):
self.color_jitter = torchvision.transforms.ColorJitter(
brightness=brightness,
contrast=contrast,
saturation=saturation,
hue=hue,)
def __call__(self, image, target):
image = self.color_jitter(image)
return image, target
class ToTensor(object): class ToTensor(object):
def __call__(self, image, target): def __call__(self, image, target):
return F.to_tensor(image), target return F.to_tensor(image), target
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment