Commit 4db30576 authored by Ross Girshick's avatar Ross Girshick Committed by Facebook Github Bot

Do not mutate cfg.TEST.SCALE, cfg.TEST.MAX_SIZE

Reviewed By: ir413

Differential Revision: D7148427

fbshipit-source-id: 25b755c75ddc59cff7a1dc5a5cddb139c44f7cbf
parent 70e20023
......@@ -43,7 +43,7 @@ TEST:
NMS: 0.5
RPN_PRE_NMS_TOP_N: 1000 # Per FPN level
RPN_POST_NMS_TOP_N: 1000
WEIGHTS: https://s3-us-west-2.amazonaws.com/detectron/35857389/12_2017_baselines/e2e_faster_rcnn_R-50-FPN_2x.yaml.01_37_22.KSeq0b5q/output/train/coco_2014_train%3Acoco_2014_valminusminival/generalized_rcnn/model_final.pkl
WEIGHTS: https://s3-us-west-2.amazonaws.com/detectron/35859007/12_2017_baselines/e2e_mask_rcnn_R-50-FPN_2x.yaml.01_49_07.By8nQcCH/output/train/coco_2014_train:coco_2014_valminusminival/generalized_rcnn/model_final.pkl
# -- Test time augmentation example -- #
BBOX_AUG:
......
......@@ -47,7 +47,7 @@ TEST:
SCALE: 800
MAX_SIZE: 1333
NMS: 0.5
WEIGHTS: https://s3-us-west-2.amazonaws.com/detectron/37651887/12_2017_baselines/keypoint_rcnn_R-50-FPN_s1x.yaml.20_01_40.FDjUQ7VX/output/train/keypoints_coco_2014_train%3Akeypoints_coco_2014_valminusminival/generalized_rcnn/model_final.pkl
WEIGHTS: https://s3-us-west-2.amazonaws.com/detectron/37651887/12_2017_baselines/keypoint_rcnn_R-50-FPN_s1x.yaml.20_01_40.FDjUQ7VX/output/train/keypoints_coco_2014_train:keypoints_coco_2014_valminusminival/generalized_rcnn/model_final.pkl
# -- Test time augmentation example -- #
BBOX_AUG:
......
......@@ -187,7 +187,7 @@ def im_proposals(model, im):
"""Generate RPN proposals on a single image."""
inputs = {}
inputs['data'], im_scale, inputs['im_info'] = \
blob_utils.get_image_blob_for_inference(im)
blob_utils.get_image_blob(im, cfg.TEST.SCALE, cfg.TEST.MAX_SIZE)
for k, v in inputs.items():
workspace.FeedBlob(core.ScopedName(k), v.astype(np.float32, copy=False))
workspace.RunNet(model.net.Proto().name)
......
......@@ -62,7 +62,9 @@ def im_detect_all(model, im, box_proposals, timers=None):
if cfg.TEST.BBOX_AUG.ENABLED:
scores, boxes, im_scale = im_detect_bbox_aug(model, im, box_proposals)
else:
scores, boxes, im_scale = im_detect_bbox(model, im, box_proposals)
scores, boxes, im_scale = im_detect_bbox(
model, im, cfg.TEST.SCALE, cfg.TEST.MAX_SIZE, boxes=box_proposals
)
timers['im_detect_bbox'].toc()
# score and boxes are from the whole image after score thresholding and nms
......@@ -106,15 +108,17 @@ def im_detect_all(model, im, box_proposals, timers=None):
return cls_boxes, cls_segms, cls_keyps
def im_conv_body_only(model, im):
def im_conv_body_only(model, im, target_scale, target_max_size):
"""Runs `model.conv_body_net` on the given image `im`."""
im_blob, im_scale, _im_info = blob_utils.get_image_blob_for_inference(im)
im_blob, im_scale, _im_info = blob_utils.get_image_blob(
im, target_scale, target_max_size
)
workspace.FeedBlob(core.ScopedName('data'), im_blob)
workspace.RunNet(model.conv_body_net.Proto().name)
return im_scale
def im_detect_bbox(model, im, boxes=None):
def im_detect_bbox(model, im, target_scale, target_max_size, boxes=None):
"""Bounding box object detection for an image with given box proposals.
Arguments:
......@@ -130,7 +134,7 @@ def im_detect_bbox(model, im, boxes=None):
im_scales (list): list of image scales used in the input blob (as
returned by _get_blobs and for use with im_detect_mask, etc.)
"""
inputs, im_scale = _get_blobs(im, boxes)
inputs, im_scale = _get_blobs(im, boxes, target_scale, target_max_size)
# When mapping from image ROIs to feature map ROIs, there's some aliasing
# (some distinct image ROIs get mapped to the same feature ROI).
......@@ -217,7 +221,11 @@ def im_detect_bbox_aug(model, im, box_proposals=None):
# Perform detection on the horizontally flipped image
if cfg.TEST.BBOX_AUG.H_FLIP:
scores_hf, boxes_hf, _ = im_detect_bbox_hflip(
model, im, box_proposals
model,
im,
cfg.TEST.SCALE,
cfg.TEST.MAX_SIZE,
box_proposals=box_proposals
)
add_preds_t(scores_hf, boxes_hf)
......@@ -251,7 +259,9 @@ def im_detect_bbox_aug(model, im, box_proposals=None):
# Compute detections for the original image (identity transform) last to
# ensure that the Caffe2 workspace is populated with blobs corresponding
# to the original image on return (postcondition of im_detect_bbox)
scores_i, boxes_i, im_scale_i = im_detect_bbox(model, im, box_proposals)
scores_i, boxes_i, im_scale_i = im_detect_bbox(
model, im, cfg.TEST.SCALE, cfg.TEST.MAX_SIZE, boxes=box_proposals
)
add_preds_t(scores_i, boxes_i)
# Combine the predicted scores
......@@ -281,7 +291,9 @@ def im_detect_bbox_aug(model, im, box_proposals=None):
return scores_c, boxes_c, im_scale_i
def im_detect_bbox_hflip(model, im, box_proposals=None):
def im_detect_bbox_hflip(
model, im, target_scale, target_max_size, box_proposals=None
):
"""Performs bbox detection on the horizontally flipped image.
Function signature is the same as for im_detect_bbox.
"""
......@@ -295,7 +307,7 @@ def im_detect_bbox_hflip(model, im, box_proposals=None):
box_proposals_hf = None
scores_hf, boxes_hf, im_scale = im_detect_bbox(
model, im_hf, box_proposals_hf
model, im_hf, target_scale, target_max_size, boxes=box_proposals_hf
)
# Invert the detections computed on the flipped image
......@@ -305,30 +317,19 @@ def im_detect_bbox_hflip(model, im, box_proposals=None):
def im_detect_bbox_scale(
model, im, scale, max_size, box_proposals=None, hflip=False
model, im, target_scale, target_max_size, box_proposals=None, hflip=False
):
"""Computes bbox detections at the given scale.
Returns predictions in the original image space.
"""
# Remember the original scale
orig_scale = cfg.TEST.SCALE
orig_max_size = cfg.TEST.MAX_SIZE
# Perform detection at the given scale
cfg.TEST.SCALE = scale
cfg.TEST.MAX_SIZE = max_size
if hflip:
scores_scl, boxes_scl, _ = im_detect_bbox_hflip(
model, im, box_proposals
model, im, target_scale, target_max_size, box_proposals=box_proposals
)
else:
scores_scl, boxes_scl, _ = im_detect_bbox(model, im, box_proposals)
# Restore the original scale
cfg.TEST.SCALE = orig_scale
cfg.TEST.MAX_SIZE = orig_max_size
scores_scl, boxes_scl, _ = im_detect_bbox(
model, im, target_scale, target_max_size, boxes=box_proposals
)
return scores_scl, boxes_scl
......@@ -348,10 +349,20 @@ def im_detect_bbox_aspect_ratio(
if hflip:
scores_ar, boxes_ar, _ = im_detect_bbox_hflip(
model, im_ar, box_proposals_ar
model,
im_ar,
cfg.TEST.SCALE,
cfg.TEST.MAX_SIZE,
box_proposals=box_proposals_ar
)
else:
scores_ar, boxes_ar, _ = im_detect_bbox(model, im_ar, box_proposals_ar)
scores_ar, boxes_ar, _ = im_detect_bbox(
model,
im_ar,
cfg.TEST.SCALE,
cfg.TEST.MAX_SIZE,
boxes=box_proposals_ar
)
# Invert the detected boxes
boxes_inv = box_utils.aspect_ratio(boxes_ar, 1.0 / aspect_ratio)
......@@ -420,13 +431,15 @@ def im_detect_mask_aug(model, im, boxes):
masks_ts = []
# Compute masks for the original image (identity transform)
im_scale_i = im_conv_body_only(model, im)
im_scale_i = im_conv_body_only(model, im, cfg.TEST.SCALE, cfg.TEST.MAX_SIZE)
masks_i = im_detect_mask(model, im_scale_i, boxes)
masks_ts.append(masks_i)
# Perform mask detection on the horizontally flipped image
if cfg.TEST.MASK_AUG.H_FLIP:
masks_hf = im_detect_mask_hflip(model, im, boxes)
masks_hf = im_detect_mask_hflip(
model, im, cfg.TEST.SCALE, cfg.TEST.MAX_SIZE, boxes
)
masks_ts.append(masks_hf)
# Compute detections at different scales
......@@ -473,7 +486,7 @@ def im_detect_mask_aug(model, im, boxes):
return masks_c
def im_detect_mask_hflip(model, im, boxes):
def im_detect_mask_hflip(model, im, target_scale, target_max_size, boxes):
"""Performs mask detection on the horizontally flipped image.
Function signature is the same as for im_detect_mask_aug.
"""
......@@ -481,7 +494,7 @@ def im_detect_mask_hflip(model, im, boxes):
im_hf = im[:, ::-1, :]
boxes_hf = box_utils.flip_boxes(boxes, im.shape[1])
im_scale = im_conv_body_only(model, im_hf)
im_scale = im_conv_body_only(model, im_hf, target_scale, target_max_size)
masks_hf = im_detect_mask(model, im_scale, boxes_hf)
# Invert the predicted soft masks
......@@ -490,27 +503,17 @@ def im_detect_mask_hflip(model, im, boxes):
return masks_inv
def im_detect_mask_scale(model, im, scale, max_size, boxes, hflip=False):
def im_detect_mask_scale(
model, im, target_scale, target_max_size, boxes, hflip=False
):
"""Computes masks at the given scale."""
# Remember the original scale
orig_scale = cfg.TEST.SCALE
orig_max_size = cfg.TEST.MAX_SIZE
# Perform mask detection at the given scale
cfg.TEST.SCALE = scale
cfg.TEST.MAX_SIZE = max_size
if hflip:
masks_scl = im_detect_mask_hflip(model, im, boxes)
masks_scl = im_detect_mask_hflip(
model, im, target_scale, target_max_size, boxes
)
else:
im_scale = im_conv_body_only(model, im)
im_scale = im_conv_body_only(model, im, target_scale, target_max_size)
masks_scl = im_detect_mask(model, im_scale, boxes)
# Restore the original scale
cfg.TEST.SCALE = orig_scale
cfg.TEST.MAX_SIZE = orig_max_size
return masks_scl
......@@ -522,9 +525,13 @@ def im_detect_mask_aspect_ratio(model, im, aspect_ratio, boxes, hflip=False):
boxes_ar = box_utils.aspect_ratio(boxes, aspect_ratio)
if hflip:
masks_ar = im_detect_mask_hflip(model, im_ar, boxes_ar)
masks_ar = im_detect_mask_hflip(
model, im_ar, cfg.TEST.SCALE, cfg.TEST.MAX_SIZE, boxes_ar
)
else:
im_scale = im_conv_body_only(model, im_ar)
im_scale = im_conv_body_only(
model, im_ar, cfg.TEST.SCALE, cfg.TEST.MAX_SIZE
)
masks_ar = im_detect_mask(model, im_scale, boxes_ar)
return masks_ar
......@@ -595,13 +602,15 @@ def im_detect_keypoints_aug(model, im, boxes):
us_ts.append(us_t)
# Compute the heatmaps for the original image (identity transform)
im_scale = im_conv_body_only(model, im)
im_scale = im_conv_body_only(model, im, cfg.TEST.SCALE, cfg.TEST.MAX_SIZE)
heatmaps_i = im_detect_keypoints(model, im_scale, boxes)
add_heatmaps_t(heatmaps_i)
# Perform keypoints detection on the horizontally flipped image
if cfg.TEST.KPS_AUG.H_FLIP:
heatmaps_hf = im_detect_keypoints_hflip(model, im, boxes)
heatmaps_hf = im_detect_keypoints_hflip(
model, im, cfg.TEST.SCALE, cfg.TEST.MAX_SIZE, boxes
)
add_heatmaps_t(heatmaps_hf)
# Compute detections at different scales
......@@ -656,7 +665,7 @@ def im_detect_keypoints_aug(model, im, boxes):
return heatmaps_c
def im_detect_keypoints_hflip(model, im, boxes):
def im_detect_keypoints_hflip(model, im, target_scale, target_max_size, boxes):
"""Computes keypoint predictions on the horizontally flipped image.
Function signature is the same as for im_detect_keypoints_aug.
"""
......@@ -664,7 +673,7 @@ def im_detect_keypoints_hflip(model, im, boxes):
im_hf = im[:, ::-1, :]
boxes_hf = box_utils.flip_boxes(boxes, im.shape[1])
im_scale = im_conv_body_only(model, im_hf)
im_scale = im_conv_body_only(model, im_hf, target_scale, target_max_size)
heatmaps_hf = im_detect_keypoints(model, im_scale, boxes_hf)
# Invert the predicted keypoints
......@@ -673,27 +682,17 @@ def im_detect_keypoints_hflip(model, im, boxes):
return heatmaps_inv
def im_detect_keypoints_scale(model, im, scale, max_size, boxes, hflip=False):
def im_detect_keypoints_scale(
model, im, target_scale, target_max_size, boxes, hflip=False
):
"""Computes keypoint predictions at the given scale."""
# Store the original scale
orig_scale = cfg.TEST.SCALE
orig_max_size = cfg.TEST.MAX_SIZE
# Perform detection at the given scale
cfg.TEST.SCALE = scale
cfg.TEST.MAX_SIZE = max_size
if hflip:
heatmaps_scl = im_detect_keypoints_hflip(model, im, boxes)
heatmaps_scl = im_detect_keypoints_hflip(
model, im, target_scale, target_max_size, boxes
)
else:
im_scale = im_conv_body_only(model, im)
im_scale = im_conv_body_only(model, im, target_scale, target_max_size)
heatmaps_scl = im_detect_keypoints(model, im_scale, boxes)
# Restore the original scale
cfg.TEST.SCALE = orig_scale
cfg.TEST.MAX_SIZE = orig_max_size
return heatmaps_scl
......@@ -707,9 +706,13 @@ def im_detect_keypoints_aspect_ratio(
boxes_ar = box_utils.aspect_ratio(boxes, aspect_ratio)
if hflip:
heatmaps_ar = im_detect_keypoints_hflip(model, im_ar, boxes_ar)
heatmaps_ar = im_detect_keypoints_hflip(
model, im_ar, cfg.TEST.SCALE, cfg.TEST.MAX_SIZE, boxes_ar
)
else:
im_scale = im_conv_body_only(model, im_ar)
im_scale = im_conv_body_only(
model, im_ar, cfg.TEST.SCALE, cfg.TEST.MAX_SIZE
)
heatmaps_ar = im_detect_keypoints(model, im_scale, boxes_ar)
return heatmaps_ar
......@@ -936,11 +939,11 @@ def _add_multilevel_rois_for_test(blobs, name):
)
def _get_blobs(im, rois):
def _get_blobs(im, rois, target_scale, target_max_size):
"""Convert an image and RoIs within that image into network inputs."""
blobs = {}
blobs['data'], im_scale, blobs['im_info'] = \
blob_utils.get_image_blob_for_inference(im)
blob_utils.get_image_blob(im, target_scale, target_max_size)
if rois is not None:
blobs['rois'] = _get_rois_blob(rois, im_scale)
return blobs, im_scale
......@@ -77,7 +77,7 @@ def im_detect_bbox(model, im, timers=None):
A = cfg.RETINANET.SCALES_PER_OCTAVE * len(cfg.RETINANET.ASPECT_RATIOS)
inputs = {}
inputs['data'], im_scale, inputs['im_info'] = \
blob_utils.get_image_blob_for_inference(im)
blob_utils.get_image_blob(im, cfg.TEST.SCALE, cfg.TEST.MAX_SIZE)
cls_probs, box_preds = [], []
for lvl in range(k_min, k_max + 1):
suffix = 'fpn{}'.format(lvl)
......
......@@ -107,8 +107,8 @@ def _get_image_blob(roidb):
im, im_scale = blob_utils.prep_im_for_blob(
im, cfg.PIXEL_MEANS, target_size, cfg.TRAIN.MAX_SIZE
)
im_scales.append(im_scale[0])
processed_ims.append(im[0])
im_scales.append(im_scale)
processed_ims.append(im)
# Create a blob to hold the input images
blob = blob_utils.im_list_to_blob(processed_ims)
......
......@@ -37,8 +37,8 @@ from caffe2.proto import caffe2_pb2
from core.config import cfg
def get_image_blob_for_inference(im):
"""Converts an image into a network input.
def get_image_blob(im, target_scale, target_max_size):
"""Convert an image into a network input.
Arguments:
im (ndarray): a color image in BGR order
......@@ -49,7 +49,7 @@ def get_image_blob_for_inference(im):
im_info (ndarray)
"""
processed_im, im_scale = prep_im_for_blob(
im, cfg.PIXEL_MEANS, cfg.TEST.SCALE, cfg.TEST.MAX_SIZE
im, cfg.PIXEL_MEANS, target_scale, target_max_size
)
blob = im_list_to_blob(processed_im)
# NOTE: this height and width may be larger than actual scaled input image
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment