Commit 4db30576 authored by Ross Girshick's avatar Ross Girshick Committed by Facebook Github Bot

Do not mutate cfg.TEST.SCALE, cfg.TEST.MAX_SIZE

Reviewed By: ir413

Differential Revision: D7148427

fbshipit-source-id: 25b755c75ddc59cff7a1dc5a5cddb139c44f7cbf
parent 70e20023
...@@ -43,7 +43,7 @@ TEST: ...@@ -43,7 +43,7 @@ TEST:
NMS: 0.5 NMS: 0.5
RPN_PRE_NMS_TOP_N: 1000 # Per FPN level RPN_PRE_NMS_TOP_N: 1000 # Per FPN level
RPN_POST_NMS_TOP_N: 1000 RPN_POST_NMS_TOP_N: 1000
WEIGHTS: https://s3-us-west-2.amazonaws.com/detectron/35857389/12_2017_baselines/e2e_faster_rcnn_R-50-FPN_2x.yaml.01_37_22.KSeq0b5q/output/train/coco_2014_train%3Acoco_2014_valminusminival/generalized_rcnn/model_final.pkl WEIGHTS: https://s3-us-west-2.amazonaws.com/detectron/35859007/12_2017_baselines/e2e_mask_rcnn_R-50-FPN_2x.yaml.01_49_07.By8nQcCH/output/train/coco_2014_train:coco_2014_valminusminival/generalized_rcnn/model_final.pkl
# -- Test time augmentation example -- # # -- Test time augmentation example -- #
BBOX_AUG: BBOX_AUG:
......
...@@ -47,7 +47,7 @@ TEST: ...@@ -47,7 +47,7 @@ TEST:
SCALE: 800 SCALE: 800
MAX_SIZE: 1333 MAX_SIZE: 1333
NMS: 0.5 NMS: 0.5
WEIGHTS: https://s3-us-west-2.amazonaws.com/detectron/37651887/12_2017_baselines/keypoint_rcnn_R-50-FPN_s1x.yaml.20_01_40.FDjUQ7VX/output/train/keypoints_coco_2014_train%3Akeypoints_coco_2014_valminusminival/generalized_rcnn/model_final.pkl WEIGHTS: https://s3-us-west-2.amazonaws.com/detectron/37651887/12_2017_baselines/keypoint_rcnn_R-50-FPN_s1x.yaml.20_01_40.FDjUQ7VX/output/train/keypoints_coco_2014_train:keypoints_coco_2014_valminusminival/generalized_rcnn/model_final.pkl
# -- Test time augmentation example -- # # -- Test time augmentation example -- #
BBOX_AUG: BBOX_AUG:
......
...@@ -187,7 +187,7 @@ def im_proposals(model, im): ...@@ -187,7 +187,7 @@ def im_proposals(model, im):
"""Generate RPN proposals on a single image.""" """Generate RPN proposals on a single image."""
inputs = {} inputs = {}
inputs['data'], im_scale, inputs['im_info'] = \ inputs['data'], im_scale, inputs['im_info'] = \
blob_utils.get_image_blob_for_inference(im) blob_utils.get_image_blob(im, cfg.TEST.SCALE, cfg.TEST.MAX_SIZE)
for k, v in inputs.items(): for k, v in inputs.items():
workspace.FeedBlob(core.ScopedName(k), v.astype(np.float32, copy=False)) workspace.FeedBlob(core.ScopedName(k), v.astype(np.float32, copy=False))
workspace.RunNet(model.net.Proto().name) workspace.RunNet(model.net.Proto().name)
......
...@@ -62,7 +62,9 @@ def im_detect_all(model, im, box_proposals, timers=None): ...@@ -62,7 +62,9 @@ def im_detect_all(model, im, box_proposals, timers=None):
if cfg.TEST.BBOX_AUG.ENABLED: if cfg.TEST.BBOX_AUG.ENABLED:
scores, boxes, im_scale = im_detect_bbox_aug(model, im, box_proposals) scores, boxes, im_scale = im_detect_bbox_aug(model, im, box_proposals)
else: else:
scores, boxes, im_scale = im_detect_bbox(model, im, box_proposals) scores, boxes, im_scale = im_detect_bbox(
model, im, cfg.TEST.SCALE, cfg.TEST.MAX_SIZE, boxes=box_proposals
)
timers['im_detect_bbox'].toc() timers['im_detect_bbox'].toc()
# score and boxes are from the whole image after score thresholding and nms # score and boxes are from the whole image after score thresholding and nms
...@@ -106,15 +108,17 @@ def im_detect_all(model, im, box_proposals, timers=None): ...@@ -106,15 +108,17 @@ def im_detect_all(model, im, box_proposals, timers=None):
return cls_boxes, cls_segms, cls_keyps return cls_boxes, cls_segms, cls_keyps
def im_conv_body_only(model, im): def im_conv_body_only(model, im, target_scale, target_max_size):
"""Runs `model.conv_body_net` on the given image `im`.""" """Runs `model.conv_body_net` on the given image `im`."""
im_blob, im_scale, _im_info = blob_utils.get_image_blob_for_inference(im) im_blob, im_scale, _im_info = blob_utils.get_image_blob(
im, target_scale, target_max_size
)
workspace.FeedBlob(core.ScopedName('data'), im_blob) workspace.FeedBlob(core.ScopedName('data'), im_blob)
workspace.RunNet(model.conv_body_net.Proto().name) workspace.RunNet(model.conv_body_net.Proto().name)
return im_scale return im_scale
def im_detect_bbox(model, im, boxes=None): def im_detect_bbox(model, im, target_scale, target_max_size, boxes=None):
"""Bounding box object detection for an image with given box proposals. """Bounding box object detection for an image with given box proposals.
Arguments: Arguments:
...@@ -130,7 +134,7 @@ def im_detect_bbox(model, im, boxes=None): ...@@ -130,7 +134,7 @@ def im_detect_bbox(model, im, boxes=None):
im_scales (list): list of image scales used in the input blob (as im_scales (list): list of image scales used in the input blob (as
returned by _get_blobs and for use with im_detect_mask, etc.) returned by _get_blobs and for use with im_detect_mask, etc.)
""" """
inputs, im_scale = _get_blobs(im, boxes) inputs, im_scale = _get_blobs(im, boxes, target_scale, target_max_size)
# When mapping from image ROIs to feature map ROIs, there's some aliasing # When mapping from image ROIs to feature map ROIs, there's some aliasing
# (some distinct image ROIs get mapped to the same feature ROI). # (some distinct image ROIs get mapped to the same feature ROI).
...@@ -217,7 +221,11 @@ def im_detect_bbox_aug(model, im, box_proposals=None): ...@@ -217,7 +221,11 @@ def im_detect_bbox_aug(model, im, box_proposals=None):
# Perform detection on the horizontally flipped image # Perform detection on the horizontally flipped image
if cfg.TEST.BBOX_AUG.H_FLIP: if cfg.TEST.BBOX_AUG.H_FLIP:
scores_hf, boxes_hf, _ = im_detect_bbox_hflip( scores_hf, boxes_hf, _ = im_detect_bbox_hflip(
model, im, box_proposals model,
im,
cfg.TEST.SCALE,
cfg.TEST.MAX_SIZE,
box_proposals=box_proposals
) )
add_preds_t(scores_hf, boxes_hf) add_preds_t(scores_hf, boxes_hf)
...@@ -251,7 +259,9 @@ def im_detect_bbox_aug(model, im, box_proposals=None): ...@@ -251,7 +259,9 @@ def im_detect_bbox_aug(model, im, box_proposals=None):
# Compute detections for the original image (identity transform) last to # Compute detections for the original image (identity transform) last to
# ensure that the Caffe2 workspace is populated with blobs corresponding # ensure that the Caffe2 workspace is populated with blobs corresponding
# to the original image on return (postcondition of im_detect_bbox) # to the original image on return (postcondition of im_detect_bbox)
scores_i, boxes_i, im_scale_i = im_detect_bbox(model, im, box_proposals) scores_i, boxes_i, im_scale_i = im_detect_bbox(
model, im, cfg.TEST.SCALE, cfg.TEST.MAX_SIZE, boxes=box_proposals
)
add_preds_t(scores_i, boxes_i) add_preds_t(scores_i, boxes_i)
# Combine the predicted scores # Combine the predicted scores
...@@ -281,7 +291,9 @@ def im_detect_bbox_aug(model, im, box_proposals=None): ...@@ -281,7 +291,9 @@ def im_detect_bbox_aug(model, im, box_proposals=None):
return scores_c, boxes_c, im_scale_i return scores_c, boxes_c, im_scale_i
def im_detect_bbox_hflip(model, im, box_proposals=None): def im_detect_bbox_hflip(
model, im, target_scale, target_max_size, box_proposals=None
):
"""Performs bbox detection on the horizontally flipped image. """Performs bbox detection on the horizontally flipped image.
Function signature is the same as for im_detect_bbox. Function signature is the same as for im_detect_bbox.
""" """
...@@ -295,7 +307,7 @@ def im_detect_bbox_hflip(model, im, box_proposals=None): ...@@ -295,7 +307,7 @@ def im_detect_bbox_hflip(model, im, box_proposals=None):
box_proposals_hf = None box_proposals_hf = None
scores_hf, boxes_hf, im_scale = im_detect_bbox( scores_hf, boxes_hf, im_scale = im_detect_bbox(
model, im_hf, box_proposals_hf model, im_hf, target_scale, target_max_size, boxes=box_proposals_hf
) )
# Invert the detections computed on the flipped image # Invert the detections computed on the flipped image
...@@ -305,30 +317,19 @@ def im_detect_bbox_hflip(model, im, box_proposals=None): ...@@ -305,30 +317,19 @@ def im_detect_bbox_hflip(model, im, box_proposals=None):
def im_detect_bbox_scale( def im_detect_bbox_scale(
model, im, scale, max_size, box_proposals=None, hflip=False model, im, target_scale, target_max_size, box_proposals=None, hflip=False
): ):
"""Computes bbox detections at the given scale. """Computes bbox detections at the given scale.
Returns predictions in the original image space. Returns predictions in the original image space.
""" """
# Remember the original scale
orig_scale = cfg.TEST.SCALE
orig_max_size = cfg.TEST.MAX_SIZE
# Perform detection at the given scale
cfg.TEST.SCALE = scale
cfg.TEST.MAX_SIZE = max_size
if hflip: if hflip:
scores_scl, boxes_scl, _ = im_detect_bbox_hflip( scores_scl, boxes_scl, _ = im_detect_bbox_hflip(
model, im, box_proposals model, im, target_scale, target_max_size, box_proposals=box_proposals
) )
else: else:
scores_scl, boxes_scl, _ = im_detect_bbox(model, im, box_proposals) scores_scl, boxes_scl, _ = im_detect_bbox(
model, im, target_scale, target_max_size, boxes=box_proposals
# Restore the original scale )
cfg.TEST.SCALE = orig_scale
cfg.TEST.MAX_SIZE = orig_max_size
return scores_scl, boxes_scl return scores_scl, boxes_scl
...@@ -348,10 +349,20 @@ def im_detect_bbox_aspect_ratio( ...@@ -348,10 +349,20 @@ def im_detect_bbox_aspect_ratio(
if hflip: if hflip:
scores_ar, boxes_ar, _ = im_detect_bbox_hflip( scores_ar, boxes_ar, _ = im_detect_bbox_hflip(
model, im_ar, box_proposals_ar model,
im_ar,
cfg.TEST.SCALE,
cfg.TEST.MAX_SIZE,
box_proposals=box_proposals_ar
) )
else: else:
scores_ar, boxes_ar, _ = im_detect_bbox(model, im_ar, box_proposals_ar) scores_ar, boxes_ar, _ = im_detect_bbox(
model,
im_ar,
cfg.TEST.SCALE,
cfg.TEST.MAX_SIZE,
boxes=box_proposals_ar
)
# Invert the detected boxes # Invert the detected boxes
boxes_inv = box_utils.aspect_ratio(boxes_ar, 1.0 / aspect_ratio) boxes_inv = box_utils.aspect_ratio(boxes_ar, 1.0 / aspect_ratio)
...@@ -420,13 +431,15 @@ def im_detect_mask_aug(model, im, boxes): ...@@ -420,13 +431,15 @@ def im_detect_mask_aug(model, im, boxes):
masks_ts = [] masks_ts = []
# Compute masks for the original image (identity transform) # Compute masks for the original image (identity transform)
im_scale_i = im_conv_body_only(model, im) im_scale_i = im_conv_body_only(model, im, cfg.TEST.SCALE, cfg.TEST.MAX_SIZE)
masks_i = im_detect_mask(model, im_scale_i, boxes) masks_i = im_detect_mask(model, im_scale_i, boxes)
masks_ts.append(masks_i) masks_ts.append(masks_i)
# Perform mask detection on the horizontally flipped image # Perform mask detection on the horizontally flipped image
if cfg.TEST.MASK_AUG.H_FLIP: if cfg.TEST.MASK_AUG.H_FLIP:
masks_hf = im_detect_mask_hflip(model, im, boxes) masks_hf = im_detect_mask_hflip(
model, im, cfg.TEST.SCALE, cfg.TEST.MAX_SIZE, boxes
)
masks_ts.append(masks_hf) masks_ts.append(masks_hf)
# Compute detections at different scales # Compute detections at different scales
...@@ -473,7 +486,7 @@ def im_detect_mask_aug(model, im, boxes): ...@@ -473,7 +486,7 @@ def im_detect_mask_aug(model, im, boxes):
return masks_c return masks_c
def im_detect_mask_hflip(model, im, boxes): def im_detect_mask_hflip(model, im, target_scale, target_max_size, boxes):
"""Performs mask detection on the horizontally flipped image. """Performs mask detection on the horizontally flipped image.
Function signature is the same as for im_detect_mask_aug. Function signature is the same as for im_detect_mask_aug.
""" """
...@@ -481,7 +494,7 @@ def im_detect_mask_hflip(model, im, boxes): ...@@ -481,7 +494,7 @@ def im_detect_mask_hflip(model, im, boxes):
im_hf = im[:, ::-1, :] im_hf = im[:, ::-1, :]
boxes_hf = box_utils.flip_boxes(boxes, im.shape[1]) boxes_hf = box_utils.flip_boxes(boxes, im.shape[1])
im_scale = im_conv_body_only(model, im_hf) im_scale = im_conv_body_only(model, im_hf, target_scale, target_max_size)
masks_hf = im_detect_mask(model, im_scale, boxes_hf) masks_hf = im_detect_mask(model, im_scale, boxes_hf)
# Invert the predicted soft masks # Invert the predicted soft masks
...@@ -490,27 +503,17 @@ def im_detect_mask_hflip(model, im, boxes): ...@@ -490,27 +503,17 @@ def im_detect_mask_hflip(model, im, boxes):
return masks_inv return masks_inv
def im_detect_mask_scale(model, im, scale, max_size, boxes, hflip=False): def im_detect_mask_scale(
model, im, target_scale, target_max_size, boxes, hflip=False
):
"""Computes masks at the given scale.""" """Computes masks at the given scale."""
# Remember the original scale
orig_scale = cfg.TEST.SCALE
orig_max_size = cfg.TEST.MAX_SIZE
# Perform mask detection at the given scale
cfg.TEST.SCALE = scale
cfg.TEST.MAX_SIZE = max_size
if hflip: if hflip:
masks_scl = im_detect_mask_hflip(model, im, boxes) masks_scl = im_detect_mask_hflip(
model, im, target_scale, target_max_size, boxes
)
else: else:
im_scale = im_conv_body_only(model, im) im_scale = im_conv_body_only(model, im, target_scale, target_max_size)
masks_scl = im_detect_mask(model, im_scale, boxes) masks_scl = im_detect_mask(model, im_scale, boxes)
# Restore the original scale
cfg.TEST.SCALE = orig_scale
cfg.TEST.MAX_SIZE = orig_max_size
return masks_scl return masks_scl
...@@ -522,9 +525,13 @@ def im_detect_mask_aspect_ratio(model, im, aspect_ratio, boxes, hflip=False): ...@@ -522,9 +525,13 @@ def im_detect_mask_aspect_ratio(model, im, aspect_ratio, boxes, hflip=False):
boxes_ar = box_utils.aspect_ratio(boxes, aspect_ratio) boxes_ar = box_utils.aspect_ratio(boxes, aspect_ratio)
if hflip: if hflip:
masks_ar = im_detect_mask_hflip(model, im_ar, boxes_ar) masks_ar = im_detect_mask_hflip(
model, im_ar, cfg.TEST.SCALE, cfg.TEST.MAX_SIZE, boxes_ar
)
else: else:
im_scale = im_conv_body_only(model, im_ar) im_scale = im_conv_body_only(
model, im_ar, cfg.TEST.SCALE, cfg.TEST.MAX_SIZE
)
masks_ar = im_detect_mask(model, im_scale, boxes_ar) masks_ar = im_detect_mask(model, im_scale, boxes_ar)
return masks_ar return masks_ar
...@@ -595,13 +602,15 @@ def im_detect_keypoints_aug(model, im, boxes): ...@@ -595,13 +602,15 @@ def im_detect_keypoints_aug(model, im, boxes):
us_ts.append(us_t) us_ts.append(us_t)
# Compute the heatmaps for the original image (identity transform) # Compute the heatmaps for the original image (identity transform)
im_scale = im_conv_body_only(model, im) im_scale = im_conv_body_only(model, im, cfg.TEST.SCALE, cfg.TEST.MAX_SIZE)
heatmaps_i = im_detect_keypoints(model, im_scale, boxes) heatmaps_i = im_detect_keypoints(model, im_scale, boxes)
add_heatmaps_t(heatmaps_i) add_heatmaps_t(heatmaps_i)
# Perform keypoints detection on the horizontally flipped image # Perform keypoints detection on the horizontally flipped image
if cfg.TEST.KPS_AUG.H_FLIP: if cfg.TEST.KPS_AUG.H_FLIP:
heatmaps_hf = im_detect_keypoints_hflip(model, im, boxes) heatmaps_hf = im_detect_keypoints_hflip(
model, im, cfg.TEST.SCALE, cfg.TEST.MAX_SIZE, boxes
)
add_heatmaps_t(heatmaps_hf) add_heatmaps_t(heatmaps_hf)
# Compute detections at different scales # Compute detections at different scales
...@@ -656,7 +665,7 @@ def im_detect_keypoints_aug(model, im, boxes): ...@@ -656,7 +665,7 @@ def im_detect_keypoints_aug(model, im, boxes):
return heatmaps_c return heatmaps_c
def im_detect_keypoints_hflip(model, im, boxes): def im_detect_keypoints_hflip(model, im, target_scale, target_max_size, boxes):
"""Computes keypoint predictions on the horizontally flipped image. """Computes keypoint predictions on the horizontally flipped image.
Function signature is the same as for im_detect_keypoints_aug. Function signature is the same as for im_detect_keypoints_aug.
""" """
...@@ -664,7 +673,7 @@ def im_detect_keypoints_hflip(model, im, boxes): ...@@ -664,7 +673,7 @@ def im_detect_keypoints_hflip(model, im, boxes):
im_hf = im[:, ::-1, :] im_hf = im[:, ::-1, :]
boxes_hf = box_utils.flip_boxes(boxes, im.shape[1]) boxes_hf = box_utils.flip_boxes(boxes, im.shape[1])
im_scale = im_conv_body_only(model, im_hf) im_scale = im_conv_body_only(model, im_hf, target_scale, target_max_size)
heatmaps_hf = im_detect_keypoints(model, im_scale, boxes_hf) heatmaps_hf = im_detect_keypoints(model, im_scale, boxes_hf)
# Invert the predicted keypoints # Invert the predicted keypoints
...@@ -673,27 +682,17 @@ def im_detect_keypoints_hflip(model, im, boxes): ...@@ -673,27 +682,17 @@ def im_detect_keypoints_hflip(model, im, boxes):
return heatmaps_inv return heatmaps_inv
def im_detect_keypoints_scale(model, im, scale, max_size, boxes, hflip=False): def im_detect_keypoints_scale(
model, im, target_scale, target_max_size, boxes, hflip=False
):
"""Computes keypoint predictions at the given scale.""" """Computes keypoint predictions at the given scale."""
# Store the original scale
orig_scale = cfg.TEST.SCALE
orig_max_size = cfg.TEST.MAX_SIZE
# Perform detection at the given scale
cfg.TEST.SCALE = scale
cfg.TEST.MAX_SIZE = max_size
if hflip: if hflip:
heatmaps_scl = im_detect_keypoints_hflip(model, im, boxes) heatmaps_scl = im_detect_keypoints_hflip(
model, im, target_scale, target_max_size, boxes
)
else: else:
im_scale = im_conv_body_only(model, im) im_scale = im_conv_body_only(model, im, target_scale, target_max_size)
heatmaps_scl = im_detect_keypoints(model, im_scale, boxes) heatmaps_scl = im_detect_keypoints(model, im_scale, boxes)
# Restore the original scale
cfg.TEST.SCALE = orig_scale
cfg.TEST.MAX_SIZE = orig_max_size
return heatmaps_scl return heatmaps_scl
...@@ -707,9 +706,13 @@ def im_detect_keypoints_aspect_ratio( ...@@ -707,9 +706,13 @@ def im_detect_keypoints_aspect_ratio(
boxes_ar = box_utils.aspect_ratio(boxes, aspect_ratio) boxes_ar = box_utils.aspect_ratio(boxes, aspect_ratio)
if hflip: if hflip:
heatmaps_ar = im_detect_keypoints_hflip(model, im_ar, boxes_ar) heatmaps_ar = im_detect_keypoints_hflip(
model, im_ar, cfg.TEST.SCALE, cfg.TEST.MAX_SIZE, boxes_ar
)
else: else:
im_scale = im_conv_body_only(model, im_ar) im_scale = im_conv_body_only(
model, im_ar, cfg.TEST.SCALE, cfg.TEST.MAX_SIZE
)
heatmaps_ar = im_detect_keypoints(model, im_scale, boxes_ar) heatmaps_ar = im_detect_keypoints(model, im_scale, boxes_ar)
return heatmaps_ar return heatmaps_ar
...@@ -936,11 +939,11 @@ def _add_multilevel_rois_for_test(blobs, name): ...@@ -936,11 +939,11 @@ def _add_multilevel_rois_for_test(blobs, name):
) )
def _get_blobs(im, rois): def _get_blobs(im, rois, target_scale, target_max_size):
"""Convert an image and RoIs within that image into network inputs.""" """Convert an image and RoIs within that image into network inputs."""
blobs = {} blobs = {}
blobs['data'], im_scale, blobs['im_info'] = \ blobs['data'], im_scale, blobs['im_info'] = \
blob_utils.get_image_blob_for_inference(im) blob_utils.get_image_blob(im, target_scale, target_max_size)
if rois is not None: if rois is not None:
blobs['rois'] = _get_rois_blob(rois, im_scale) blobs['rois'] = _get_rois_blob(rois, im_scale)
return blobs, im_scale return blobs, im_scale
...@@ -77,7 +77,7 @@ def im_detect_bbox(model, im, timers=None): ...@@ -77,7 +77,7 @@ def im_detect_bbox(model, im, timers=None):
A = cfg.RETINANET.SCALES_PER_OCTAVE * len(cfg.RETINANET.ASPECT_RATIOS) A = cfg.RETINANET.SCALES_PER_OCTAVE * len(cfg.RETINANET.ASPECT_RATIOS)
inputs = {} inputs = {}
inputs['data'], im_scale, inputs['im_info'] = \ inputs['data'], im_scale, inputs['im_info'] = \
blob_utils.get_image_blob_for_inference(im) blob_utils.get_image_blob(im, cfg.TEST.SCALE, cfg.TEST.MAX_SIZE)
cls_probs, box_preds = [], [] cls_probs, box_preds = [], []
for lvl in range(k_min, k_max + 1): for lvl in range(k_min, k_max + 1):
suffix = 'fpn{}'.format(lvl) suffix = 'fpn{}'.format(lvl)
......
...@@ -107,8 +107,8 @@ def _get_image_blob(roidb): ...@@ -107,8 +107,8 @@ def _get_image_blob(roidb):
im, im_scale = blob_utils.prep_im_for_blob( im, im_scale = blob_utils.prep_im_for_blob(
im, cfg.PIXEL_MEANS, target_size, cfg.TRAIN.MAX_SIZE im, cfg.PIXEL_MEANS, target_size, cfg.TRAIN.MAX_SIZE
) )
im_scales.append(im_scale[0]) im_scales.append(im_scale)
processed_ims.append(im[0]) processed_ims.append(im)
# Create a blob to hold the input images # Create a blob to hold the input images
blob = blob_utils.im_list_to_blob(processed_ims) blob = blob_utils.im_list_to_blob(processed_ims)
......
...@@ -37,8 +37,8 @@ from caffe2.proto import caffe2_pb2 ...@@ -37,8 +37,8 @@ from caffe2.proto import caffe2_pb2
from core.config import cfg from core.config import cfg
def get_image_blob_for_inference(im): def get_image_blob(im, target_scale, target_max_size):
"""Converts an image into a network input. """Convert an image into a network input.
Arguments: Arguments:
im (ndarray): a color image in BGR order im (ndarray): a color image in BGR order
...@@ -49,7 +49,7 @@ def get_image_blob_for_inference(im): ...@@ -49,7 +49,7 @@ def get_image_blob_for_inference(im):
im_info (ndarray) im_info (ndarray)
""" """
processed_im, im_scale = prep_im_for_blob( processed_im, im_scale = prep_im_for_blob(
im, cfg.PIXEL_MEANS, cfg.TEST.SCALE, cfg.TEST.MAX_SIZE im, cfg.PIXEL_MEANS, target_scale, target_max_size
) )
blob = im_list_to_blob(processed_im) blob = im_list_to_blob(processed_im)
# NOTE: this height and width may be larger than actual scaled input image # NOTE: this height and width may be larger than actual scaled input image
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment