Skip to content

Commit

Permalink
add test for fp16 and add warmup
Browse files Browse the repository at this point in the history
(cherry picked from commit 5277872)
  • Loading branch information
Le Horizon authored and Le Horizon committed Aug 28, 2024
1 parent 4b9924f commit a8fa3d7
Showing 1 changed file with 21 additions and 8 deletions.
29 changes: 21 additions & 8 deletions alf/utils/tensorrt_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from alf.data_structures import restart
from alf.algorithms.sac_algorithm import SacAlgorithm
from alf.utils.tensorrt_utils import (OnnxRuntimeEngine, TensorRTEngine,
get_tensorrt_engine_class,
compile_method, is_onnxruntime_available,
is_tensorrt_available)

Expand Down Expand Up @@ -154,28 +155,40 @@ def test_tensorrt_resnet50(self):
model.eval()
dummy_img = torch.randn(1, 3, 224, 224)

for _ in range(10):
eager_output = model(dummy_img)
start_time = time.time()
for _ in range(100):
eager_output = model(dummy_img)
print("Eager-mode predict step time: ",
(time.time() - start_time) / 100)

if is_tensorrt_available():
compile_method(model, 'forward')
model(dummy_img) # build engine
start_time = time.time()
for _ in range(100):
output = model(dummy_img)
print(f"TensorRT predict step time: ",
(time.time() - start_time) / 100)
self.assertTensorClose(eager_output, output, epsilon=0.01)
for fp16 in [True, False]:
model = models.resnet50(pretrained=True)
model.eval()
compile_method(model, 'forward',
partial(get_tensorrt_engine_class(), fp16=fp16))
model(dummy_img) # build engine
for _ in range(10):
output = model(dummy_img)
start_time = time.time()
for _ in range(100):
output = model(dummy_img)
fp_str = "16" if fp16 else "32"
print(f"TensorRT FP{fp_str} predict step time: ",
(time.time() - start_time) / 100)
eps = 0.03 if fp16 else 0.01
self.assertTensorClose(eager_output, output, epsilon=eps)

if is_onnxruntime_available():
model1 = models.resnet50(pretrained=True)
model1.eval()
# Use onnxruntime API
compile_method(model1, 'forward', OnnxRuntimeEngine)
model1(dummy_img) # build engine
for _ in range(10):
output = model1(dummy_img)
start_time = time.time()
for _ in range(100):
output = model1(dummy_img)
Expand Down

0 comments on commit a8fa3d7

Please sign in to comment.