From 91c95927525bf46b0432a1744b4f440c58050998 Mon Sep 17 00:00:00 2001 From: WolframRhodium Date: Sat, 20 Apr 2024 18:39:11 +0800 Subject: [PATCH] scripts/vsmlrt.py: add `tf32` flag to the ort_cuda backend --- scripts/vsmlrt.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/scripts/vsmlrt.py b/scripts/vsmlrt.py index 04a068c..4e87d4d 100644 --- a/scripts/vsmlrt.py +++ b/scripts/vsmlrt.py @@ -1,4 +1,4 @@ -__version__ = "3.20.8" +__version__ = "3.20.9" __all__ = [ "Backend", "BackendV2", @@ -101,6 +101,7 @@ class ORT_CUDA: fp16_blacklist_ops: typing.Optional[typing.Sequence[str]] = None prefer_nhwc: bool = False output_format: int = 0 # 0: fp32, 1: fp16 + tf32: bool = False # internal backend attributes supports_onnx_serialization: bool = True @@ -2057,6 +2058,7 @@ def _inference( if version >= (1, 18, 0): kwargs["prefer_nhwc"] = backend.prefer_nhwc kwargs["output_format"] = backend.output_format + kwargs["tf32"] = backend.tf32 clip = core.ort.Model( clips, network_path,