From b0396677a293a2573ec5d09c5eac794dd06b03f7 Mon Sep 17 00:00:00 2001 From: Ramana Radhakrishnan Date: Tue, 21 Apr 2020 00:08:06 +0100 Subject: [PATCH 1/2] Fix oversight in importing tf.compat.v1 as tf. --- tests/python/frontend/tensorflow/test_bn_dynamic.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/python/frontend/tensorflow/test_bn_dynamic.py b/tests/python/frontend/tensorflow/test_bn_dynamic.py index 4be838e331ef..a2d69034a94a 100644 --- a/tests/python/frontend/tensorflow/test_bn_dynamic.py +++ b/tests/python/frontend/tensorflow/test_bn_dynamic.py @@ -22,7 +22,10 @@ """ import tvm import numpy as np -import tensorflow as tf +try: + import tensorflow.compat.v1 as tf +except ImportError: + import tensorflow as tf from tvm import relay from tensorflow.python.framework import graph_util From 8971135d7a1fb670d51261e9839f4bafc660060f Mon Sep 17 00:00:00 2001 From: Ramana Radhakrishnan Date: Tue, 21 Apr 2020 00:14:32 +0100 Subject: [PATCH 2/2] Actually disable test for lstm in TF2.1 Since the testing framework actually uses pytest, the version check needs to be moved. --- tests/python/frontend/tensorflow/test_forward.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/python/frontend/tensorflow/test_forward.py b/tests/python/frontend/tensorflow/test_forward.py index bc884bbbfa9b..93501f134d59 100644 --- a/tests/python/frontend/tensorflow/test_forward.py +++ b/tests/python/frontend/tensorflow/test_forward.py @@ -1901,7 +1901,9 @@ def _get_tensorflow_output(): def test_forward_lstm(): '''test LSTM block cell''' - _test_lstm_cell(1, 2, 1, 0.5, 'float32') + if package_version.parse(tf.VERSION) < package_version.parse('2.0.0'): + #in 2.0, tf.contrib.rnn.LSTMBlockCell is removed + _test_lstm_cell(1, 2, 1, 0.5, 'float32') ####################################################################### @@ -3308,9 +3310,7 @@ def test_forward_isfinite(): test_forward_ptb() # RNN - if package_version.parse(tf.VERSION) < package_version.parse('2.0.0'): - #in 2.0, tf.contrib.rnn.LSTMBlockCell is removed - test_forward_lstm() + test_forward_lstm() # Elementwise test_forward_ceil()