Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Frontend][TENSORFLOW] Add support for unpack with dim 0 after tensorlist stack #8558

Merged
merged 2 commits into from
Jul 28, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions python/tvm/relay/frontend/tensorflow2_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,13 +133,21 @@ def _impl(inputs, attr, params, prelude):
stack_func = prelude.get_global_var("tensor_array_stack", dtype_str)
out = stack_func(inputs[0])
else:
static_tensor_array_ops = StaticTensorArrayOps(prelude, dtype_str, input_ta_shape)
if "num_elements" in attr:
num_elements = attr["num_elements"]
static_tensor_array_ops = StaticTensorArrayOps(
prelude, dtype_str, input_ta_shape, num_elements
)
static_tensor_array_ops.register()
stack_func = prelude.get_global_var_static(
"tensor_array_stack", dtype_str, input_ta_shape
"tensor_array_stack", dtype_str, input_ta_shape, num_elements
)
out_tensor = stack_func(inputs[0])
out_shape = (Any(),) + input_ta_shape
out_shape = (
(num_elements,) + input_ta_shape
if num_elements and num_elements == 1
else (Any(),) + input_ta_shape
)
static_tensor_array_ops = StaticTensorArrayOps(prelude, dtype_str, out_shape)
static_tensor_array_ops.register()
get_data_func = prelude.get_global_var_static("tensor_get_data", dtype_str, out_shape)
Expand Down
80 changes: 59 additions & 21 deletions python/tvm/relay/prelude.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,33 @@ def get_tensor_array_shape(expr, dtype, prelude):
return None


def _get_name_static(canonical, dtype, shape):
"""Get name for static shape tensor array op corresponding
to the canonical name"""
def _get_name_static(canonical, dtype, shape, batch_dim=None):
"""Get name for static shape tensor array op

By design, static ADT tensor in TVM has type name in the format
of static_tensor_dim0_dim1_..._dimN_t
or static_tensor_batch1_dim0_dim1_..._dimN_t if tensorlist stack only have one item.

Parameters
----------
canonical : String
Tensor array op name

dtype : str
Data type.

shape : tuple of (int, Any) or None
Tensor array shape

batch_dim: None or int
1 if tensorlist stack only have one item.
None by default

Returns
-------
name : String
The tensor array op name
"""
dim_names = []
for dim in shape:
if isinstance(dim, Any):
Expand All @@ -89,26 +113,31 @@ def _get_name_static(canonical, dtype, shape):
shape_str = "scalar"
if canonical == "tensor_t":
return "static_tensor_{}_{}_t".format(dtype, shape_str)
return "{}_{}_{}".format(canonical, dtype, shape_str)
if batch_dim is None or canonical in ["tensor_constructor", "tensor_nil"]:
return "{}_{}_{}".format(canonical, dtype, shape_str)
if batch_dim != 1:
return "{}_{}_{}".format(canonical, dtype, shape_str)
return "{}_{}_batch{}_{}".format(canonical, dtype, str(batch_dim), shape_str)


class StaticTensorArrayOps(object):
"""Contains tensor array related ops for fixed rank tensor array"""

def __init__(self, prelude, dtype, shape):
def __init__(self, prelude, dtype, shape, batch_dim=None):
"""Create tensor array ops registry"""
self.prelude = prelude
self.dtype = dtype
self.shape = shape
self.batch_dim = batch_dim
self.list, self.cons, self.nil = self.prelude.mod.get_type("List")

def get_name(self, canonical):
"""Get name corresponding to the canonical name"""
return _get_name_static(canonical, self.dtype, self.shape)
return _get_name_static(canonical, self.dtype, self.shape, self.batch_dim)

def get_global_var(self, canonical):
"""Get global corresponding to the canonical name"""
return self.prelude.get_global_var_static(canonical, self.dtype, self.shape)
return self.prelude.get_global_var_static(canonical, self.dtype, self.shape, self.batch_dim)

def get_type(self, canonical):
"""Get type corresponding to the canonical name"""
Expand Down Expand Up @@ -262,9 +291,10 @@ def define_tensor_expand_dims(self):

# Note: we set the added axis to be Any() instead of 1 due to
# in stack op, we need to recursively concatenate.
new_axis = Any() if self.batch_dim is None or self.batch_dim != 1 else self.batch_dim
tensor_type_var, tensor_constructor, _ = self._get_adt_by_shape(
[
Any(),
new_axis,
]
+ list(self.shape)
)
Expand Down Expand Up @@ -573,20 +603,27 @@ def define_tensor_array_stack(self):
expand_dims_var = self.get_global_var("tensor_expand_dims")

# Register tensor_concatenate for output_shape
new_axis = Any() if not self.batch_dim or self.batch_dim != 1 else self.batch_dim
output_shape = [
Any(),
new_axis,
] + list(self.shape)

_, _, output_ops = self._get_adt_by_shape(output_shape)
output_ops.define_tensor_concatenate()
concat_var = output_ops.get_global_var("tensor_concatenate")

tensor_array_expand_dims = self.prelude.map(expand_dims_var, tensor_array)
tensors = self.prelude.foldl(
concat_var,
self.prelude.hd(tensor_array_expand_dims),
self.prelude.tl(tensor_array_expand_dims),
)
if self.batch_dim is not None and self.batch_dim == 1:
# only one element
tensors = self.prelude.id(
self.prelude.hd(tensor_array_expand_dims),
)
else:
tensors = self.prelude.foldl(
concat_var,
self.prelude.hd(tensor_array_expand_dims),
self.prelude.tl(tensor_array_expand_dims),
)

output_tensor_type_var, _, _ = self._get_adt_by_shape(output_shape)
self.prelude.mod[stack_var] = Function(
[tensor_array], tensors, output_tensor_type_var(), []
Expand All @@ -599,8 +636,9 @@ def define_tensor_array_gather(self):
helper_name = self.get_name("tensor_array_gather_helper")
helper_var = self._create_global_var(helper_name)

new_axis = Any() if self.batch_dim is None or self.batch_dim != 1 else self.batch_dim
output_shape = [
Any(),
new_axis,
] + list(self.shape)
output_tensor_type_var, _, _ = self._get_adt_by_shape(output_shape)
stack_var = self.get_global_var("tensor_array_stack")
Expand Down Expand Up @@ -668,7 +706,7 @@ def register(self):

def _get_adt_by_shape(self, shape):
"""Get ADT type and constructor with given shape."""
adt_ops = StaticTensorArrayOps(self.prelude, self.dtype, shape)
adt_ops = StaticTensorArrayOps(self.prelude, self.dtype, shape, self.batch_dim)
adt_ops.define_tensor_adt()
tensor_type_var = adt_ops.get_type("tensor_t")
tensor_constructor = adt_ops.get_ctor("tensor_constructor")
Expand Down Expand Up @@ -1482,13 +1520,13 @@ def get_tensor_ctor(self, canonical, dtype):
ty = self.get_type("tensor_t", dtype)
return self.get_ctor(ty.name_hint, canonical, dtype)

def get_name_static(self, canonical, dtype, shape):
def get_name_static(self, canonical, dtype, shape, batch_dim=None):
"""Get name corresponding to the canonical name"""
return _get_name_static(canonical, dtype, shape)
return _get_name_static(canonical, dtype, shape, batch_dim)

def get_global_var_static(self, canonical, dtype, shape):
def get_global_var_static(self, canonical, dtype, shape, batch_dim=None):
"""Get var corresponding to the canonical name"""
name = self.get_name_static(canonical, dtype, shape)
name = self.get_name_static(canonical, dtype, shape, batch_dim)
return self.mod.get_global_var(name)

def get_type_static(self, canonical, dtype, shape):
Expand Down
61 changes: 31 additions & 30 deletions tests/python/frontend/tensorflow2/test_functional_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,8 +484,6 @@ def get_input(self):
in_tens[1] = np.zeros((3,), dtype="float32")
return in_tens

"""2D array as input"""

@tf.function(input_signature=[tf.TensorSpec(shape=(2, 3), dtype=tf.float32)])
def func(self, x):
dtype = tf.float32
Expand Down Expand Up @@ -513,8 +511,6 @@ def get_input(self):
in_tens[1, :, :] = np.zeros((3, 4), dtype="float32")
return in_tens

"""2D array as input"""

@tf.function(input_signature=[tf.TensorSpec(shape=(2, 3, 4), dtype=tf.float32)])
def func(self, x):
dtype = tf.float32
Expand All @@ -531,18 +527,8 @@ def func(self, x):
run_model_graph(TensorList2D)
run_func_graph(TensorList2D, runtime="vm")

run_test(
(
3,
4,
)
)
run_test(
(
-1,
-1,
)
)
run_test((3, 4))
run_test((-1, -1))


def test_tensorlist_stack_2d():
Expand All @@ -553,8 +539,6 @@ def get_input(self):
in_tens[1, :, :] = np.zeros((3, 4), dtype="float32")
return in_tens

"""2D array as input"""

@tf.function(input_signature=[tf.TensorSpec(shape=(2, 3, 4), dtype=tf.float32)])
def func(self, x):
dtype = tf.float32
Expand All @@ -570,18 +554,35 @@ def func(self, x):
run_model_graph(TensorListStack2D)
run_func_graph(TensorListStack2D, runtime="vm")

run_test(
(
3,
4,
)
)
run_test(
(
-1,
-1,
)
)
run_test((3, 4))
run_test((-1, -1))


def test_tensorlist_stack_unpack():
def run_test(elem_shape):
class TensorListStack2D(tf.Module):
def get_input(self):
in_tens = np.ones((1, 3, 4), dtype="float32")
return in_tens

@tf.function(input_signature=[tf.TensorSpec(shape=(1, 3, 4), dtype=tf.float32)])
def func(self, x):
dtype = tf.float32
tl = tf.raw_ops.TensorListReserve(
element_shape=elem_shape, num_elements=1, element_dtype=dtype
)
tl = tf.raw_ops.TensorListSetItem(input_handle=tl, index=0, item=x[0, :, :])
output = tf.raw_ops.TensorListStack(
input_handle=tl, element_shape=elem_shape, element_dtype=dtype, num_elements=1
)
output = tf.raw_ops.Unpack(value=output, num=1, axis=0)
return output

run_model_graph(TensorListStack2D)
run_func_graph(TensorListStack2D, runtime="vm")

run_test((3, 4))
run_test((-1, -1))


if __name__ == "__main__":
Expand Down