From 9ea60374c9ee9ed08ecc4d7dd0e8119187ecee2e Mon Sep 17 00:00:00 2001 From: sven1977 Date: Wed, 28 Feb 2024 19:10:23 +0100 Subject: [PATCH 1/5] wip Signed-off-by: sven1977 --- rllib/examples/export/onnx_torch.py | 95 ++++++++++++++++++++--------- 1 file changed, 65 insertions(+), 30 deletions(-) diff --git a/rllib/examples/export/onnx_torch.py b/rllib/examples/export/onnx_torch.py index 94faec7b1cdc..b8eea6063d6b 100644 --- a/rllib/examples/export/onnx_torch.py +++ b/rllib/examples/export/onnx_torch.py @@ -7,67 +7,102 @@ import shutil import torch +from ray.rllib.utils.torch_utils import convert_to_torch_tensor + + +class ONNXCompatibleWrapper(torch.nn.Module): + def __init__(self, original_model): + super(ONNXCompatibleWrapper, self).__init__() + self.original_model = original_model + + def forward(self, a, b0, b1, c): + # Convert the separate tensor inputs back into the list format + # expected by the original model's forward method. + b = [b0, b1] + ret = self.original_model({"obs": a}, b, c) + # results, state_out_0, state_out_1 + return ret[0], ret[1][0], ret[1][1] + + if __name__ == "__main__": # Configure our PPO Algorithm. config = ( ppo.PPOConfig() # ONNX is not supported by RLModule API yet. .experimental(_enable_new_api_stack=False) + .environment("CartPole-v1") .rollouts(num_rollout_workers=1) .framework("torch") + .training(model={"use_lstm": True}) ) - outdir = "export_torch" - if os.path.exists(outdir): - shutil.rmtree(outdir) - - np.random.seed(1234) + B = 3 + T = 5 + LSTM_CELL = 256 - # We will run inference with this test batch - test_data = { - "obs": np.random.uniform(0, 1.0, size=(10, 4)).astype(np.float32), - "state_ins": np.array([0.0], dtype=np.float32), + # Input data for a python inference forward call. + test_data_python = { + "obs": np.random.uniform(0, 1.0, size=(B*T, 4)).astype(np.float32), + "state_ins": [ + np.random.uniform(0, 1.0, size=(B, LSTM_CELL)).astype(np.float32), + np.random.uniform(0, 1.0, size=(B, LSTM_CELL)).astype(np.float32) + ], + "seq_lens": np.array([T] * B, np.float32), + } + # Input data for the ONNX session. + test_data_onnx = { + "obs": test_data_python["obs"], + "state_in_0": test_data_python["state_ins"][0], + "state_in_1": test_data_python["state_ins"][1], + "seq_lens": test_data_python["seq_lens"], } + # Input data for compiling the ONNX model. + test_data_onnx_input = convert_to_torch_tensor(test_data_onnx) # Start Ray and initialize a PPO Algorithm. ray.init() - algo = config.build(env="CartPole-v1") - - # You could train the model here - # algo.train() + algo = config.build() # Let's run inference on the torch model policy = algo.get_policy() result_pytorch, _ = policy.model( { - "obs": torch.tensor(test_data["obs"]), - } + "obs": torch.tensor(test_data_python["obs"]), + }, + [ + torch.tensor(test_data_python["state_ins"][0]), + torch.tensor(test_data_python["state_ins"][1]), + ], + torch.tensor(test_data_python["seq_lens"]), ) # Evaluate tensor to fetch numpy array result_pytorch = result_pytorch.detach().numpy() # This line will export the model to ONNX. - policy.export_model(outdir, onnx=11) - # Equivalent to: - # algo.export_policy_model(outdir, onnx=11) - - # Import ONNX model. - exported_model_file = os.path.join(outdir, "model.onnx") - + onnx_compatible = ONNXCompatibleWrapper(policy.model) + exported_model_file = "model.onnx" + torch.onnx.export( + onnx_compatible, + (test_data_onnx_input["obs"], test_data_onnx_input["state_in_0"], test_data_onnx_input["state_in_1"], test_data_onnx_input["seq_lens"]), + exported_model_file, + export_params=True, + opset_version=11, + do_constant_folding=True, + input_names=["obs", "state_in_0", "state_in_1", "seq_lens"], + output_names=["output", "state_out_0", "state_out_1"], + dynamic_axes={ + k: {0: "batch_size"} + for k in ["obs", "state_in_0", "state_in_1", "seq_lens"] + }, + ) # Start an inference session for the ONNX model session = onnxruntime.InferenceSession(exported_model_file, None) - - # Pass the same test batch to the ONNX model - if Version(torch.__version__) < Version("1.9.0"): - # In torch < 1.9.0 the second input/output name gets mixed up - test_data["state_outs"] = test_data.pop("state_ins") - - result_onnx = session.run(["output"], test_data) + result_onnx = session.run(["output"], test_data_onnx) # These results should be equal! print("PYTORCH", result_pytorch) - print("ONNX", result_onnx) + print("ONNX", result_onnx[0]) assert np.allclose( result_pytorch, result_onnx From e62ba191dca729809e0780775ce19518de361c86 Mon Sep 17 00:00:00 2001 From: sven1977 Date: Thu, 13 Jun 2024 20:37:09 +0200 Subject: [PATCH 2/5] lstm example working fine (with --use-lstm) option Signed-off-by: sven1977 --- rllib/examples/checkpoints/onnx_torch.py | 95 ++++++++++++++++-------- 1 file changed, 62 insertions(+), 33 deletions(-) diff --git a/rllib/examples/checkpoints/onnx_torch.py b/rllib/examples/checkpoints/onnx_torch.py index 6340486957b9..945177172ab7 100644 --- a/rllib/examples/checkpoints/onnx_torch.py +++ b/rllib/examples/checkpoints/onnx_torch.py @@ -1,16 +1,20 @@ # @OldAPIStack -from packaging.version import Version import numpy as np -import ray -import ray.rllib.algorithms.ppo as ppo import onnxruntime -import os -import shutil -import torch +import ray +import ray.rllib.algorithms.ppo as ppo +from ray.rllib.utils.framework import try_import_torch +from ray.rllib.utils.test_utils import add_rllib_example_script_args, check from ray.rllib.utils.torch_utils import convert_to_torch_tensor +torch, _ = try_import_torch() + +parser = add_rllib_example_script_args() +parser.add_argument("--use-lstm", action="store_true") +parser.set_defaults(num_env_runners=1) + class ONNXCompatibleWrapper(torch.nn.Module): def __init__(self, original_model): @@ -27,15 +31,21 @@ def forward(self, a, b0, b1, c): if __name__ == "__main__": + args = parser.parse_args() + + ray.init(local_mode=args.local_mode) + # Configure our PPO Algorithm. config = ( ppo.PPOConfig() # ONNX is not supported by RLModule API yet. - .api_stack(enable_rl_module_and_learner=False) + .api_stack( + enable_rl_module_and_learner=args.enable_new_api_stack, + enable_env_runner_and_connector_v2=args.enable_new_api_stack, + ) .environment("CartPole-v1") - .env_runners(num_env_runners=1) - .framework("torch") - .training(model={"use_lstm": True}) + .env_runners(num_env_runners=args.num_env_runners) + .training(model={"use_lstm": args.use_lstm}) ) B = 3 @@ -45,24 +55,31 @@ def forward(self, a, b0, b1, c): # Input data for a python inference forward call. test_data_python = { "obs": np.random.uniform(0, 1.0, size=(B*T, 4)).astype(np.float32), - "state_ins": [ - np.random.uniform(0, 1.0, size=(B, LSTM_CELL)).astype(np.float32), - np.random.uniform(0, 1.0, size=(B, LSTM_CELL)).astype(np.float32) - ], - "seq_lens": np.array([T] * B, np.float32), + "state_ins": np.array([0.0], dtype=np.float32), } # Input data for the ONNX session. test_data_onnx = { "obs": test_data_python["obs"], - "state_in_0": test_data_python["state_ins"][0], - "state_in_1": test_data_python["state_ins"][1], - "seq_lens": test_data_python["seq_lens"], } + + if args.use_lstm: + test_data_python.update({ + "state_ins": [ + np.random.uniform(0, 1.0, size=(B, LSTM_CELL)).astype(np.float32), + np.random.uniform(0, 1.0, size=(B, LSTM_CELL)).astype(np.float32) + ], + "seq_lens": np.array([T] * B, np.float32), + }) + test_data_onnx.update({ + "state_in_0": test_data_python["state_ins"][0], + "state_in_1": test_data_python["state_ins"][1], + "seq_lens": test_data_python["seq_lens"], + }) + # Input data for compiling the ONNX model. test_data_onnx_input = convert_to_torch_tensor(test_data_onnx) - # Start Ray and initialize a PPO Algorithm. - ray.init() + # Initialize a PPO Algorithm. algo = config.build() # You could train the model here @@ -77,29 +94,43 @@ def forward(self, a, b0, b1, c): [ torch.tensor(test_data_python["state_ins"][0]), torch.tensor(test_data_python["state_ins"][1]), - ], - torch.tensor(test_data_python["seq_lens"]), + ] if args.use_lstm else [], + torch.tensor(test_data_python["seq_lens"]) if args.use_lstm else None, ) # Evaluate tensor to fetch numpy array result_pytorch = result_pytorch.detach().numpy() # This line will export the model to ONNX. - onnx_compatible = ONNXCompatibleWrapper(policy.model) + onnx_compatible = policy.model + if args.use_lstm: + onnx_compatible = ONNXCompatibleWrapper(onnx_compatible) exported_model_file = "model.onnx" + input_names = [ + "obs", + "state_in_0", + "state_in_1", + "seq_lens", + ] if args.use_lstm else ["obs"] + torch.onnx.export( onnx_compatible, - (test_data_onnx_input["obs"], test_data_onnx_input["state_in_0"], test_data_onnx_input["state_in_1"], test_data_onnx_input["seq_lens"]), + ( + tuple(test_data_onnx_input[n] for n in input_names) + if args.use_lstm + else ({"obs": test_data_onnx_input["obs"]},) + ), exported_model_file, export_params=True, opset_version=11, do_constant_folding=True, - input_names=["obs", "state_in_0", "state_in_1", "seq_lens"], - output_names=["output", "state_out_0", "state_out_1"], - dynamic_axes={ - k: {0: "batch_size"} - for k in ["obs", "state_in_0", "state_in_1", "seq_lens"] - }, + input_names=input_names, + output_names=[ + "output", + "state_out_0", + "state_out_1", + ] if args.use_lstm else ["output"], + dynamic_axes={k: {0: "batch_size"} for k in input_names}, ) # Start an inference session for the ONNX model session = onnxruntime.InferenceSession(exported_model_file, None) @@ -109,7 +140,5 @@ def forward(self, a, b0, b1, c): print("PYTORCH", result_pytorch) print("ONNX", result_onnx[0]) - assert np.allclose( - result_pytorch, result_onnx - ), "Model outputs are NOT equal. FAILED" + check(result_pytorch, result_onnx[0]) print("Model outputs are equal. PASSED") From f23ae6e2e489051cf08d9360a9160eda5bbd6461 Mon Sep 17 00:00:00 2001 From: sven1977 Date: Thu, 13 Jun 2024 20:44:06 +0200 Subject: [PATCH 3/5] wip Signed-off-by: sven1977 --- rllib/BUILD | 9 ++ rllib/examples/checkpoints/onnx_torch.py | 145 +++++------------- rllib/examples/checkpoints/onnx_torch_lstm.py | 130 ++++++++++++++++ 3 files changed, 178 insertions(+), 106 deletions(-) create mode 100644 rllib/examples/checkpoints/onnx_torch_lstm.py diff --git a/rllib/BUILD b/rllib/BUILD index 5cd99351b97c..cac84c2dfcc2 100644 --- a/rllib/BUILD +++ b/rllib/BUILD @@ -2213,6 +2213,15 @@ py_test( srcs = ["examples/checkpoints/onnx_torch.py"], ) +#@OldAPIStack +py_test( + name = "examples/checkpoints/onnx_torch_lstm", + main = "examples/checkpoints/onnx_torch_lstm.py", + tags = ["team:rllib", "exclusive", "examples", "no_main"], + size = "small", + srcs = ["examples/checkpoints/onnx_torch_lstm.py"], +) + # subdirectory: connectors/ # .................................... # Framestacking examples only run in smoke-test mode (a few iters only). diff --git a/rllib/examples/checkpoints/onnx_torch.py b/rllib/examples/checkpoints/onnx_torch.py index 945177172ab7..c377a5c65663 100644 --- a/rllib/examples/checkpoints/onnx_torch.py +++ b/rllib/examples/checkpoints/onnx_torch.py @@ -1,86 +1,39 @@ # @OldAPIStack +from packaging.version import Version import numpy as np -import onnxruntime - import ray import ray.rllib.algorithms.ppo as ppo -from ray.rllib.utils.framework import try_import_torch -from ray.rllib.utils.test_utils import add_rllib_example_script_args, check -from ray.rllib.utils.torch_utils import convert_to_torch_tensor - -torch, _ = try_import_torch() - -parser = add_rllib_example_script_args() -parser.add_argument("--use-lstm", action="store_true") -parser.set_defaults(num_env_runners=1) - - -class ONNXCompatibleWrapper(torch.nn.Module): - def __init__(self, original_model): - super(ONNXCompatibleWrapper, self).__init__() - self.original_model = original_model - - def forward(self, a, b0, b1, c): - # Convert the separate tensor inputs back into the list format - # expected by the original model's forward method. - b = [b0, b1] - ret = self.original_model({"obs": a}, b, c) - # results, state_out_0, state_out_1 - return ret[0], ret[1][0], ret[1][1] - +import onnxruntime +import os +import shutil +import torch if __name__ == "__main__": - args = parser.parse_args() - - ray.init(local_mode=args.local_mode) - # Configure our PPO Algorithm. config = ( ppo.PPOConfig() # ONNX is not supported by RLModule API yet. - .api_stack( - enable_rl_module_and_learner=args.enable_new_api_stack, - enable_env_runner_and_connector_v2=args.enable_new_api_stack, - ) - .environment("CartPole-v1") - .env_runners(num_env_runners=args.num_env_runners) - .training(model={"use_lstm": args.use_lstm}) + .api_stack(enable_rl_module_and_learner=False) + .env_runners(num_env_runners=1) + .framework("torch") ) - B = 3 - T = 5 - LSTM_CELL = 256 + outdir = "export_torch" + if os.path.exists(outdir): + shutil.rmtree(outdir) + + np.random.seed(1234) - # Input data for a python inference forward call. - test_data_python = { - "obs": np.random.uniform(0, 1.0, size=(B*T, 4)).astype(np.float32), + # We will run inference with this test batch + test_data = { + "obs": np.random.uniform(0, 1.0, size=(10, 4)).astype(np.float32), "state_ins": np.array([0.0], dtype=np.float32), } - # Input data for the ONNX session. - test_data_onnx = { - "obs": test_data_python["obs"], - } - if args.use_lstm: - test_data_python.update({ - "state_ins": [ - np.random.uniform(0, 1.0, size=(B, LSTM_CELL)).astype(np.float32), - np.random.uniform(0, 1.0, size=(B, LSTM_CELL)).astype(np.float32) - ], - "seq_lens": np.array([T] * B, np.float32), - }) - test_data_onnx.update({ - "state_in_0": test_data_python["state_ins"][0], - "state_in_1": test_data_python["state_ins"][1], - "seq_lens": test_data_python["seq_lens"], - }) - - # Input data for compiling the ONNX model. - test_data_onnx_input = convert_to_torch_tensor(test_data_onnx) - - # Initialize a PPO Algorithm. - algo = config.build() + # Start Ray and initialize a PPO Algorithm. + ray.init() + algo = config.build(env="CartPole-v1") # You could train the model here # algo.train() @@ -89,56 +42,36 @@ def forward(self, a, b0, b1, c): policy = algo.get_policy() result_pytorch, _ = policy.model( { - "obs": torch.tensor(test_data_python["obs"]), - }, - [ - torch.tensor(test_data_python["state_ins"][0]), - torch.tensor(test_data_python["state_ins"][1]), - ] if args.use_lstm else [], - torch.tensor(test_data_python["seq_lens"]) if args.use_lstm else None, + "obs": torch.tensor(test_data["obs"]), + } ) # Evaluate tensor to fetch numpy array result_pytorch = result_pytorch.detach().numpy() # This line will export the model to ONNX. - onnx_compatible = policy.model - if args.use_lstm: - onnx_compatible = ONNXCompatibleWrapper(onnx_compatible) - exported_model_file = "model.onnx" - input_names = [ - "obs", - "state_in_0", - "state_in_1", - "seq_lens", - ] if args.use_lstm else ["obs"] - - torch.onnx.export( - onnx_compatible, - ( - tuple(test_data_onnx_input[n] for n in input_names) - if args.use_lstm - else ({"obs": test_data_onnx_input["obs"]},) - ), - exported_model_file, - export_params=True, - opset_version=11, - do_constant_folding=True, - input_names=input_names, - output_names=[ - "output", - "state_out_0", - "state_out_1", - ] if args.use_lstm else ["output"], - dynamic_axes={k: {0: "batch_size"} for k in input_names}, - ) + policy.export_model(outdir, onnx=11) + # Equivalent to: + # algo.export_policy_model(outdir, onnx=11) + + # Import ONNX model. + exported_model_file = os.path.join(outdir, "model.onnx") + # Start an inference session for the ONNX model session = onnxruntime.InferenceSession(exported_model_file, None) - result_onnx = session.run(["output"], test_data_onnx) + + # Pass the same test batch to the ONNX model + if Version(torch.__version__) < Version("1.9.0"): + # In torch < 1.9.0 the second input/output name gets mixed up + test_data["state_outs"] = test_data.pop("state_ins") + + result_onnx = session.run(["output"], test_data) # These results should be equal! print("PYTORCH", result_pytorch) - print("ONNX", result_onnx[0]) + print("ONNX", result_onnx) - check(result_pytorch, result_onnx[0]) + assert np.allclose( + result_pytorch, result_onnx + ), "Model outputs are NOT equal. FAILED" print("Model outputs are equal. PASSED") diff --git a/rllib/examples/checkpoints/onnx_torch_lstm.py b/rllib/examples/checkpoints/onnx_torch_lstm.py new file mode 100644 index 000000000000..85c0b556a58c --- /dev/null +++ b/rllib/examples/checkpoints/onnx_torch_lstm.py @@ -0,0 +1,130 @@ +# @OldAPIStack + +import numpy as np +import onnxruntime + +import ray +import ray.rllib.algorithms.ppo as ppo +from ray.rllib.utils.framework import try_import_torch +from ray.rllib.utils.test_utils import add_rllib_example_script_args, check +from ray.rllib.utils.torch_utils import convert_to_torch_tensor + +torch, _ = try_import_torch() + +parser = add_rllib_example_script_args() +parser.set_defaults(num_env_runners=1) + + +class ONNXCompatibleWrapper(torch.nn.Module): + def __init__(self, original_model): + super(ONNXCompatibleWrapper, self).__init__() + self.original_model = original_model + + def forward(self, a, b0, b1, c): + # Convert the separate tensor inputs back into the list format + # expected by the original model's forward method. + b = [b0, b1] + ret = self.original_model({"obs": a}, b, c) + # results, state_out_0, state_out_1 + return ret[0], ret[1][0], ret[1][1] + + +if __name__ == "__main__": + args = parser.parse_args() + + ray.init(local_mode=args.local_mode) + + # Configure our PPO Algorithm. + config = ( + ppo.PPOConfig() + # ONNX is not supported by RLModule API yet. + .api_stack( + enable_rl_module_and_learner=args.enable_new_api_stack, + enable_env_runner_and_connector_v2=args.enable_new_api_stack, + ) + .environment("CartPole-v1") + .env_runners(num_env_runners=args.num_env_runners) + .training(model={"use_lstm": True}) + ) + + B = 3 + T = 5 + LSTM_CELL = 256 + + # Input data for a python inference forward call. + test_data_python = { + "obs": np.random.uniform(0, 1.0, size=(B * T, 4)).astype(np.float32), + "state_ins": [ + np.random.uniform(0, 1.0, size=(B, LSTM_CELL)).astype(np.float32), + np.random.uniform(0, 1.0, size=(B, LSTM_CELL)).astype(np.float32), + ], + "seq_lens": np.array([T] * B, np.float32), + } + # Input data for the ONNX session. + test_data_onnx = { + "obs": test_data_python["obs"], + "state_in_0": test_data_python["state_ins"][0], + "state_in_1": test_data_python["state_ins"][1], + "seq_lens": test_data_python["seq_lens"], + } + + # Input data for compiling the ONNX model. + test_data_onnx_input = convert_to_torch_tensor(test_data_onnx) + + # Initialize a PPO Algorithm. + algo = config.build() + + # You could train the model here + # algo.train() + + # Let's run inference on the torch model + policy = algo.get_policy() + result_pytorch, _ = policy.model( + { + "obs": torch.tensor(test_data_python["obs"]), + }, + [ + torch.tensor(test_data_python["state_ins"][0]), + torch.tensor(test_data_python["state_ins"][1]), + ], + torch.tensor(test_data_python["seq_lens"]), + ) + + # Evaluate tensor to fetch numpy array + result_pytorch = result_pytorch.detach().numpy() + + # This line will export the model to ONNX. + onnx_compatible = ONNXCompatibleWrapper(policy.model) + exported_model_file = "model.onnx" + input_names = [ + "obs", + "state_in_0", + "state_in_1", + "seq_lens", + ] + + torch.onnx.export( + onnx_compatible, + tuple(test_data_onnx_input[n] for n in input_names), + exported_model_file, + export_params=True, + opset_version=11, + do_constant_folding=True, + input_names=input_names, + output_names=[ + "output", + "state_out_0", + "state_out_1", + ], + dynamic_axes={k: {0: "batch_size"} for k in input_names}, + ) + # Start an inference session for the ONNX model + session = onnxruntime.InferenceSession(exported_model_file, None) + result_onnx = session.run(["output"], test_data_onnx) + + # These results should be equal! + print("PYTORCH", result_pytorch) + print("ONNX", result_onnx[0]) + + check(result_pytorch, result_onnx[0]) + print("Model outputs are equal. PASSED") From 305586bcc2424433f633639a5ad57278e5ecc30b Mon Sep 17 00:00:00 2001 From: sven1977 Date: Fri, 14 Jun 2024 13:15:21 +0200 Subject: [PATCH 4/5] wip Signed-off-by: sven1977 --- rllib/examples/checkpoints/onnx_torch_lstm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rllib/examples/checkpoints/onnx_torch_lstm.py b/rllib/examples/checkpoints/onnx_torch_lstm.py index 85c0b556a58c..735289310273 100644 --- a/rllib/examples/checkpoints/onnx_torch_lstm.py +++ b/rllib/examples/checkpoints/onnx_torch_lstm.py @@ -118,7 +118,7 @@ def forward(self, a, b0, b1, c): ], dynamic_axes={k: {0: "batch_size"} for k in input_names}, ) - # Start an inference session for the ONNX model + # Start an inference session for the ONNX model. session = onnxruntime.InferenceSession(exported_model_file, None) result_onnx = session.run(["output"], test_data_onnx) From fd5e479b02fe0a3e16b93075d40193329a0fe4ac Mon Sep 17 00:00:00 2001 From: sven1977 Date: Fri, 14 Jun 2024 14:18:59 +0200 Subject: [PATCH 5/5] wip Signed-off-by: sven1977 --- rllib/examples/checkpoints/onnx_torch_lstm.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/rllib/examples/checkpoints/onnx_torch_lstm.py b/rllib/examples/checkpoints/onnx_torch_lstm.py index 735289310273..d95a282a3a30 100644 --- a/rllib/examples/checkpoints/onnx_torch_lstm.py +++ b/rllib/examples/checkpoints/onnx_torch_lstm.py @@ -32,6 +32,10 @@ def forward(self, a, b0, b1, c): if __name__ == "__main__": args = parser.parse_args() + assert ( + not args.enable_new_api_stack + ), "Must NOT set --enable-new-api-stack when running this script!" + ray.init(local_mode=args.local_mode) # Configure our PPO Algorithm. @@ -93,7 +97,8 @@ def forward(self, a, b0, b1, c): # Evaluate tensor to fetch numpy array result_pytorch = result_pytorch.detach().numpy() - # This line will export the model to ONNX. + # Wrap the actual ModelV2 with the torch wrapper above to make this all work with + # LSTMs (extra `state` in- and outputs and `seq_lens` inputs). onnx_compatible = ONNXCompatibleWrapper(policy.model) exported_model_file = "model.onnx" input_names = [ @@ -103,6 +108,7 @@ def forward(self, a, b0, b1, c): "seq_lens", ] + # This line will export the model to ONNX. torch.onnx.export( onnx_compatible, tuple(test_data_onnx_input[n] for n in input_names),