Commit 757d77c2 authored by Peizhao Zhang's avatar Peizhao Zhang Committed by Facebook Github Bot

Added a script to convert pkl model to pb.

Summary:
Added a script to convert pkl model to pb.
- Supported Faster R-CNN detection model.
- Supported converting to CPU and GPU model.
- Fused AffineChannel to Conv.
- Verified model after conversion.

Reviewed By: rbgirshick

Differential Revision: D6783633

fbshipit-source-id: 19706d3074a4a784a2161a695d7c534e014ebb3f
parent dd6c6615
...@@ -128,9 +128,10 @@ class DetectionModelHelper(cnn.CNNModelHelper): ...@@ -128,9 +128,10 @@ class DetectionModelHelper(cnn.CNNModelHelper):
(extracted from rpn_cls_probs; see above). (extracted from rpn_cls_probs; see above).
""" """
name = 'GenerateProposalsOp:' + ','.join([str(b) for b in blobs_in]) name = 'GenerateProposalsOp:' + ','.join([str(b) for b in blobs_in])
# spatial_scale passed to the Python op is only used in convert_pkl_to_pb
self.net.Python( self.net.Python(
GenerateProposalsOp(anchors, spatial_scale, self.train).forward GenerateProposalsOp(anchors, spatial_scale, self.train).forward
)(blobs_in, blobs_out, name=name) )(blobs_in, blobs_out, name=name, spatial_scale=spatial_scale)
return blobs_out return blobs_out
def GenerateProposalLabels(self, blobs_in): def GenerateProposalLabels(self, blobs_in):
......
# Copyright (c) 2017-present, Facebook, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
##############################################################################
'''Helper functions for model conversion to pb'''
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
from functools import wraps
import copy
import numpy as np
from caffe2.python import core, workspace
from caffe2.proto import caffe2_pb2
class OpFilter(object):
def __init__(self, **kwargs):
self.type = None
self.type_in = None
self.inputs = None
self.outputs = None
self.input_has = None
self.output_has = None
self.cond = None
self.reverse = False
assert all([x in self.__dict__ for x in kwargs])
self.__dict__.update(kwargs)
def check(self, op):
ret = self.reverse
if self.type and op.type != self.type:
return ret
if self.type_in and op.type not in self.type_in:
return ret
if self.inputs and set(op.input) != set(self.inputs):
return ret
if self.outputs and set(op.output) != set(self.outputs):
return ret
if self.input_has and self.input_has not in op.input:
return ret
if self.output_has and self.output_has not in op.output:
return ret
if self.cond is not None and not self.cond:
return ret
return not ret
def filter_op(op, **kwargs):
''' Returns true if passed all checks '''
return OpFilter(**kwargs).check(op)
def op_filter(**filter_args):
''' Returns None if no condition is satisfied '''
def actual_decorator(f):
@wraps(f)
def wrapper(op, **params):
if not filter_op(op, **filter_args):
return None
return f(op, **params)
return wrapper
return actual_decorator
def op_func_chain(convert_func_list):
''' Run funcs one by one until func return is not None '''
assert isinstance(convert_func_list, list)
def _chain(op):
for x in convert_func_list:
ret = x(op)
if ret is not None:
return ret
return None
return _chain
def convert_op_in_ops(ops_ref, func_or_list):
func = func_or_list
if isinstance(func_or_list, list):
func = op_func_chain(func_or_list)
ops = [op for op in ops_ref]
converted_ops = []
for op in ops:
new_ops = func(op)
if new_ops is not None and not isinstance(new_ops, list):
new_ops = [new_ops]
converted_ops.extend(new_ops if new_ops is not None else [op])
del ops_ref[:]
# ops_ref maybe of type RepeatedCompositeFieldContainer
# which does not have append()
ops_ref.extend(converted_ops)
def convert_op_in_proto(proto, func_or_list):
convert_op_in_ops(proto.op, func_or_list)
def get_op_arg(op, arg_name):
for x in op.arg:
if x.name == arg_name:
return x
return None
def get_op_arg_valf(op, arg_name, default_val):
arg = get_op_arg(op, arg_name)
return arg.f if arg is not None else default_val
def update_mobile_engines(net):
for op in net.op:
if op.type == "Conv":
op.engine = "NNPACK"
if op.type == "ConvTranspose":
op.engine = "BLOCK"
def pairwise(iterable):
"s -> (s0,s1), (s1,s2), (s2, s3), ..."
from itertools import tee
a, b = tee(iterable)
next(b, None)
return zip(a, b)
def blob_uses(net, blob):
u = []
for i, op in enumerate(net.op):
if blob in op.input or blob in op.control_input:
u.append(i)
return u
def fuse_first_affine(net, params, removed_tensors):
net = copy.deepcopy(net)
params = copy.deepcopy(params)
for ((i, current), (j, next_)) in pairwise(enumerate(net.op)):
if next_.input[0] != current.output[0]:
continue
if current.type not in ("Conv", "ConvTranspose") \
or next_.type != "AffineChannel":
continue
if current.output[0] != next_.output[0] and \
len(blob_uses(net, current.output[0])) != 1:
# Can't fuse if more than one user unless AffineChannel is inplace
continue
# else, can fuse
conv = current
affine = next_
fused_conv = copy.deepcopy(conv)
fused_conv.output[0] = affine.output[0]
conv_weight = params[conv.input[1]]
conv_has_bias = len(conv.input) > 2
conv_bias = params[conv.input[2]] if conv_has_bias else 0
A = params[affine.input[1]]
B = params[affine.input[2]]
# Thus, can just have the affine transform
# X * A + B
# where
# A = bn_scale * 1.0 / (sqrt(running_var + eps))
# B = (bias - running_mean * (1.0 / sqrt(running_var + eps))
# * bn_scale)
# This identify should hold if we have correctly fused
# np.testing.assert_array_equal(
# params[conv.output[0]] * A + B,
# params[bn.output[0]])
# Now, we have that the computation made is the following:
# ((X `conv` W) + b) * A + B
# Then, we can simply fuse this as follows:
# (X `conv` (W * A)) + b * A + B
# which is simply
# (X `conv` Q) + C
# where
# Q = W * A
# C = b * A + B
# For ConvTranspose, from the view of convolutions as a
# Toepeliz multiplication, we have W_ = W^T, so the weights
# are laid out as (R, S, K, K) (vs (S, R, K, K) for a Conv),
# so the weights broadcast slightly differently. Remember, our
# BN scale 'B' is of size (S,)
A_ = A.reshape(-1, 1, 1, 1) if conv.type == "Conv" else \
A.reshape(1, -1, 1, 1)
C = conv_bias * A + B
Q = conv_weight * A_
assert params[conv.input[1]].shape == Q.shape
params[conv.input[1]] = Q
if conv_has_bias:
assert params[conv.input[2]].shape == C.shape
params[conv.input[2]] = C
else:
# make af_bias to be bias of the conv layer
fused_conv.input.append(affine.input[2])
params[affine.input[2]] = B
new_ops = net.op[:i] + [fused_conv] + net.op[j + 1:]
del net.op[:]
if conv_has_bias:
del params[affine.input[2]]
removed_tensors.append(affine.input[2])
removed_tensors.append(affine.input[1])
del params[affine.input[1]]
net.op.extend(new_ops)
break
return net, params, removed_tensors
def fuse_affine(net, params, ignore_failure):
# Run until we hit a fixed point
removed_tensors = []
while True:
(next_net, next_params, removed_tensors) = \
fuse_first_affine(net, params, removed_tensors)
if len(next_net.op) == len(net.op):
if (
any(op.type == "AffineChannel" for op in next_net.op) and
not ignore_failure
):
raise Exception(
"Model contains AffineChannel op after fusion: %s", next_net)
return (next_net, next_params, removed_tensors)
net, params, removed_tensors = (next_net, next_params, removed_tensors)
def fuse_net(fuse_func, net, blobs, ignore_failure=False):
is_core_net = isinstance(net, core.Net)
if is_core_net:
net = net.Proto()
net, params, removed_tensors = fuse_func(net, blobs, ignore_failure)
for rt in removed_tensors:
net.external_input.remove(rt)
if is_core_net:
net = core.Net(net)
return net, params
def fuse_net_affine(net, blobs):
return fuse_net(fuse_affine, net, blobs)
def add_tensor(net, name, blob):
''' Create an operator to store the tensor 'blob',
run the operator to put the blob to workspace.
uint8 is stored as an array of string with one element.
'''
kTypeNameMapper = {
np.dtype('float32'): "GivenTensorFill",
np.dtype('int32'): "GivenTensorIntFill",
np.dtype('int64'): "GivenTensorInt64Fill",
np.dtype('uint8'): "GivenTensorStringFill",
}
shape = blob.shape
values = blob
# pass array of uint8 as a string to save storage
# storing uint8_t has a large overhead for now
if blob.dtype == np.dtype('uint8'):
shape = [1]
values = [str(blob.data)]
op = core.CreateOperator(
kTypeNameMapper[blob.dtype],
[], [name],
shape=shape,
values=values,
# arg=[
# putils.MakeArgument("shape", shape),
# putils.MakeArgument("values", values),
# ]
)
net.op.extend([op])
def gen_init_net_from_blobs(blobs, blobs_to_use=None, excluded_blobs=None):
''' Generate an initialization net based on a blob dict '''
ret = caffe2_pb2.NetDef()
if blobs_to_use is None:
blobs_to_use = {x for x in blobs}
else:
blobs_to_use = copy.deepcopy(blobs_to_use)
if excluded_blobs is not None:
blobs_to_use = [x for x in blobs_to_use if x not in excluded_blobs]
for name in blobs_to_use:
blob = blobs[name]
if isinstance(blob, str):
print('Blob {} with type {} is not supported in generating init net,'
' skipped.'.format(name, type(blob)))
continue
add_tensor(ret, name, blob)
return ret
def get_ws_blobs(blob_names=None):
''' Get blobs in 'blob_names' in the default workspace,
get all blobs if blob_names is None '''
blobs = {}
if blob_names is None:
blob_names = workspace.Blobs()
blobs = {x: workspace.FetchBlob(x) for x in blob_names}
return blobs
def get_device_option_cpu():
device_option = core.DeviceOption(caffe2_pb2.CPU)
return device_option
def get_device_option_cuda(gpu_id=0):
device_option = caffe2_pb2.DeviceOption()
device_option.device_type = caffe2_pb2.CUDA
device_option.cuda_gpu_id = gpu_id
return device_option
def create_input_blobs_for_net(net_def):
for op in net_def.op:
for blob_in in op.input:
if not workspace.HasBlob(blob_in):
workspace.CreateBlob(blob_in)
def compare_model(model1_func, model2_func, test_image, check_blobs):
''' model_func(test_image, check_blobs)
'''
cb1, cb2 = check_blobs, check_blobs
if isinstance(check_blobs, dict):
cb1 = check_blobs.keys()
cb2 = check_blobs.values()
print('Running the first model...')
res1 = model1_func(test_image, check_blobs)
print('Running the second model...')
res2 = model2_func(test_image, check_blobs)
for idx in range(len(cb1)):
print('Checking {} -> {}...'.format(cb1[idx], cb2[idx]))
n1, n2 = cb1[idx], cb2[idx]
r1 = res1[n1] if n1 in res1 else None
r2 = res2[n2] if n2 in res2 else None
assert r1 is not None or r2 is None, \
"Blob {} in model1 is None".format(n1)
assert r2 is not None or r1 is None, \
"Blob {} in model2 is None".format(n2)
assert r1.shape == r2.shape, \
"Blob {} and {} shape mismatched: {} vs {}".format(
n1, n2, r1.shape, r2.shape)
np.testing.assert_array_almost_equal(
r1, r2, decimal=3,
err_msg='{} and {} not matched. Max diff: {}'.format(
n1, n2, np.amax(np.absolute(r1 - r2))))
return True
# graph_name could not contain word 'graph'
def save_graph(net, file_name, graph_name="net", op_only=True):
from caffe2.python import net_drawer
graph = None
ops = net.op
if not op_only:
graph = net_drawer.GetPydotGraph(
ops, graph_name,
rankdir="TB")
else:
graph = net_drawer.GetPydotGraphMinimal(
ops, graph_name,
rankdir="TB", minimal_dependency=True)
try:
graph.write_png(file_name)
except Exception as e:
print('Error when writing graph to image {}'.format(e))
#!/usr/bin/env python2
# Copyright (c) 2017-present, Facebook, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
##############################################################################
"""Script to convert the model (.yaml and .pkl) trained by train_net to a
standard Caffe2 model in pb format (model.pb and model_init.pb). The converted
model is good for production usage, as it could run independently and efficiently
on CPU, GPU and mobile without depending on the detectron codebase.
Please see Caffe2 tutorial (
https://caffe2.ai/docs/tutorial-loading-pre-trained-models.html) for loading
the converted model, and run_model_pb() for running the model for inference.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import cv2 # NOQA (Must import before importing caffe2 due to bug in cv2)
import argparse
import copy
import pprint
import numpy as np
import os
import sys
import caffe2.python.utils as putils
from caffe2.python import core, workspace
from caffe2.proto import caffe2_pb2
from core.config import assert_and_infer_cfg
from core.config import cfg
from core.config import merge_cfg_from_file
from core.config import merge_cfg_from_list
from modeling import generate_anchors
import core.test_engine as test_engine
import utils.c2 as c2_utils
import utils.vis as vis_utils
import utils.logging
import utils.model_convert_utils as mutils
from utils.model_convert_utils import op_filter, convert_op_in_proto
c2_utils.import_contrib_ops()
c2_utils.import_detectron_ops()
logger = utils.logging.setup_logging(__name__)
def parse_args():
parser = argparse.ArgumentParser(
description='Convert a trained network to pb format'
)
parser.add_argument(
'--cfg', dest='cfg_file', help='optional config file', default=None,
type=str)
parser.add_argument(
'--net_name', dest='net_name', help='optional name for the net',
default="detectron", type=str)
parser.add_argument(
'--out_dir', dest='out_dir', help='output dir', default=None,
type=str)
parser.add_argument(
'--test_img', dest='test_img',
help='optional test image, used to verify the model conversion',
default=None,
type=str)
parser.add_argument(
'--fuse_af', dest='fuse_af', help='1 to fuse_af',
default=1,
type=int)
parser.add_argument(
'--device', dest='device',
help='Device to run the model on',
choices=['cpu', 'gpu'],
default='cpu',
type=str)
parser.add_argument(
'--net_execution_type', dest='net_execution_type',
help='caffe2 net execution type',
choices=['simple', 'dag'],
default='simple',
type=str)
parser.add_argument(
'--use_nnpack', dest='use_nnpack',
help='Use nnpack for conv',
default=1,
type=int)
parser.add_argument(
'opts', help='See lib/core/config.py for all options', default=None,
nargs=argparse.REMAINDER)
if len(sys.argv) == 1:
parser.print_help()
sys.exit(1)
ret = parser.parse_args()
ret.out_dir = os.path.abspath(ret.out_dir)
if ret.device == 'gpu' and ret.use_nnpack:
logger.warn('Should not use mobile engine for gpu model.')
ret.use_nnpack = 0
return ret
def unscope_name(name):
return c2_utils.UnscopeName(name)
def reset_names(names):
for i in range(0, len(names)):
names[i] = unscope_name(names[i])
def convert_gen_proposals(
op, blobs,
rpn_pre_nms_topN,
rpn_post_nms_topN,
rpn_nms_thres,
rpn_min_size,
):
print('Converting GenerateProposals Python -> C++:\n{}'.format(op))
assert op.name.startswith("GenerateProposalsOp"), "Not valid GenerateProposalsOp"
spatial_scale = mutils.get_op_arg_valf(op, "spatial_scale", None)
assert spatial_scale is not None
inputs = [x for x in op.input]
anchor_name = "anchor"
inputs.append(anchor_name)
blobs[anchor_name] = get_anchors(spatial_scale)
print('anchors {}'.format(blobs[anchor_name]))
ret = core.CreateOperator(
"GenerateProposals",
inputs,
list(op.output),
spatial_scale=spatial_scale,
pre_nms_topN=rpn_pre_nms_topN,
post_nms_topN=rpn_post_nms_topN,
nms_thres=rpn_nms_thres,
min_size=rpn_min_size,
correct_transform_coords=True,
)
return ret, anchor_name
def get_anchors(spatial_scale):
anchors = generate_anchors.generate_anchors(
stride=1. / spatial_scale,
sizes=cfg.RPN.SIZES,
aspect_ratios=cfg.RPN.ASPECT_RATIOS).astype(np.float32)
return anchors
def reset_blob_names(blobs):
ret = {unscope_name(x): blobs[x] for x in blobs}
blobs.clear()
blobs.update(ret)
def convert_net(args, net, blobs):
@op_filter()
def convert_op_name(op):
if args.device != 'gpu':
if op.engine != 'DEPTHWISE_3x3':
op.engine = ''
op.device_option.CopyFrom(caffe2_pb2.DeviceOption())
reset_names(op.input)
reset_names(op.output)
return [op]
@op_filter(type="Python", inputs=['rpn_cls_probs', 'rpn_bbox_pred', 'im_info'])
def convert_gen_proposal(op_in):
gen_proposals_op, ext_input = convert_gen_proposals(
op_in, blobs,
rpn_min_size=float(cfg.TEST.RPN_MIN_SIZE),
rpn_post_nms_topN=cfg.TEST.RPN_POST_NMS_TOP_N,
rpn_pre_nms_topN=cfg.TEST.RPN_PRE_NMS_TOP_N,
rpn_nms_thres=cfg.TEST.RPN_NMS_THRESH,
)
net.external_input.extend([ext_input])
return [gen_proposals_op]
@op_filter(input_has='rois')
def convert_rpn_rois(op):
for j in range(0, len(op.input)):
if op.input[j] == 'rois':
print('Converting op {} input name: rois -> rpn_rois:\n{}'.format(
op.type, op))
op.input[j] = 'rpn_rois'
return [op]
@op_filter(type_in=['StopGradient', 'Alias'])
def convert_remove_op(op):
print('Removing op {}:\n{}'.format(op.type, op))
return []
convert_op_in_proto(net, convert_op_name)
convert_op_in_proto(net, [
convert_gen_proposal, convert_rpn_rois, convert_remove_op
])
reset_names(net.external_input)
reset_names(net.external_output)
reset_blob_names(blobs)
def add_bbox_ops(args, net, blobs):
new_ops = []
new_external_outputs = []
# Operators for bboxes
op_box = core.CreateOperator(
"BBoxTransform",
['rpn_rois', 'bbox_pred', 'im_info'],
['pred_bbox'],
weights=cfg.MODEL.BBOX_REG_WEIGHTS,
apply_scale=False,
correct_transform_coords=True,
)
new_ops.extend([op_box])
blob_prob = 'cls_prob'
blob_box = 'pred_bbox'
op_nms = core.CreateOperator(
"BoxWithNMSLimit",
[blob_prob, blob_box],
['score_nms', 'bbox_nms', 'class_nms'],
arg=[
putils.MakeArgument("score_thresh", cfg.TEST.SCORE_THRESH),
putils.MakeArgument("nms", cfg.TEST.NMS),
putils.MakeArgument("detections_per_im", cfg.TEST.DETECTIONS_PER_IM),
putils.MakeArgument("soft_nms_enabled", cfg.TEST.SOFT_NMS.ENABLED),
putils.MakeArgument("soft_nms_method", cfg.TEST.SOFT_NMS.METHOD),
putils.MakeArgument("soft_nms_sigma", cfg.TEST.SOFT_NMS.SIGMA),
]
)
new_ops.extend([op_nms])
new_external_outputs.extend(['score_nms', 'bbox_nms', 'class_nms'])
net.Proto().op.extend(new_ops)
net.Proto().external_output.extend(new_external_outputs)
def convert_model_gpu(args, net, init_net):
assert args.device == 'gpu'
ret_net = copy.deepcopy(net)
ret_init_net = copy.deepcopy(init_net)
cdo_cuda = mutils.get_device_option_cuda()
cdo_cpu = mutils.get_device_option_cpu()
CPU_OPS = [
["GenerateProposals", None],
["BBoxTransform", None],
["BoxWithNMSLimit", None],
]
CPU_BLOBS = ["im_info", "anchor"]
@op_filter()
def convert_op_gpu(op):
for x in CPU_OPS:
if mutils.filter_op(op, type=x[0], inputs=x[1]):
return None
op.device_option.CopyFrom(cdo_cuda)
return [op]
@op_filter()
def convert_init_op_gpu(op):
if op.output[0] in CPU_BLOBS:
op.device_option.CopyFrom(cdo_cpu)
else:
op.device_option.CopyFrom(cdo_cuda)
return [op]
convert_op_in_proto(ret_init_net.Proto(), convert_init_op_gpu)
convert_op_in_proto(ret_net.Proto(), convert_op_gpu)
ret = core.InjectDeviceCopiesAmongNets([ret_init_net, ret_net])
return [ret[0][1], ret[0][0]]
def gen_init_net(net, blobs, empty_blobs):
blobs = copy.deepcopy(blobs)
for x in empty_blobs:
blobs[x] = np.array([], dtype=np.float32)
init_net = mutils.gen_init_net_from_blobs(
blobs, net.external_inputs)
init_net = core.Net(init_net)
return init_net
def _save_image_graphs(args, all_net, all_init_net):
print('Saving model graph...')
mutils.save_graph(
all_net.Proto(), os.path.join(args.out_dir, "model_def.png"),
op_only=False)
print('Model def image saved to {}.'.format(args.out_dir))
def _save_models(all_net, all_init_net, args):
print('Writing converted model to {}...'.format(args.out_dir))
fname = "model"
if not os.path.exists(args.out_dir):
os.makedirs(args.out_dir)
with open(os.path.join(args.out_dir, fname + '.pb'), 'w') as f:
f.write(all_net.Proto().SerializeToString())
with open(os.path.join(args.out_dir, fname + '.pbtxt'), 'w') as f:
f.write(str(all_net.Proto()))
with open(os.path.join(args.out_dir, fname + '_init.pb'), 'w') as f:
f.write(all_init_net.Proto().SerializeToString())
_save_image_graphs(args, all_net, all_init_net)
def load_model(args):
model = test_engine.initialize_model_from_cfg()
blobs = mutils.get_ws_blobs()
return model, blobs
def _get_result_blobs(check_blobs):
ret = {}
for x in check_blobs:
sn = core.ScopedName(x)
if workspace.HasBlob(sn):
ret[x] = workspace.FetchBlob(sn)
else:
ret[x] = None
return ret
def _sort_results(boxes, segms, keypoints, classes):
indices = np.argsort(boxes[:, -1])[::-1]
if boxes is not None:
boxes = boxes[indices, :]
if segms is not None:
segms = [segms[x] for x in indices]
if keypoints is not None:
keypoints = [keypoints[x] for x in indices]
if classes is not None:
if isinstance(classes, list):
classes = [classes[x] for x in indices]
else:
classes = classes[indices]
return boxes, segms, keypoints, classes
def run_model_cfg(args, im, check_blobs):
workspace.ResetWorkspace()
model, _ = load_model(args)
with c2_utils.NamedCudaScope(0):
cls_boxes, cls_segms, cls_keyps = test_engine.im_detect_all(
model, im, None, None,
)
boxes, segms, keypoints, classes = vis_utils.convert_from_cls_format(
cls_boxes, cls_segms, cls_keyps)
# sort the results based on score for comparision
boxes, segms, keypoints, classes = _sort_results(
boxes, segms, keypoints, classes)
# write final results back to workspace
def _ornone(res):
return np.array(res) if res is not None else np.array([], dtype=np.float32)
with c2_utils.NamedCudaScope(0):
workspace.FeedBlob(core.ScopedName('result_boxes'), _ornone(boxes))
workspace.FeedBlob(core.ScopedName('result_segms'), _ornone(segms))
workspace.FeedBlob(core.ScopedName('result_keypoints'), _ornone(keypoints))
workspace.FeedBlob(core.ScopedName('result_classids'), _ornone(classes))
# get result blobs
with c2_utils.NamedCudaScope(0):
ret = _get_result_blobs(check_blobs)
return ret
def _prepare_blobs(
im,
pixel_means,
target_size,
max_size,
):
''' Reference: blob.prep_im_for_blob() '''
im = im.astype(np.float32, copy=False)
im -= pixel_means
im_shape = im.shape
im_size_min = np.min(im_shape[0:2])
im_size_max = np.max(im_shape[0:2])
im_scale = float(target_size) / float(im_size_min)
if np.round(im_scale * im_size_max) > max_size:
im_scale = float(max_size) / float(im_size_max)
im = cv2.resize(im, None, None, fx=im_scale, fy=im_scale,
interpolation=cv2.INTER_LINEAR)
blob = np.zeros([1, im.shape[0], im.shape[1], 3], dtype=np.float32)
blob[0, :, :, :] = im
channel_swap = (0, 3, 1, 2) # swap channel to (k, c, h, w)
blob = blob.transpose(channel_swap)
blobs = {}
blobs['data'] = blob
blobs['im_info'] = np.array(
[[blob.shape[2], blob.shape[3], im_scale]],
dtype=np.float32
)
return blobs
def run_model_pb(args, net, init_net, im, check_blobs):
assert len(cfg.TEST.SCALES) == 1
workspace.ResetWorkspace()
workspace.RunNetOnce(init_net)
mutils.create_input_blobs_for_net(net.Proto())
workspace.CreateNet(net)
# input_blobs, _ = core_test._get_blobs(im, None)
input_blobs = _prepare_blobs(
im,
cfg.PIXEL_MEANS,
cfg.TEST.SCALES[0], cfg.TEST.MAX_SIZE
)
gpu_blobs = []
if args.device == 'gpu':
gpu_blobs = ['data']
for k, v in input_blobs.items():
workspace.FeedBlob(
core.ScopedName(k),
v,
mutils.get_device_option_cuda() if k in gpu_blobs else
mutils.get_device_option_cpu()
)
try:
workspace.RunNet(net.Proto().name)
scores = workspace.FetchBlob('score_nms')
classids = workspace.FetchBlob('class_nms')
boxes = workspace.FetchBlob('bbox_nms')
except Exception as e:
print('Running pb model failed.\n{}'.format(e))
# may not detect anything at all
R = 0
scores = np.zeros((R,), dtype=np.float32)
boxes = np.zeros((R, 4), dtype=np.float32)
classids = np.zeros((R,), dtype=np.float32)
boxes = np.column_stack((boxes, scores))
# sort the results based on score for comparision
boxes, _, _, classids = _sort_results(
boxes, None, None, classids)
# write final result back to workspace
workspace.FeedBlob('result_boxes', boxes)
workspace.FeedBlob('result_classids', classids)
ret = _get_result_blobs(check_blobs)
return ret
def verify_model(args, model_pb, test_img_file):
check_blobs = [
"result_boxes", "result_classids", # result
]
print('Loading test file {}...'.format(test_img_file))
test_img = cv2.imread(test_img_file)
assert test_img is not None
def _run_cfg_func(im, blobs):
return run_model_cfg(args, im, check_blobs)
def _run_pb_func(im, blobs):
return run_model_pb(args, model_pb[0], model_pb[1], im, check_blobs)
print('Checking models...')
assert mutils.compare_model(
_run_cfg_func, _run_pb_func, test_img, check_blobs)
def main():
workspace.GlobalInit(['caffe2', '--caffe2_log_level=0'])
args = parse_args()
logger.info('Called with args:')
logger.info(args)
if args.cfg_file is not None:
merge_cfg_from_file(args.cfg_file)
if args.opts is not None:
merge_cfg_from_list(args.opts)
cfg.NUM_GPUS = 1
assert_and_infer_cfg()
logger.info('Conerting model with config:')
logger.info(pprint.pformat(cfg))
assert not cfg.MODEL.KEYPOINTS_ON, "Keypoint model not supported."
assert not cfg.MODEL.MASK_ON, "Mask model not supported."
assert not cfg.FPN.FPN_ON, "FPN not supported."
assert not cfg.RETINANET.RETINANET_ON, "RetinaNet model not supported."
# load model from cfg
model, blobs = load_model(args)
net = core.Net('')
net.Proto().op.extend(copy.deepcopy(model.net.Proto().op))
net.Proto().external_input.extend(
copy.deepcopy(model.net.Proto().external_input))
net.Proto().external_output.extend(
copy.deepcopy(model.net.Proto().external_output))
net.Proto().type = args.net_execution_type
net.Proto().num_workers = 1 if args.net_execution_type == 'simple' else 4
# Reset the device_option, change to unscope name and replace python operators
convert_net(args, net.Proto(), blobs)
# add operators for bbox
add_bbox_ops(args, net, blobs)
if args.fuse_af:
print('Fusing affine channel...')
net, blobs = mutils.fuse_net_affine(
net, blobs)
if args.use_nnpack:
mutils.update_mobile_engines(net.Proto())
# generate init net
empty_blobs = ['data', 'im_info']
init_net = gen_init_net(net, blobs, empty_blobs)
if args.device == 'gpu':
[net, init_net] = convert_model_gpu(args, net, init_net)
net.Proto().name = args.net_name
init_net.Proto().name = args.net_name + "_init"
if args.test_img is not None:
verify_model(args, [net, init_net], args.test_img)
_save_models(net, init_net, args)
if __name__ == '__main__':
main()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment