Skip to content

Commit

Permalink
Improve ZerosLike implementation and optimize for opset >= 9 (#2003)
Browse files Browse the repository at this point in the history
* Improve ZerosLike implementation for opset >= 9

Signed-off-by: Deyu Huang <[email protected]>
Co-authored-by: Guenther Schmuelling <[email protected]>

* add a blank line

Signed-off-by: Deyu Huang <[email protected]>

Co-authored-by: Guenther Schmuelling <[email protected]>
Co-authored-by: Jay Zhang <[email protected]>
  • Loading branch information
3 people authored Jul 27, 2022
1 parent 1c7d4ce commit d72b4d1
Show file tree
Hide file tree
Showing 4 changed files with 115 additions and 1 deletion.
13 changes: 13 additions & 0 deletions tests/test_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -3887,6 +3887,19 @@ def func(x, y):

self._run_test_case(func, [_OUTPUT], {_INPUT: input_x > 0.5, _INPUT1: input_y})

@check_opset_min_version(9, "ConstantOfShape")
def test_zeros_like_opset9(self):
input_x = np.random.random_sample([3, 16, 16]).astype(np.float32)
input_y = np.array([16, 16, 3]).astype(np.int64)

def func(x, y):
z = tf.reshape(x, y)
return tf.zeros_like(z, name=_TFOUTPUT)

self._run_test_case(func, [_OUTPUT], {_INPUT: input_x, _INPUT1: input_y})
self._run_test_case(func, [_OUTPUT], {_INPUT: input_x.astype(np.int32), _INPUT1: input_y}, as_session=True,
graph_validator=lambda g: check_op_count(g, "ConstantOfShape", 1, disabled=False))

@check_opset_min_version(9, "is_nan")
def test_isnan(self):
# only compatible with dtype `float32`
Expand Down
66 changes: 66 additions & 0 deletions tests/test_optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2241,6 +2241,72 @@ def test_const_fold_cast_with_const(self):
self.run_and_compare(["res"], {"X": np.random.randn(*shape).astype(np.int64)}, model_proto,
"Cast", 0)

def test_const_fold_add(self):
shape = (6, 6)
const_tensor1 = helper.make_tensor(name='const_tensor', data_type=TensorProto.FLOAT, dims=shape,
vals=np.random.randn(*shape).flatten().astype(np.float32))
const_tensor2 = helper.make_tensor(name='const_tensor', data_type=TensorProto.FLOAT, dims=shape,
vals=np.random.randn(*shape).flatten().astype(np.float32))
node1 = helper.make_node("Constant", [], ["const1"], value=const_tensor1)
node2 = helper.make_node("Constant", [], ["const2"], value=const_tensor2)
node3 = helper.make_node("Add", ["const1", "const2"], ["add"])
node4 = helper.make_node("Add", ["add", "X"], ["res"])

graph = helper.make_graph(
[node1, node2, node3, node4],
"test_const_fold_add",
[helper.make_tensor_value_info("X", TensorProto.FLOAT, shape)],
[helper.make_tensor_value_info("res", TensorProto.FLOAT, shape)],
)

model_proto = self.make_model(graph, producer_name="onnx-tests")
self.run_and_compare(["res"], {"X": np.random.randn(*shape).astype(np.float32)}, model_proto,
"Add", 1)

def test_const_fold_sub(self):
shape = (6, 6)
const_tensor1 = helper.make_tensor(name='const_tensor', data_type=TensorProto.FLOAT, dims=shape,
vals=np.random.randn(*shape).flatten().astype(np.float32))
const_tensor2 = helper.make_tensor(name='const_tensor', data_type=TensorProto.FLOAT, dims=shape,
vals=np.random.randn(*shape).flatten().astype(np.float32))
node1 = helper.make_node("Constant", [], ["const1"], value=const_tensor1)
node2 = helper.make_node("Constant", [], ["const2"], value=const_tensor2)
node3 = helper.make_node("Sub", ["const1", "const2"], ["sub"])
node4 = helper.make_node("Sub", ["sub", "X"], ["res"])

graph = helper.make_graph(
[node1, node2, node3, node4],
"test_const_fold_sub",
[helper.make_tensor_value_info("X", TensorProto.FLOAT, shape)],
[helper.make_tensor_value_info("res", TensorProto.FLOAT, shape)],
)

model_proto = self.make_model(graph, producer_name="onnx-tests")
self.run_and_compare(["res"], {"X": np.random.randn(*shape).astype(np.float32)}, model_proto,
"Sub", 1)

def test_const_fold_mul(self):
shape = (6, 6)
const_tensor1 = helper.make_tensor(name='const_tensor', data_type=TensorProto.FLOAT, dims=shape,
vals=np.random.randn(*shape).flatten().astype(np.float32))
const_tensor2 = helper.make_tensor(name='const_tensor', data_type=TensorProto.FLOAT, dims=shape,
vals=np.random.randn(*shape).flatten().astype(np.float32))
node1 = helper.make_node("Constant", [], ["const1"], value=const_tensor1)
node2 = helper.make_node("Constant", [], ["const2"], value=const_tensor2)
node3 = helper.make_node("Mul", ["const1", "const2"], ["mul"])
node4 = helper.make_node("Mul", ["mul", "X"], ["res"])

graph = helper.make_graph(
[node1, node2, node3, node4],
"test_const_fold_mul",
[helper.make_tensor_value_info("X", TensorProto.FLOAT, shape)],
[helper.make_tensor_value_info("res", TensorProto.FLOAT, shape)],
)

model_proto = self.make_model(graph, producer_name="onnx-tests")
self.run_and_compare(["res"], {"X": np.random.randn(*shape).astype(np.float32)}, model_proto,
"Mul", 1)

def test_const_fold_split(self):
shape = (2, 6, 1)
const_tensor = helper.make_tensor(name='const_tensor', data_type=TensorProto.FLOAT, dims=shape,
Expand Down
13 changes: 12 additions & 1 deletion tf2onnx/onnx_opset/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import logging

import numpy as np
from onnx import onnx_pb, numpy_helper
from onnx import onnx_pb, numpy_helper, helper
from tf2onnx import utils
from tf2onnx.handler import tf_op
from tf2onnx.graph_builder import GraphBuilder
Expand Down Expand Up @@ -242,6 +242,17 @@ def version_1(cls, ctx, node, **kwargs):
name=node.name, outputs=node.output,
shapes=shapes, dtypes=dtypes)

@classmethod
def version_9(cls, ctx, node, **kwargs):
dtypes = node.output_dtypes
ctx.remove_node(node.name)
shape = ctx.make_node("Shape", node.input).output[0]
zero_tensor = helper.make_tensor("value", dtypes[0], [1], vals=[0])
ctx.make_node("ConstantOfShape", inputs=[shape],
attr={'value': zero_tensor},
name=node.name, outputs=node.output,
dtypes=dtypes)


@tf_op(["IteratorV2", "FIFOQueueV2"])
class Iterator:
Expand Down
24 changes: 24 additions & 0 deletions tf2onnx/optimizer/const_fold_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,30 @@ def _fold_unsqueeze(node, graph):
const_val_after_unsqueeze = const_val.reshape(shape_out)
return [const_val_after_unsqueeze]

@staticmethod
@_register_func("Mul")
def _fold_mul(node, graph):
const_val1 = node.inputs[0].get_tensor_value(as_list=False)
const_val2 = node.inputs[1].get_tensor_value(as_list=False)
const_val_after_nul = np.multiply(const_val1, const_val2)
return [const_val_after_nul]

@staticmethod
@_register_func("Add")
def _fold_add(node, graph):
const_val1 = node.inputs[0].get_tensor_value(as_list=False)
const_val2 = node.inputs[1].get_tensor_value(as_list=False)
const_val_after_add = np.add(const_val1, const_val2)
return [const_val_after_add]

@staticmethod
@_register_func("Sub")
def _fold_sub(node, graph):
const_val1 = node.inputs[0].get_tensor_value(as_list=False)
const_val2 = node.inputs[1].get_tensor_value(as_list=False)
const_val_after_sub = np.subtract(const_val1, const_val2)
return [const_val_after_sub]

@staticmethod
@_register_func("Split")
def _fold_split(node, graph):
Expand Down

0 comments on commit d72b4d1

Please sign in to comment.