Skip to content

Commit

Permalink
Add --outputs_as_nchw option to transpose output to from nhwc to nchw (
Browse files Browse the repository at this point in the history
…#1979)

* add output_as_nchw

Signed-off-by: Deyu Huang <[email protected]>

* fix node replace logic

Signed-off-by: Deyu Huang <[email protected]>

* add tests for outputs as nchw

Signed-off-by: Deyu Huang <[email protected]>

* add it into function and doc

Signed-off-by: Deyu Huang <[email protected]>

* fix output_names_with_port range

Signed-off-by: Deyu Huang <[email protected]>

* fix the input_as_nchw description

Signed-off-by: Deyu Huang <[email protected]>

* change tests name

Signed-off-by: Deyu Huang <[email protected]>
  • Loading branch information
hwangdeyu authored Jul 8, 2022
1 parent fa0b6cf commit 9ce72be
Show file tree
Hide file tree
Showing 5 changed files with 114 additions and 42 deletions.
26 changes: 15 additions & 11 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -292,8 +292,8 @@ import tf2onnx
model_proto, external_tensor_storage = tf2onnx.convert.from_keras(model,
input_signature=None, opset=None, custom_ops=None,
custom_op_handlers=None, custom_rewriter=None,
inputs_as_nchw=None, extra_opset=None shape_override=None,
target=None, large_model=False, output_path=None)
inputs_as_nchw=None, outputs_as_nchw=None, extra_opset=None,
shape_override=None, target=None, large_model=False, output_path=None)
Args:
model: the tf.keras model we want to convert
Expand All @@ -307,7 +307,8 @@ model_proto, external_tensor_storage = tf2onnx.convert.from_keras(model,
custom_rewriter: list of custom graph rewriters
extra_opset: list of extra opset's, for example the opset's used by custom ops
shape_override: dict with inputs that override the shapes given by tensorflow
inputs_as_nchw: transpose inputs in list from nchw to nhwc
inputs_as_nchw: transpose inputs in list from nhwc to nchw
outputs_as_nchw: transpose outputs in list from nhwc to nchw
large_model: use the ONNX external tensor storage format
output_path: save model to output_path
Expand All @@ -323,8 +324,8 @@ import tf2onnx
model_proto, external_tensor_storage = tf2onnx.convert.from_function(function,
input_signature=None, opset=None, custom_ops=None,
custom_op_handlers=None, custom_rewriter=None,
inputs_as_nchw=None, extra_opset=None, shape_override=None,
custom_op_handlers=None, custom_rewriter=None, inputs_as_nchw=None,
outputs_as_nchw=None, extra_opset=None, shape_override=None,
target=None, large_model=False, output_path=None)
Args:
Expand All @@ -339,7 +340,8 @@ model_proto, external_tensor_storage = tf2onnx.convert.from_function(function,
custom_rewriter: list of custom graph rewriters
extra_opset: list of extra opset's, for example the opset's used by custom ops
shape_override: dict with inputs that override the shapes given by tensorflow
inputs_as_nchw: transpose inputs in list from nchw to nhwc
inputs_as_nchw: transpose inputs in list from nhwc to nchw
outputs_as_nchw: transpose outputs in list from nhwc to nchw
large_model: use the ONNX external tensor storage format
output_path: save model to output_path
Expand All @@ -354,7 +356,7 @@ import tf2onnx
model_proto, external_tensor_storage = tf2onnx.convert.from_graph_def(graph_def,
name=None, input_names=None, output_names=None, opset=None,
custom_ops=None, custom_op_handlers=None, custom_rewriter=None,
inputs_as_nchw=None, extra_opset=None,
inputs_as_nchw=None, outputs_as_nchw=None, extra_opset=None,
shape_override=None, target=None, large_model=False,
output_path=None)
Expand All @@ -369,7 +371,8 @@ model_proto, external_tensor_storage = tf2onnx.convert.from_graph_def(graph_def,
custom_rewriter: list of custom graph rewriters
extra_opset: list of extra opset's, for example the opset's used by custom ops
shape_override: dict with inputs that override the shapes given by tensorflow
inputs_as_nchw: transpose inputs in list from nchw to nhwc
inputs_as_nchw: transpose inputs in list from nhwc to nchw
outputs_as_nchw: transpose outputs in list from nhwc to nchw
large_model: use the ONNX external tensor storage format
output_path: save model to output_path
Expand All @@ -383,8 +386,8 @@ import tf2onnx
model_proto, external_tensor_storage = tf2onnx.convert.from_tflite(tflite_path,
input_names=None, output_names=None, opset=None, custom_ops=None, custom_op_handlers=None,
custom_rewriter=None, inputs_as_nchw=None, extra_opset=None, shape_override=None, target=None,
large_model=False, output_path=None):
custom_rewriter=None, inputs_as_nchw=None, outputs_as_nchw=None, extra_opset=None,
shape_override=None, target=None, large_model=False, output_path=None):
Args:
tflite_path: the tflite model file full path
Expand All @@ -396,7 +399,8 @@ model_proto, external_tensor_storage = tf2onnx.convert.from_tflite(tflite_path,
runtime can still open the model. Type is a dictionary `{op name: domain}`.
custom_op_handlers: dictionary of custom ops handlers
custom_rewriter: list of custom graph rewriters
inputs_as_nchw: transpose inputs in list from nchw to nhwc
inputs_as_nchw: transpose inputs in list from nhwc to nchw
outputs_as_nchw: transpose outputs in list from nhwc to nchw
extra_opset: list of extra opset's, for example the opset's used by custom ops
shape_override: dict with inputs that override the shapes given by tensorflow
target: list of workarounds applied to help certain platforms
Expand Down
20 changes: 18 additions & 2 deletions tests/backend_test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import onnx
from common import get_test_config
from tfjs_runner import run_tfjs
from tf2onnx import constants
from tf2onnx import utils
from tf2onnx.tfonnx import process_tf_graph
from tf2onnx import optimizer
Expand Down Expand Up @@ -366,6 +367,7 @@ def run_test_case(self, func, feed_dict, input_names_with_port, output_names_wit
graph_def_path = os.path.join(self.test_data_directory, self._testMethodName + "_after_tf_optimize.pb")
utils.save_protobuf(graph_def_path, graph_def)
self.logger.debug("created file %s", graph_def_path)
tfl_process_args = process_args.copy()

if test_tfjs:
tfjs_path = self.convert_to_tfjs(graph_def_path, output_names_with_port)
Expand Down Expand Up @@ -395,6 +397,10 @@ def run_test_case(self, func, feed_dict, input_names_with_port, output_names_wit
g = optimizer.optimize_graph(g, catch_errors=False)
actual = self.run_backend(g, output_names_with_port, onnx_feed_dict, large_model,
use_custom_ops=use_custom_ops)
if 'outputs_as_nchw' in tfl_process_args:
for output_name in tfl_process_args['outputs_as_nchw']:
i = output_names_with_port.index(output_name)
actual[i] = np.transpose(actual[i], constants.NCHW_TO_NHWC)

self.assert_results_equal(expected, actual, rtol, atol, mtol, check_value, check_shape, check_dtype)
self.assert_shapes_correct(g, self.config.allow_missing_shapes, not self.config.skip_onnx_checker)
Expand All @@ -410,12 +416,14 @@ def run_test_case(self, func, feed_dict, input_names_with_port, output_names_wit
if run_tfl_consistency_test:
self.assert_results_equal(expected, tfl_res, rtol, atol, mtol, check_value, check_shape, check_dtype)

tfl_process_args = process_args.copy()
if 'inputs_as_nchw' in tfl_process_args:
nchw_inps_with_port = tfl_process_args['inputs_as_nchw']
tfl_process_args['inputs_as_nchw'] = [i.split(':')[0] for i in nchw_inps_with_port]
input_names_without_port = [inp.split(':')[0] for inp in feed_dict.keys()]

if 'outputs_as_nchw' in tfl_process_args:
nchw_outps_with_port = tfl_process_args['outputs_as_nchw']
tfl_process_args['outputs_as_nchw'] = [i.split(':')[0] for i in nchw_outps_with_port]
output_names_with_port = [i.split(':')[0] for i in nchw_outps_with_port]
g = process_tf_graph(None, opset=self.config.opset,
input_names=input_names_without_port,
output_names=tfl_outputs,
Expand All @@ -427,6 +435,10 @@ def run_test_case(self, func, feed_dict, input_names_with_port, output_names_wit
onnx_feed_dict_without_port = {k.split(':')[0]: v for k, v in onnx_feed_dict.items()}
onnx_tfl_res = self.run_backend(g, tfl_outputs, onnx_feed_dict_without_port,
postfix="_from_tflite", use_custom_ops=use_custom_ops)
if 'outputs_as_nchw' in tfl_process_args:
for output_name in tfl_process_args['outputs_as_nchw']:
i = output_names_with_port.index(output_name)
onnx_tfl_res[i] = np.transpose(onnx_tfl_res[i], constants.NCHW_TO_NHWC)

self.assert_results_equal(tfl_res, onnx_tfl_res, rtol, atol, mtol, check_value, check_shape, check_dtype)
self.assert_shapes_correct(g, self.config.allow_missing_shapes, not self.config.skip_onnx_checker)
Expand Down Expand Up @@ -456,6 +468,10 @@ def run_test_case(self, func, feed_dict, input_names_with_port, output_names_wit
g = optimizer.optimize_graph(g)
onnx_tfjs_res = self.run_backend(g, None, onnx_feed_dict, large_model,
postfix="_from_tfjs", use_custom_ops=use_custom_ops)
if 'outputs_as_nchw' in tfl_process_args:
for output_name in tfl_process_args['outputs_as_nchw']:
i = output_names_with_port.index(output_name)
onnx_tfjs_res[i] = np.transpose(onnx_tfjs_res[i], constants.NCHW_TO_NHWC)

self.assert_results_equal(tfjs_res, onnx_tfjs_res, rtol, atol, mtol, check_value, check_shape,
check_dtype=False)
Expand Down
13 changes: 12 additions & 1 deletion tests/test_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -712,7 +712,7 @@ def func(x):
graph_validator=lambda g: (check_op_count(g, "RandomUniform", 0) and
check_op_count(g, "RandomUniformLike", 0)))

def test_conv2d_with_input_transpose(self):
def test_inputs_as_nchw_arg(self):
x_shape = [2, 32, 32, 3]
kernel_shape = [3, 3, 3, 3]
x_val = make_xval(x_shape)
Expand All @@ -725,6 +725,17 @@ def func(x):
process_args={"inputs_as_nchw": [_INPUT]},
onnx_feed_dict={_INPUT: x_val_for_onnx})

def test_outputs_as_nchw_arg(self):
x_shape = [2, 32, 32, 3]
kernel_shape = [3, 3, 3, 3]
x_val = make_xval(x_shape)
def func(x):
kernel = tf.constant(make_xval(kernel_shape), dtype=tf.float32, name='kernel')
conv = tf.nn.conv2d(x, kernel, strides=[1, 1, 1, 1], padding="SAME")
return tf.identity(conv, name=_TFOUTPUT)
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val}, rtol=1e-05,
process_args={"outputs_as_nchw": [_OUTPUT]})

@skip_tflite("TFlite adds ops that obscure pattern")
@check_tf_min_version("1.15")
def test_conv1d_dilations_rewriter(self):
Expand Down
Loading

0 comments on commit 9ce72be

Please sign in to comment.