From 9ce72be41ba8db4629e4647624998f2345c37f91 Mon Sep 17 00:00:00 2001 From: Deyu Huang Date: Fri, 8 Jul 2022 18:35:44 +0800 Subject: [PATCH] Add --outputs_as_nchw option to transpose output to from nhwc to nchw (#1979) * add output_as_nchw Signed-off-by: Deyu Huang * fix node replace logic Signed-off-by: Deyu Huang * add tests for outputs as nchw Signed-off-by: Deyu Huang * add it into function and doc Signed-off-by: Deyu Huang * fix output_names_with_port range Signed-off-by: Deyu Huang * fix the input_as_nchw description Signed-off-by: Deyu Huang * change tests name Signed-off-by: Deyu Huang --- README.md | 26 +++++++++++-------- tests/backend_test_base.py | 20 ++++++++++++-- tests/test_backend.py | 13 +++++++++- tf2onnx/convert.py | 44 ++++++++++++++++++++----------- tf2onnx/tfonnx.py | 53 ++++++++++++++++++++++++++++---------- 5 files changed, 114 insertions(+), 42 deletions(-) diff --git a/README.md b/README.md index ad640e21e..bd4753e6b 100644 --- a/README.md +++ b/README.md @@ -292,8 +292,8 @@ import tf2onnx model_proto, external_tensor_storage = tf2onnx.convert.from_keras(model, input_signature=None, opset=None, custom_ops=None, custom_op_handlers=None, custom_rewriter=None, - inputs_as_nchw=None, extra_opset=None shape_override=None, - target=None, large_model=False, output_path=None) + inputs_as_nchw=None, outputs_as_nchw=None, extra_opset=None, + shape_override=None, target=None, large_model=False, output_path=None) Args: model: the tf.keras model we want to convert @@ -307,7 +307,8 @@ model_proto, external_tensor_storage = tf2onnx.convert.from_keras(model, custom_rewriter: list of custom graph rewriters extra_opset: list of extra opset's, for example the opset's used by custom ops shape_override: dict with inputs that override the shapes given by tensorflow - inputs_as_nchw: transpose inputs in list from nchw to nhwc + inputs_as_nchw: transpose inputs in list from nhwc to nchw + outputs_as_nchw: transpose outputs in list from nhwc to nchw large_model: use the ONNX external tensor storage format output_path: save model to output_path @@ -323,8 +324,8 @@ import tf2onnx model_proto, external_tensor_storage = tf2onnx.convert.from_function(function, input_signature=None, opset=None, custom_ops=None, - custom_op_handlers=None, custom_rewriter=None, - inputs_as_nchw=None, extra_opset=None, shape_override=None, + custom_op_handlers=None, custom_rewriter=None, inputs_as_nchw=None, + outputs_as_nchw=None, extra_opset=None, shape_override=None, target=None, large_model=False, output_path=None) Args: @@ -339,7 +340,8 @@ model_proto, external_tensor_storage = tf2onnx.convert.from_function(function, custom_rewriter: list of custom graph rewriters extra_opset: list of extra opset's, for example the opset's used by custom ops shape_override: dict with inputs that override the shapes given by tensorflow - inputs_as_nchw: transpose inputs in list from nchw to nhwc + inputs_as_nchw: transpose inputs in list from nhwc to nchw + outputs_as_nchw: transpose outputs in list from nhwc to nchw large_model: use the ONNX external tensor storage format output_path: save model to output_path @@ -354,7 +356,7 @@ import tf2onnx model_proto, external_tensor_storage = tf2onnx.convert.from_graph_def(graph_def, name=None, input_names=None, output_names=None, opset=None, custom_ops=None, custom_op_handlers=None, custom_rewriter=None, - inputs_as_nchw=None, extra_opset=None, + inputs_as_nchw=None, outputs_as_nchw=None, extra_opset=None, shape_override=None, target=None, large_model=False, output_path=None) @@ -369,7 +371,8 @@ model_proto, external_tensor_storage = tf2onnx.convert.from_graph_def(graph_def, custom_rewriter: list of custom graph rewriters extra_opset: list of extra opset's, for example the opset's used by custom ops shape_override: dict with inputs that override the shapes given by tensorflow - inputs_as_nchw: transpose inputs in list from nchw to nhwc + inputs_as_nchw: transpose inputs in list from nhwc to nchw + outputs_as_nchw: transpose outputs in list from nhwc to nchw large_model: use the ONNX external tensor storage format output_path: save model to output_path @@ -383,8 +386,8 @@ import tf2onnx model_proto, external_tensor_storage = tf2onnx.convert.from_tflite(tflite_path, input_names=None, output_names=None, opset=None, custom_ops=None, custom_op_handlers=None, - custom_rewriter=None, inputs_as_nchw=None, extra_opset=None, shape_override=None, target=None, - large_model=False, output_path=None): + custom_rewriter=None, inputs_as_nchw=None, outputs_as_nchw=None, extra_opset=None, + shape_override=None, target=None, large_model=False, output_path=None): Args: tflite_path: the tflite model file full path @@ -396,7 +399,8 @@ model_proto, external_tensor_storage = tf2onnx.convert.from_tflite(tflite_path, runtime can still open the model. Type is a dictionary `{op name: domain}`. custom_op_handlers: dictionary of custom ops handlers custom_rewriter: list of custom graph rewriters - inputs_as_nchw: transpose inputs in list from nchw to nhwc + inputs_as_nchw: transpose inputs in list from nhwc to nchw + outputs_as_nchw: transpose outputs in list from nhwc to nchw extra_opset: list of extra opset's, for example the opset's used by custom ops shape_override: dict with inputs that override the shapes given by tensorflow target: list of workarounds applied to help certain platforms diff --git a/tests/backend_test_base.py b/tests/backend_test_base.py index f39c398b1..38cc52dcf 100644 --- a/tests/backend_test_base.py +++ b/tests/backend_test_base.py @@ -20,6 +20,7 @@ import onnx from common import get_test_config from tfjs_runner import run_tfjs +from tf2onnx import constants from tf2onnx import utils from tf2onnx.tfonnx import process_tf_graph from tf2onnx import optimizer @@ -366,6 +367,7 @@ def run_test_case(self, func, feed_dict, input_names_with_port, output_names_wit graph_def_path = os.path.join(self.test_data_directory, self._testMethodName + "_after_tf_optimize.pb") utils.save_protobuf(graph_def_path, graph_def) self.logger.debug("created file %s", graph_def_path) + tfl_process_args = process_args.copy() if test_tfjs: tfjs_path = self.convert_to_tfjs(graph_def_path, output_names_with_port) @@ -395,6 +397,10 @@ def run_test_case(self, func, feed_dict, input_names_with_port, output_names_wit g = optimizer.optimize_graph(g, catch_errors=False) actual = self.run_backend(g, output_names_with_port, onnx_feed_dict, large_model, use_custom_ops=use_custom_ops) + if 'outputs_as_nchw' in tfl_process_args: + for output_name in tfl_process_args['outputs_as_nchw']: + i = output_names_with_port.index(output_name) + actual[i] = np.transpose(actual[i], constants.NCHW_TO_NHWC) self.assert_results_equal(expected, actual, rtol, atol, mtol, check_value, check_shape, check_dtype) self.assert_shapes_correct(g, self.config.allow_missing_shapes, not self.config.skip_onnx_checker) @@ -410,12 +416,14 @@ def run_test_case(self, func, feed_dict, input_names_with_port, output_names_wit if run_tfl_consistency_test: self.assert_results_equal(expected, tfl_res, rtol, atol, mtol, check_value, check_shape, check_dtype) - tfl_process_args = process_args.copy() if 'inputs_as_nchw' in tfl_process_args: nchw_inps_with_port = tfl_process_args['inputs_as_nchw'] tfl_process_args['inputs_as_nchw'] = [i.split(':')[0] for i in nchw_inps_with_port] input_names_without_port = [inp.split(':')[0] for inp in feed_dict.keys()] - + if 'outputs_as_nchw' in tfl_process_args: + nchw_outps_with_port = tfl_process_args['outputs_as_nchw'] + tfl_process_args['outputs_as_nchw'] = [i.split(':')[0] for i in nchw_outps_with_port] + output_names_with_port = [i.split(':')[0] for i in nchw_outps_with_port] g = process_tf_graph(None, opset=self.config.opset, input_names=input_names_without_port, output_names=tfl_outputs, @@ -427,6 +435,10 @@ def run_test_case(self, func, feed_dict, input_names_with_port, output_names_wit onnx_feed_dict_without_port = {k.split(':')[0]: v for k, v in onnx_feed_dict.items()} onnx_tfl_res = self.run_backend(g, tfl_outputs, onnx_feed_dict_without_port, postfix="_from_tflite", use_custom_ops=use_custom_ops) + if 'outputs_as_nchw' in tfl_process_args: + for output_name in tfl_process_args['outputs_as_nchw']: + i = output_names_with_port.index(output_name) + onnx_tfl_res[i] = np.transpose(onnx_tfl_res[i], constants.NCHW_TO_NHWC) self.assert_results_equal(tfl_res, onnx_tfl_res, rtol, atol, mtol, check_value, check_shape, check_dtype) self.assert_shapes_correct(g, self.config.allow_missing_shapes, not self.config.skip_onnx_checker) @@ -456,6 +468,10 @@ def run_test_case(self, func, feed_dict, input_names_with_port, output_names_wit g = optimizer.optimize_graph(g) onnx_tfjs_res = self.run_backend(g, None, onnx_feed_dict, large_model, postfix="_from_tfjs", use_custom_ops=use_custom_ops) + if 'outputs_as_nchw' in tfl_process_args: + for output_name in tfl_process_args['outputs_as_nchw']: + i = output_names_with_port.index(output_name) + onnx_tfjs_res[i] = np.transpose(onnx_tfjs_res[i], constants.NCHW_TO_NHWC) self.assert_results_equal(tfjs_res, onnx_tfjs_res, rtol, atol, mtol, check_value, check_shape, check_dtype=False) diff --git a/tests/test_backend.py b/tests/test_backend.py index 521978870..fe50e0591 100755 --- a/tests/test_backend.py +++ b/tests/test_backend.py @@ -712,7 +712,7 @@ def func(x): graph_validator=lambda g: (check_op_count(g, "RandomUniform", 0) and check_op_count(g, "RandomUniformLike", 0))) - def test_conv2d_with_input_transpose(self): + def test_inputs_as_nchw_arg(self): x_shape = [2, 32, 32, 3] kernel_shape = [3, 3, 3, 3] x_val = make_xval(x_shape) @@ -725,6 +725,17 @@ def func(x): process_args={"inputs_as_nchw": [_INPUT]}, onnx_feed_dict={_INPUT: x_val_for_onnx}) + def test_outputs_as_nchw_arg(self): + x_shape = [2, 32, 32, 3] + kernel_shape = [3, 3, 3, 3] + x_val = make_xval(x_shape) + def func(x): + kernel = tf.constant(make_xval(kernel_shape), dtype=tf.float32, name='kernel') + conv = tf.nn.conv2d(x, kernel, strides=[1, 1, 1, 1], padding="SAME") + return tf.identity(conv, name=_TFOUTPUT) + self._run_test_case(func, [_OUTPUT], {_INPUT: x_val}, rtol=1e-05, + process_args={"outputs_as_nchw": [_OUTPUT]}) + @skip_tflite("TFlite adds ops that obscure pattern") @check_tf_min_version("1.15") def test_conv1d_dilations_rewriter(self): diff --git a/tf2onnx/convert.py b/tf2onnx/convert.py index 0a1069496..32c28f0bc 100644 --- a/tf2onnx/convert.py +++ b/tf2onnx/convert.py @@ -86,11 +86,12 @@ def get_args(): # experimental parser.add_argument("--inputs-as-nchw", help="transpose inputs as from nhwc to nchw") + parser.add_argument("--outputs-as-nchw", help="transpose outputs as from nhwc to nchw") args = parser.parse_args() args.shape_override = None if args.input: - # for backward compativility + # for backward compatibility args.graphdef = args.input if args.graphdef or args.checkpoint: if not args.inputs or not args.outputs: @@ -112,6 +113,8 @@ def get_args(): args.rename_inputs = args.rename_inputs.split(",") if args.inputs_as_nchw: args.inputs_as_nchw = args.inputs_as_nchw.split(",") + if args.outputs_as_nchw: + args.outputs_as_nchw = args.outputs_as_nchw.split(",") if args.target: args.target = args.target.split(",") if args.signature_def: @@ -275,6 +278,7 @@ def main(): input_names=inputs, output_names=outputs, inputs_as_nchw=args.inputs_as_nchw, + outputs_as_nchw=args.outputs_as_nchw, large_model=args.large_model, tensors_to_rename=tensors_to_rename, ignore_default=args.ignore_default, @@ -356,8 +360,8 @@ def _is_legacy_keras_model(model): def _from_keras_tf1(model, opset=None, custom_ops=None, custom_op_handlers=None, custom_rewriter=None, - inputs_as_nchw=None, extra_opset=None, shape_override=None, target=None, - large_model=False, output_path=None): + inputs_as_nchw=None, outputs_as_nchw=None, extra_opset=None, shape_override=None, + target=None, large_model=False, output_path=None): """from_keras for tf 1.15""" input_names = [t.name for t in model.inputs] output_names = [t.name for t in model.outputs] @@ -392,6 +396,7 @@ def _from_keras_tf1(model, opset=None, custom_ops=None, custom_op_handlers=None, input_names=input_names, output_names=output_names, inputs_as_nchw=inputs_as_nchw, + outputs_as_nchw=outputs_as_nchw, large_model=large_model, tensors_to_rename=tensors_to_rename, initialized_tables=initialized_tables, @@ -401,7 +406,7 @@ def _from_keras_tf1(model, opset=None, custom_ops=None, custom_op_handlers=None, def from_keras(model, input_signature=None, opset=None, custom_ops=None, custom_op_handlers=None, - custom_rewriter=None, inputs_as_nchw=None, extra_opset=None, shape_override=None, + custom_rewriter=None, inputs_as_nchw=None, outputs_as_nchw=None, extra_opset=None, shape_override=None, target=None, large_model=False, output_path=None, optimizers=None): """Returns a ONNX model_proto for a tf.keras model. @@ -417,7 +422,8 @@ def from_keras(model, input_signature=None, opset=None, custom_ops=None, custom_ custom_rewriter: list of custom graph rewriters extra_opset: list of extra opset's, for example the opset's used by custom ops shape_override: dict with inputs that override the shapes given by tensorflow - inputs_as_nchw: transpose inputs in list from nchw to nhwc + inputs_as_nchw: transpose inputs in list from nhwc to nchw + outputs_as_nchw: transpose outputs in list from nhwc to nchw large_model: use the ONNX external tensor storage format output_path: save model to output_path optimizers: list (subset) of tf2onnx optimizers if applying all optimizers is not desired. @@ -427,7 +433,7 @@ def from_keras(model, input_signature=None, opset=None, custom_ops=None, custom_ """ if LooseVersion(tf.__version__) < "2.0": return _from_keras_tf1(model, opset, custom_ops, custom_op_handlers, custom_rewriter, inputs_as_nchw, - extra_opset, shape_override, target, large_model, output_path) + outputs_as_nchw, extra_opset, shape_override, target, large_model, output_path) old_out_names = _rename_duplicate_keras_model_names(model) from tensorflow.python.keras.saving import saving_utils as _saving_utils # pylint: disable=import-outside-toplevel @@ -500,6 +506,7 @@ def wrap_call(*args, training=False, **kwargs): input_names=input_names, output_names=output_names, inputs_as_nchw=inputs_as_nchw, + outputs_as_nchw=outputs_as_nchw, large_model=large_model, tensors_to_rename=tensors_to_rename, initialized_tables=initialized_tables, @@ -509,8 +516,8 @@ def wrap_call(*args, training=False, **kwargs): def from_function(function, input_signature=None, opset=None, custom_ops=None, custom_op_handlers=None, - custom_rewriter=None, inputs_as_nchw=None, extra_opset=None, shape_override=None, target=None, - large_model=False, output_path=None): + custom_rewriter=None, inputs_as_nchw=None, outputs_as_nchw=None, extra_opset=None, + shape_override=None, target=None, large_model=False, output_path=None): """Returns a ONNX model_proto for a tf.function. Args: @@ -525,7 +532,8 @@ def from_function(function, input_signature=None, opset=None, custom_ops=None, c custom_rewriter: list of custom graph rewriters extra_opset: list of extra opset's, for example the opset's used by custom ops shape_override: dict with inputs that override the shapes given by tensorflow - inputs_as_nchw: transpose inputs in list from nchw to nhwc + inputs_as_nchw: transpose inputs in list from nhwc to nchw + outputs_as_nchw: transpose outputs in list from nhwc to nchw large_model: use the ONNX external tensor storage format output_path: save model to output_path @@ -564,6 +572,7 @@ def from_function(function, input_signature=None, opset=None, custom_ops=None, c input_names=input_names, output_names=output_names, inputs_as_nchw=inputs_as_nchw, + outputs_as_nchw=outputs_as_nchw, large_model=large_model, tensors_to_rename=tensors_to_rename, initialized_tables=initialized_tables, @@ -573,8 +582,9 @@ def from_function(function, input_signature=None, opset=None, custom_ops=None, c def from_graph_def(graph_def, name=None, input_names=None, output_names=None, opset=None, custom_ops=None, - custom_op_handlers=None, custom_rewriter=None, inputs_as_nchw=None, extra_opset=None, - shape_override=None, target=None, large_model=False, tensors_to_rename=None, output_path=None): + custom_op_handlers=None, custom_rewriter=None, inputs_as_nchw=None, outputs_as_nchw=None, + extra_opset=None, shape_override=None, target=None, large_model=False, + tensors_to_rename=None, output_path=None): """Returns a ONNX model_proto for a tensorflow graphdef. Args: @@ -591,7 +601,8 @@ def from_graph_def(graph_def, name=None, input_names=None, output_names=None, op custom_rewriter: list of custom graph rewriters extra_opset: list of extra opset's, for example the opset's used by custom ops shape_override: dict with inputs that override the shapes given by tensorflow - inputs_as_nchw: transpose inputs in list from nchw to nhwc + inputs_as_nchw: transpose inputs in list from nhwc to nchw + outputs_as_nchw: transpose outputs in list from nhwc to nchw large_model: use the ONNX external tensor storage format output_path: save model to output_path @@ -628,6 +639,7 @@ def from_graph_def(graph_def, name=None, input_names=None, output_names=None, op input_names=input_names, output_names=output_names, inputs_as_nchw=inputs_as_nchw, + outputs_as_nchw=outputs_as_nchw, large_model=large_model, tensors_to_rename=tensors_to_rename, initialized_tables=initialized_tables, @@ -637,8 +649,8 @@ def from_graph_def(graph_def, name=None, input_names=None, output_names=None, op def from_tflite(tflite_path, input_names=None, output_names=None, opset=None, custom_ops=None, custom_op_handlers=None, - custom_rewriter=None, inputs_as_nchw=None, extra_opset=None, shape_override=None, target=None, - large_model=False, output_path=None): + custom_rewriter=None, inputs_as_nchw=None, outputs_as_nchw=None, extra_opset=None, shape_override=None, + target=None, large_model=False, output_path=None): """Returns a ONNX model_proto for a tflite model file. Args: @@ -651,7 +663,8 @@ def from_tflite(tflite_path, input_names=None, output_names=None, opset=None, cu runtime can still open the model. Type is a dictionary `{op name: domain}`. custom_op_handlers: dictionary of custom ops handlers custom_rewriter: list of custom graph rewriters - inputs_as_nchw: transpose inputs in list from nchw to nhwc + inputs_as_nchw: transpose inputs in list from nhwc to nchw + outputs_as_nchw: transpose outputs in list from nhwc to nchw extra_opset: list of extra opset's, for example the opset's used by custom ops shape_override: dict with inputs that override the shapes given by tensorflow target: list of workarounds applied to help certain platforms @@ -680,6 +693,7 @@ def from_tflite(tflite_path, input_names=None, output_names=None, opset=None, cu input_names=input_names, output_names=output_names, inputs_as_nchw=inputs_as_nchw, + outputs_as_nchw=outputs_as_nchw, large_model=large_model, tensors_to_rename=None, initialized_tables=None, diff --git a/tf2onnx/tfonnx.py b/tf2onnx/tfonnx.py index 1a351cfcb..c2c881e77 100644 --- a/tf2onnx/tfonnx.py +++ b/tf2onnx/tfonnx.py @@ -329,6 +329,29 @@ def transpose_inputs(ctx, inputs_as_nchw): ops.append(node) ctx.reset_nodes(ops) +def transpose_outputs(ctx, outputs_as_nchw): + """Insert a transpose from NHWC to NCHW on model output on users request.""" + ops = [] + for node in ctx.get_nodes(): + for output_name in node.output: + if output_name in outputs_as_nchw: + shape = ctx.get_shape(output_name) + if len(shape) != len(constants.NHWC_TO_NCHW): + logger.warning("transpose_output for %s: shape must be rank 4, ignored" % output_name) + ops.append(node) + continue + # insert transpose + op_name = utils.make_name(node.name) + transpose = ctx.insert_new_node_on_output("Transpose", node.input[0], name=op_name) + transpose.set_attr("perm", constants.NHWC_TO_NCHW) + ctx.copy_shape(node.output[0], transpose.output[0]) + ctx.set_shape(transpose.output[0], np.array(shape)[constants.NHWC_TO_NCHW]) + ctx.set_shape(output_name, np.array(shape)[constants.NHWC_TO_NCHW]) + ops.append(transpose) + ops.append(node) + continue + ops.append(node) + ctx.reset_nodes(ops) def topological_sort(g, continue_on_error): ops = g.get_nodes() @@ -376,7 +399,7 @@ def run_rewriters(g, funcs, continue_on_error): def process_tf_graph(tf_graph, continue_on_error=False, verbose=False, target=None, opset=None, custom_op_handlers=None, custom_rewriter=None, - extra_opset=None, shape_override=None, inputs_as_nchw=None, + extra_opset=None, shape_override=None, inputs_as_nchw=None, outputs_as_nchw=None, input_names=None, output_names=None, ignore_default=None, use_default=None, is_subgraph=False, const_node_values=None, tensors_to_rename=None, initialized_tables=None, tflite_path=None, dequantize=False, tfjs_path=None): @@ -391,7 +414,8 @@ def process_tf_graph(tf_graph, continue_on_error=False, verbose=False, target=No custom_rewriter: list of custom graph rewriters extra_opset: list of extra opset's, for example the opset's used by custom ops shape_override: dict with inputs that override the shapes given by tensorflow - inputs_as_nchw: transpose inputs in list from nchw to nhwc + inputs_as_nchw: transpose inputs in list from nhwc to nchw + outputs_as_nchw: transpose outputs in list from nhwc to nchw input_names: list of input node names in graph, input name format as node_name:port_id. Optional. output_names: list of output node names in graph, format is node_name:port_id. Optional for tflite. ignore_default: list of node names of PlaceholderWithDefault ops to change into Placeholder ops @@ -421,6 +445,8 @@ def process_tf_graph(tf_graph, continue_on_error=False, verbose=False, target=No clear_functions() if inputs_as_nchw is None: inputs_as_nchw = [] + if outputs_as_nchw is None: + outputs_as_nchw = [] is_tflite = False if tflite_path is not None: @@ -435,8 +461,8 @@ def process_tf_graph(tf_graph, continue_on_error=False, verbose=False, target=No for g in [main_g] + subgraphs: g.set_config(target, opset, extra_opset) - g = process_graphs(main_g, subgraphs, custom_op_handlers, inputs_as_nchw, continue_on_error, custom_rewriter, - initialized_tables, tensors_to_rename, is_tflite, dequantize) + g = process_graphs(main_g, subgraphs, custom_op_handlers, inputs_as_nchw, outputs_as_nchw, continue_on_error, + custom_rewriter, initialized_tables, tensors_to_rename, is_tflite, dequantize) return g @@ -476,24 +502,23 @@ def graphs_from_tf(tf_graph, input_names, output_names, shape_override=None, con return main_g, subgraphs -def process_graphs(main_g, subgraphs, custom_op_handlers, inputs_as_nchw, continue_on_error, custom_rewriter, - initialized_tables, tensors_to_rename, is_tflite=False, dequantize=False): - +def process_graphs(main_g, subgraphs, custom_op_handlers, inputs_as_nchw, outputs_as_nchw, continue_on_error, + custom_rewriter, initialized_tables, tensors_to_rename, is_tflite=False, dequantize=False): if tensors_to_rename is not None: main_g.rename_tensors(tensors_to_rename) inputs_as_nchw = [tensors_to_rename.get(t, t) for t in inputs_as_nchw] + outputs_as_nchw = [tensors_to_rename.get(t, t) for t in outputs_as_nchw] for g in subgraphs: - fg = process_parsed_graph(g, custom_op_handlers, inputs_as_nchw, continue_on_error, custom_rewriter, - initialized_tables, is_tflite, dequantize) + fg = process_parsed_graph(g, custom_op_handlers, inputs_as_nchw, outputs_as_nchw, continue_on_error, + custom_rewriter, initialized_tables, is_tflite, dequantize) set_function(fg.graph_name, fg) - g = process_parsed_graph(main_g, custom_op_handlers, inputs_as_nchw, continue_on_error, custom_rewriter, - initialized_tables, is_tflite, - dequantize) + g = process_parsed_graph(main_g, custom_op_handlers, inputs_as_nchw, outputs_as_nchw, continue_on_error, + custom_rewriter, initialized_tables, is_tflite, dequantize) return g -def process_parsed_graph(g, custom_op_handlers, inputs_as_nchw, continue_on_error, custom_rewriter, +def process_parsed_graph(g, custom_op_handlers, inputs_as_nchw, outputs_as_nchw, continue_on_error, custom_rewriter, initialized_tables, is_tflite=False, dequantize=False): op_cnt, attr_cnt = g.dump_node_statistics(include_attrs=True, include_subgraphs=False) @@ -549,6 +574,8 @@ def compat_handler(ctx, node, **kwargs): if inputs_as_nchw: transpose_inputs(g, inputs_as_nchw) + if outputs_as_nchw: + transpose_outputs(g, outputs_as_nchw) # pre-processing graph rewrites # bi-directional re-writer should be placed after single directional re-writer