Skip to content

Commit

Permalink
scripts/vsmlrt.py: fix rife v4.7-v4.9 model with v2 representation fo…
Browse files Browse the repository at this point in the history
…r TRT backend

#66 (comment)
  • Loading branch information
WolframRhodium committed Nov 3, 2023
1 parent 463b1d1 commit 96c6aa9
Showing 1 changed file with 22 additions and 4 deletions.
26 changes: 22 additions & 4 deletions scripts/vsmlrt.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "3.18.3"
__version__ = "3.18.4"

__all__ = [
"Backend", "BackendV2",
Expand Down Expand Up @@ -141,6 +141,7 @@ class TRT:
short_path: typing.Optional[bool] = None # True on Windows by default, False otherwise
bf16: bool = False
custom_env: typing.Dict[str, str] = field(default_factory=lambda: {})
custom_args: typing.List[str] = field(default_factory=lambda: [])

# internal backend attributes
supports_onnx_serialization: bool = False
Expand Down Expand Up @@ -876,7 +877,7 @@ def RIFEMerge(
multiple = int(multiple_frac.numerator)
scale = float(Fraction(scale))

if model >= 47 and (ensemble or scale != 1.0 or _implementation == 2):
if ensemble or scale != 1.0 or _implementation == 2:
raise ValueError("not supported")

network_path = os.path.join(
Expand Down Expand Up @@ -916,6 +917,19 @@ def RIFEMerge(
trt_opt_shapes=(tile_w, tile_h)
)

# https://github.com/AmusementClub/vs-mlrt/issues/66#issuecomment-1791986979
if _implementation == 2 and model in [47, 48, 49]:
backend.custom_args.extend([
"--precisionConstraints=obey",
"--layerPrecisions=" + (
"/Cast_2:fp32,/Cast_3:fp32,/Cast_5:fp32,/Cast_7:fp32,"
"/Reciprocal:fp32,/Reciprocal_1:fp32,"
"/Mul:fp32,/Mul_1:fp32,/Mul_8:fp32,/Mul_10:fp32,"
"/Sub_5:fp32,/Sub_6:fp32,"
"ONNXTRT_Broadcast_236:fp32,ONNXTRT_Broadcast_238:fp32,ONNXTRT_Broadcast_275:fp32"
)
])

if scale == 1.0:
return inference_with_fallback(
clips=clips, network_path=network_path,
Expand Down Expand Up @@ -1163,7 +1177,8 @@ def trtexec(
max_aux_streams: typing.Optional[int] = None,
short_path: typing.Optional[bool] = None,
bf16: bool = False,
custom_env: typing.Dict[str, str] = {}
custom_env: typing.Dict[str, str] = {},
custom_args: typing.List[str] = []
) -> str:

# tensort runtime version, e.g. 8401 => 8.4.1
Expand Down Expand Up @@ -1334,6 +1349,8 @@ def trtexec(
if bf16:
args.append("--bf16")

args.extend(custom_args)

if log:
env_key = "TRTEXEC_LOG_FILE"
prev_env_value = os.environ.get(env_key)
Expand Down Expand Up @@ -1569,7 +1586,8 @@ def _inference(
max_aux_streams=backend.max_aux_streams,
short_path=backend.short_path,
bf16=backend.bf16,
custom_env=backend.custom_env
custom_env=backend.custom_env,
custom_args=backend.custom_args
)
clip = core.trt.Model(
clips, engine_path,
Expand Down

0 comments on commit 96c6aa9

Please sign in to comment.