Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TensorRT execution provider SEGFAULT #7757

Open
pvavercak opened this issue May 19, 2021 · 21 comments
Open

TensorRT execution provider SEGFAULT #7757

pvavercak opened this issue May 19, 2021 · 21 comments
Labels
ep:TensorRT issues related to TensorRT execution provider

Comments

@pvavercak
Copy link
Contributor

Hi guys,
I'm experiencing an issue with the TensorRT execution provider on Jetson Xavier with JetPack 4.4. Unfortunately, I can't share my model with you, but I was hoping if any of you faced the same issue.

Describe the bug
Once the TensorRT execution provider is added to session options, loading a model fails with the following output.
InitSession Error: ! Exception during initialization: onnxruntime/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc:756 SubGraphCollection_t onnxruntime::TensorrtExecutionProvider::GetSupportedList(SubGraphCollection_t, int, int, const onnxruntime::GraphViewer&, bool*) const graph_build.Resolve().IsOK() was false.

Stack trace:

#0 0x0000007f8b59844c in std::__atomic_base::compare_exchange_strong (__m2=std::memory_order_relaxed, __m1=std::memory_order_acquire, __i2=1, __i1=@0x7f7ba4ac94: 0,
this=0x850) at /opt/gcc-linaro-7.3.1-2018.05-x86_64_aarch64-linux-gnu/aarch64-linux-gnu/include/c++/7.3.1/bits/atomic_base.h:477
#1 std::atomic_compare_exchange_strong_explicit (__a=0x850, __i1=0x7f7ba4ac94, __i2=1, __m1=std::memory_order_acquire, _m2=std::memory_order_relaxed)
at /opt/gcc-linaro-7.3.1-2018.05-x86_64_aarch64-linux-gnu/aarch64-linux-gnu/include/c++/7.3.1/atomic:1125
#2 0x0000007f8b598818 in nsync::atm_cas_acq_u32
(p=0x850, o=0, n=1)
at onnxruntime/cmake/external/nsync/platform/c++11/atomic.h:73
#3 0x0000007f8b598c14 in nsync::nsync_mu_lock (mu=0x850) at external/nsync/cpp/internal/mu.c:148
#4 0x0000007f8a4f19bc in onnxruntime::OrtMutex::lock (this=0x850)
at onnxruntime/include/onnxruntime/core/platform/ort_mutex.h:119
#5 0x0000007f8a4f252c in std::lock_guardonnxruntime::OrtMutex::lock_guard (this=0x7f7ba4ad50, __m=...)
at /opt/gcc-linaro-7.3.1-2018.05-x86_64_aarch64-linux-gnu/aarch64-linux-gnu/include/c++/7.3.1/bits/std_mutex.h:162
#6 0x0000007f8a5319bc in onnxruntime::InferenceSession::GetModelInputs (this=0x0)
at onnxruntime/onnxruntime/core/session/inference_session.cc:1658
#7 0x0000007f8a4c1208 in <lambda(const onnxruntime::InferenceSession*)>::operator()(const onnxruntime::InferenceSession ) const (__closure=0x0, session=0x0)
at onnxruntime/onnxruntime/core/session/onnxruntime_c_api.cc:896
#8 0x0000007f8a4c123c in <lambda(const onnxruntime::InferenceSession
)>::_FUN(const onnxruntime::InferenceSession ) ()
at onnxruntime/onnxruntime/core/session/onnxruntime_c_api.cc:896
#9 0x0000007f8a4c1398 in GetNodeDefListCountHelper (sess=0x0, get_fn=0x7f8a4c1218 <<lambda(const onnxruntime::InferenceSession
)>::_FUN(const onnxruntime::InferenceSession *)>,
out=0x7f7ba4b0f8) at onnxruntime/onnxruntime/core/session/onnxruntime_c_api.cc:903
#10 0x0000007f8a4c14dc in OrtApis::SessionGetInputCount (sess=0x0, out=0x7f7ba4b0f8)
at onnxruntime/onnxruntime/core/session/onnxruntime_c_api.cc:912
#11 0x0000007fb4f37e2c in Ort::Session::GetInputCount (this=0x7f74002888)
at onnxruntime/include/onnxruntime/core/session/onnxruntime_cxx_inline.h:534
...
...

System information

  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04): ARMv8 Jetson Xavier with Ubuntu 18.04.5 LTS
  • ONNX Runtime installed from (source or binary): source
  • ONNX Runtime version: 1.7.1
  • GCC/Compiler version (if compiling from source): gcc-linaro-7.3.1-2018.05-x86_64_aarch64-linux-gnu
  • CUDA/cuDNN version: 10.2.89/8.0.0
  • GPU model and memory: -

Additional context
If I try other EPs ( Cuda or CPU), the error disappears so it only relates to TensorRT.

@oliviajain
Copy link
Contributor

Hi Patrik, just want to check have you tried rebooting your Jetson?

@RyanUnderhill RyanUnderhill added ep:TensorRT issues related to TensorRT execution provider type:support labels May 20, 2021
@pvavercak
Copy link
Contributor Author

Hi Olivia,

Yes, I have and unfortunately, It didn't help.
Just a little note, I was using ENV variables for TensorRT caching, e.g. ORT_TENSORRT_ENGINE_CACHE_ENABLE and ORT_TENSORRT_CACHE_PATH. After the reboot, I turned them off just to make sure that my model is being compiled to TensorRT from scratch.

I know that this segfault is model-related, because on the same Jetson device, with the same environment, other models are working properly. I'm just trying to find out if there is a known issue with TensorRT and some models that have a specific architecture or contain a particular layer. This issue mentions a similar problem, but I don't know if this is my case since my Jetson Xavier has a TensorRT 7.1.3 installed.
I tried CPU and CUDA execution providers for my NN on the Jetson and it worked without any problems.

@oliviajain oliviajain added the platform:jetson issues related to the NVIDIA Jetson platform label May 20, 2021
@pvavercak
Copy link
Contributor Author

Hi Olivia,

I tested this behavior also on RTX2080 and the same issue occurred. The segfault was made by my inattention when I put Ort::Session constructor into a try-catch block, caught an exception, logged an error but didn't exit the process. After that, my code was executing further and it hit a line with Ort::Session::GetInputCount(), which caused the SIGSEGV.

However, one problem still bothers me and it's a message I get from Onnxruntime when I call Ort::Session constructor:

onnxruntime/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc:756 SubGraphCollection_t onnxruntime::TensorrtExecutionProvider::GetSupportedList(SubGraphCollection_t, int, int, const onnxruntime::GraphViewer&, bool*) const graph_build.Resolve().IsOK() was false.

As I mentioned before, the same model works with CUDA and CPU EPs so TensorRT is the problem here.
I have Reshape layer in the problematic network. The shape contains -1 which is valid for inferring from remaining dimensions according to this doc.

Do you have any suggestions? Does this question relate to onnxruntime?

@pvavercak
Copy link
Contributor Author

Hello,

I just wanted to ask if you can get me some hints. I'm quite stuck on this issue right now.

Thanks a lot.

@jywu-msft
Copy link
Member

because Jetson only supports TensorRT 7.1.x , can you try this prior to building OnnxRuntime with TensorRT EP?
git config --file=.gitmodules submodule.cmake/external/onnx-tensorrt.branch 7.1
git submodule update --remote cmake/external/onnx-tensorrt
Then build as normal.

@pvavercak
Copy link
Contributor Author

Hi George,

As I mentioned above, the issue was also simulated on NVIDIA RTX2080 so it does not relate to Jetson.

However, I did onnx-tensorrt submodule checkout to 7.1 and tried it on Jetson and RTX. It did not help.

@jywu-msft
Copy link
Member

thanks for trying. just to confirm, on RTX2080, you're using TensorRT 7.1 or 7.2 ?
I know you can't share the model, but is there any way to provide a minimal repro? or provide some more details about the subgraph which is encountering this?
+@stevenlix , do you have any ideas about this Graph resolve error in GetSupportedList() ?

@jywu-msft jywu-msft removed the platform:jetson issues related to the NVIDIA Jetson platform label Jun 2, 2021
@pvavercak
Copy link
Contributor Author

Both, a system with the RTX (CentOS7) and Jetson have the same versions of cuda/cudnn/tensorrt libraries:
cuda 10.2
cudnn 8.0.x
tensorrt 7.1.3

I'm not going to promise you but I'll do my best with the minimal repro.

@pvavercak
Copy link
Contributor Author

Hello again,

I've attached a smaller model which causes the same error as described above.

net_SSH_ter_1000.zip

Please, let me know if you were able to simulate the problem.

@pvavercak
Copy link
Contributor Author

Hi everyone,

I just wanted to know if you were able to simulate the issue.

Thank you so much.

@jywu-msft
Copy link
Member

sorry haven't been able to test this model yet. will do so.
most likely the issue you are running into is related to needing to run symbolic shape inference script on the model. (there was an issue with error reporting in ort 1.7.1 that results in segfault instead of error message instructing users to run the script - that should be fixed in master now)
The model has dynamic shaped input [N, 3, N, N]. TensorRT requires shapes of inputs to be defined, so we need to run an offline script to propagate symbolic shapes through the graph.
see https://www.onnxruntime.ai/docs/reference/execution-providers/TensorRT-ExecutionProvider.html#shape-inference-for-tensorrt-subgraphs

e.g.
python3 symbolic_shape_infer.py --input net_SSH_ter_1000.onnx --output net_SSH_ter_1000_shape.onnx --auto_merge
if it succeeds, try running on the updated model.

@jywu-msft
Copy link
Member

I confirmed that building TensorRT EP with rel-1.8.0 branch and running the symbolic_shape_infer.py on 'net_SSH_ter_1000.onnx' first, the model could be loaded and session created successfully.

@pvavercak
Copy link
Contributor Author

Thank you for the answer. Running the script on my target model did the trick and now the TensorRT EP works.
However, the "*-shape" model gives me slightly different results in comparison with the previous one.
I also noticed that the net_SSH_ter_1000_shape.onnx gives me different results after changing an EP. For instance, if I set the EP from TensorRT to CUDA, results are better (or more accurate) as with TRT.

How is that possible?

@jywu-msft
Copy link
Member

outputs may vary due to differences in implementations across TensorRT and CUDA kernels.
can you quantify "slightly different results" ?
Just to confirm, we're not talking about wrong results (bug) right?

@pvavercak
Copy link
Contributor Author

sorry for my inactivity.

I was able to find a workaround. Erasing the Reshape layers with -1 in input shapes fixed the TensorRT inference.
However, I had to implement a little postprocessing containing reshape operator and also the softmax function since I was using these two operators together. Now the results remain the same.
But "slightly different results" in my case means that the output vectors from my layers were inconsistent and I didn't see any correlation between shaped models and non-shaped ones.

@faxu faxu removed the type:support label Aug 18, 2021
@stale
Copy link

stale bot commented Apr 19, 2022

This issue has been automatically marked as stale due to inactivity and will be closed in 7 days if no further activity occurs. If further support is needed, please provide an update and/or more details.

@stale stale bot added the stale issues that have not been addressed in a while; categorized by a bot label Apr 19, 2022
@Sakura-Luna
Copy link

I encountered this problem when operating a UNet model on ORT-GPU 1.15. It works fine if using CUDAExecutionProvider.
TensorrtExecutionProvider::GetSupportedList graph_build.Resolve().IsOK() was false.

According to the documentation, I specified all input dynamic axis dimensions, but it doesn't work, below is the main code.

sess_options = ort.SessionOptions()

sess_options.add_free_dimension_override_by_name("unet_sample_batch", 2)
sess_options.add_free_dimension_override_by_name("unet_sample_channels", 4)
sess_options.add_free_dimension_override_by_name("unet_sample_height", 64)
sess_options.add_free_dimension_override_by_name("unet_sample_width", 64)
sess_options.add_free_dimension_override_by_name("unet_time_batch", 1)
sess_options.add_free_dimension_override_by_name("unet_hidden_batch", 2)
sess_options.add_free_dimension_override_by_name("unet_hidden_sequence", 77)

trt_ep_options = {
    "trt_fp16_enable": True,
    "trt_engine_cache_enable": True,
    "trt_profile_min_shapes": "sample:2x4x64x64,timestep:1,encoder_hidden_states:2x77x768",
    "trt_profile_max_shapes": "sample:32x4x64x64,timestep:1,encoder_hidden_states:32x77x768",
    "trt_profile_opt_shapes": "sample:2x4x64x64,timestep:1,encoder_hidden_states:2x77x768",
}
providers = [('TensorrtExecutionProvider', trt_ep_options)]
sess = ort.InferenceSession(model, providers=providers, sess_options=sess_options)

@stale stale bot removed the stale issues that have not been addressed in a while; categorized by a bot label Jun 2, 2023
@jywu-msft
Copy link
Member

I encountered this problem when operating a UNet model on ORT-GPU 1.15. It works fine if using CUDAExecutionProvider. TensorrtExecutionProvider::GetSupportedList graph_build.Resolve().IsOK() was false.

According to the documentation, I specified all input dynamic axis dimensions, but it doesn't work, below is the main code.

sess_options = ort.SessionOptions()

sess_options.add_free_dimension_override_by_name("unet_sample_batch", 2)
sess_options.add_free_dimension_override_by_name("unet_sample_channels", 4)
sess_options.add_free_dimension_override_by_name("unet_sample_height", 64)
sess_options.add_free_dimension_override_by_name("unet_sample_width", 64)
sess_options.add_free_dimension_override_by_name("unet_time_batch", 1)
sess_options.add_free_dimension_override_by_name("unet_hidden_batch", 2)
sess_options.add_free_dimension_override_by_name("unet_hidden_sequence", 77)

trt_ep_options = {
    "trt_fp16_enable": True,
    "trt_engine_cache_enable": True,
    "trt_profile_min_shapes": "sample:2x4x64x64,timestep:1,encoder_hidden_states:2x77x768",
    "trt_profile_max_shapes": "sample:32x4x64x64,timestep:1,encoder_hidden_states:32x77x768",
    "trt_profile_opt_shapes": "sample:2x4x64x64,timestep:1,encoder_hidden_states:2x77x768",
}
providers = [('TensorrtExecutionProvider', trt_ep_options)]
sess = ort.InferenceSession(model, providers=providers, sess_options=sess_options)

what is the error encountered?
btw, can you please file a new issue (with repro assets/instructions) so we can have someone look at it?
issue#7757 is from 2 years ago and probably unrelated to what you are seeing. thanks.

@Sakura-Luna
Copy link

Sakura-Luna commented Jun 3, 2023

TensorrtExecutionProvider::GetSupportedList graph_build.Resolve().IsOK() was false.

The error is

onnxruntime.capi.onnxruntime_pybind11_state.RuntimeException: 
    [ONNXRuntimeError] : 6 : RUNTIME_EXCEPTION : Exception during initialization: 
    tensorrt_execution_provider.cc:1352 
    onnxruntime::TensorrtExecutionProvider::GetSupportedList graph_build.Resolve().IsOK() was false.

Although the environment is different, it looks like the error is the same. The following is the operating environment

Win10 x64, ONNX Runtime 1.15.0 Released Package, Python, TensorrtExecutionProvider, CUDA 11.8+Tensorrt 8.6

@Sakura-Luna
Copy link

I tried symbolic_shape_infer.py --auto_merge --guess_output_rank and got the same error.

2023-06-04 01:08:26.8714279 [E:onnxruntime:Default, tensorrt_execution_provider.h:73 
onnxruntime::TensorrtLogger::log] [2023-06-03 17:08:26   ERROR] 3: 
getPluginCreator could not find plugin: GroupNorm version: 1 

2023-06-04 01:08:39.4348212 [E:onnxruntime:, inference_session.cc:1645 
onnxruntime::InferenceSession::Initialize::
<lambda_eb486adf513608dcd45c034ea7ffb8e8>::operator ()] Exception during initialization: 

onnxruntime.capi.onnxruntime_pybind11_state.RuntimeException: [ONNXRuntimeError] : 6 : 
RUNTIME_EXCEPTION : Exception during initialization: tensorrt_execution_provider.cc:1352 
onnxruntime::TensorrtExecutionProvider::GetSupportedList graph_build.Resolve().IsOK() was false. 

@Sakura-Luna
Copy link

Can ORT TensorrtExecutionProvider complete the construction of the UNet model? I noticed that there seems to be no related description. If yes, is there a general process?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ep:TensorRT issues related to TensorRT execution provider
Projects
None yet
Development

No branches or pull requests

6 participants