Skip to content

Commit

Permalink
TRT Dynamic Reshape Fix (apache#7412)
Browse files Browse the repository at this point in the history
* Dynamic Reshape

* Changes

* Add test cases

* Add test cases

* PR COmments

* CI Error

* EmptyCommitCIError

Co-authored-by: Ubuntu <[email protected]>
  • Loading branch information
2 people authored and Lokiiiiii committed Mar 1, 2021
1 parent 92f8e3f commit ce36222
Show file tree
Hide file tree
Showing 2 changed files with 107 additions and 7 deletions.
13 changes: 6 additions & 7 deletions python/tvm/relay/op/contrib/tensorrt.py
Original file line number Diff line number Diff line change
Expand Up @@ -615,7 +615,6 @@ def layout_transform_annotate_fn(expr): # pylint: disable=unused-variable
@_register_external_dynamic_check_func("reshape")
def reshape_annotate_fn(expr): # pylint: disable=unused-variable
"""Check if reshape is supported by TensorRT."""

attrs, args = expr.attrs, expr.args
if args[0].checked_type.dtype != "float32":
logger.info("Only float32 inputs are supported for TensorRT.")
Expand All @@ -629,23 +628,23 @@ def reshape_annotate_fn(expr): # pylint: disable=unused-variable
if len(new_shape) == 0 or len(shape) == 0:
logger.info("reshape: Can't reshape to or from scalar.")
return False

dynamic_reshape = any([isinstance(x, tvm.tir.expr.Any) for x in shape])

if dynamic_reshape:
# Make sure that the batch dim is unmodified.
if int(new_shape[0]) < 0:
for shape_val, new_shape_val in enumerate(shape[1:], new_shape[1:]):
for shape_val, new_shape_val in zip(shape[1:], new_shape[1:]):
if not (
isinstance(shape_val, int)
and isinstance(new_shape_val, int)
isinstance(shape_val, (int, tvm.tir.expr.IntImm))
and isinstance(new_shape_val, (int, tvm.tir.expr.IntImm))
and int(shape_val) == int(new_shape_val)
):
return False
elif int(new_shape[0]) > 0:
# Currently we only allow dim[0] to be Any, so this branch will always be False
if not (
isinstance(shape[0], int)
and isinstance(new_shape[0], int)
isinstance(shape[0], (int, tvm.tir.expr.IntImm))
and isinstance(new_shape[0], (int, tvm.tir.expr.IntImm))
and int(shape[0]) == int(new_shape[0])
):
return False
Expand Down
101 changes: 101 additions & 0 deletions tests/python/contrib/test_tensorrt.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from tvm.contrib import graph_runtime, utils
from tvm.runtime.vm import VirtualMachine
from tvm.relay import Any, GlobalVar, transform
from tvm.relay.expr_functor import ExprVisitor
from typing import Dict, Tuple, Union
from tvm.contrib.download import download
from tvm.relay.op.contrib import tensorrt
Expand Down Expand Up @@ -631,6 +632,106 @@ def get_graph(x_shape, new_shape):
run_and_verify_func(get_graph((1, 1, 2, 3), (1, 6)))


class AreOpsOnGraph(ExprVisitor):
"""
Visits the Graph recursively and checks if it contains ops in the op_list
"""

def __init__(self, op_list):
ExprVisitor.__init__(self)
self.op_list = op_list
self.on_graph = False

def visit_call(self, call):
if isinstance(call.op, tvm.tir.op.Op):
if str(call.op) in self.op_list:
self.on_graph = True

return super().visit_call(call)

def are_ops_on_graph(self, subgraph) -> bool:
"""
This function recursively visits the graph and checks if op_list ops are ongraph"
"""
self.visit(subgraph)
return self.on_graph


def are_ops_on_trt(mod, op_list):
for subgraph in mod.get_global_vars():
name = subgraph.name_hint
op_on_trt = False
op_on_tvm = True
if name == "main":
op_on_tvm = AreOpsOnGraph(op_list).are_ops_on_graph(mod[name].body)
elif mod[name].attrs and mod[name].attrs["Compiler"] == "tensorrt":
op_on_trt = AreOpsOnGraph(op_list).are_ops_on_graph(mod[name].body)
else:
op_on_tvm &= AreOpsOnGraph(op_list).are_ops_on_graph(mod[name].body)

if not op_on_trt or op_on_tvm:
return False

return True


def test_dynamic_reshape():
if skip_codegen_test():
return

def test_run(x_data_list, x_shape, new_shape, should_offload_to_trt):
result_arr = [{} for _ in range(len(x_data_list))]
for use_trt in [True, False]:
x = relay.var("x", shape=x_shape, dtype="float32")
out = relay.reshape(x, new_shape)
f = relay.Function([x], out)
mod = tvm.IRModule()
mod["main"] = f
if use_trt:
mod, _ = tensorrt.partition_for_tensorrt(
mod, params={}, remove_no_mac_subgraphs=False
)
assert are_ops_on_trt(mod, op_list=["reshape"]) == should_offload_to_trt
if not skip_runtime_test():
with relay.build_config(opt_level=3):
relay_exec = relay.create_executor("vm", mod=mod, ctx=tvm.cpu(0), target="llvm")

for i, x_data in enumerate(x_data_list):
result_arr[i][use_trt] = relay_exec.evaluate()(x_data)

if not skip_runtime_test():
for i in range(len(x_data_list)):
assert_result_dict_holds(result_arr[i])

dim_values = [1, 1, 0, 2, 3, 0, 1, 3, 2]
x_shape = (relay.Any(), 3, 2, 3)
x_data_list = [
np.ones([dim_value] + list(x_shape)[1:]).astype("float32") for dim_value in dim_values
]
new_shape = (-1, 3, 2, 3)
should_offload_to_trt = True
test_run(x_data_list, x_shape, new_shape, should_offload_to_trt)

dim_values = [1, 1, 0, 2, 3, 0, 1, 3, 2]
x_shape = (relay.Any(), 3, 2, 3)
x_data_list = [
np.ones([dim_value] + list(x_shape)[1:]).astype("float32") for dim_value in dim_values
]
new_shape = (-1, 1, 2, 3)
should_offload_to_trt = False
test_run(x_data_list, x_shape, new_shape, should_offload_to_trt)

dim_values = [1, 1, 0, 2, 3, 0, 1, 3, 2]
x_shape = (1, relay.Any(), 2, 3)
x_data_list = [
np.ones(list(x_shape[:1]) + [dim_value] + list(x_shape)[2:]).astype("float32")
for dim_value in dim_values
]
new_shape = (1, -1, 2, 3)
should_offload_to_trt = False
test_run(x_data_list, x_shape, new_shape, should_offload_to_trt)


def test_transpose():
def get_graph(x_shape, order):
x = relay.var("x", shape=(x_shape), dtype="float32")
Expand Down

0 comments on commit ce36222

Please sign in to comment.