diff --git a/alf/utils/tensorrt_utils_test.py b/alf/utils/tensorrt_utils_test.py index 7c099e9b3..b788b0160 100644 --- a/alf/utils/tensorrt_utils_test.py +++ b/alf/utils/tensorrt_utils_test.py @@ -24,6 +24,7 @@ from alf.data_structures import restart from alf.algorithms.sac_algorithm import SacAlgorithm from alf.utils.tensorrt_utils import (OnnxRuntimeEngine, TensorRTEngine, + get_tensorrt_engine_class, compile_method, is_onnxruntime_available, is_tensorrt_available) @@ -154,6 +155,8 @@ def test_tensorrt_resnet50(self): model.eval() dummy_img = torch.randn(1, 3, 224, 224) + for _ in range(10): + eager_output = model(dummy_img) start_time = time.time() for _ in range(100): eager_output = model(dummy_img) @@ -161,14 +164,22 @@ def test_tensorrt_resnet50(self): (time.time() - start_time) / 100) if is_tensorrt_available(): - compile_method(model, 'forward') - model(dummy_img) # build engine - start_time = time.time() - for _ in range(100): - output = model(dummy_img) - print(f"TensorRT predict step time: ", - (time.time() - start_time) / 100) - self.assertTensorClose(eager_output, output, epsilon=0.01) + for fp16 in [True, False]: + model = models.resnet50(pretrained=True) + model.eval() + compile_method(model, 'forward', + partial(get_tensorrt_engine_class(), fp16=fp16)) + model(dummy_img) # build engine + for _ in range(10): + output = model(dummy_img) + start_time = time.time() + for _ in range(100): + output = model(dummy_img) + fp_str = "16" if fp16 else "32" + print(f"TensorRT FP{fp_str} predict step time: ", + (time.time() - start_time) / 100) + eps = 0.03 if fp16 else 0.01 + self.assertTensorClose(eager_output, output, epsilon=eps) if is_onnxruntime_available(): model1 = models.resnet50(pretrained=True) @@ -176,6 +187,8 @@ def test_tensorrt_resnet50(self): # Use onnxruntime API compile_method(model1, 'forward', OnnxRuntimeEngine) model1(dummy_img) # build engine + for _ in range(10): + output = model1(dummy_img) start_time = time.time() for _ in range(100): output = model1(dummy_img)