diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index ac2ea9d0b1bb..d89737b5a4a7 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -897,6 +897,7 @@ def _impl(inputs, attr, params): disables=['momentum'])(inputs, attr) if need_cast: + out = _expr.TupleGetItem(out.astuple(), 0) out = _op.cast(out, dtype=attr['T'].name) return out return _impl