diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc index cb0a0ed432a64..3e1b03871e18d 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc @@ -14,6 +14,7 @@ #include "core/providers/cuda/shared_inc/cuda_call.h" #include "core/providers/cuda/math/unary_elementwise_ops_impl.h" #include "core/providers/cuda/gpu_data_transfer.h" +#include "core/session/allocator_adapters.h" #include "cuda_runtime_api.h" #include "core/common/gsl.h" #include @@ -991,6 +992,12 @@ TensorrtExecutionProvider::~TensorrtExecutionProvider() { ORT_IGNORE_RETURN_VALUE(CUDA_CALL(cudaStreamDestroy(stream_))); } ReleaseTensorRTCustomOpDomainList(info_.custom_op_domain_list); + + if (alloc_ != nullptr) { + // This code is same as OrtApis::ReleaseAllocator defined in allocator_adapters.cc. + // We can't get api inside destructor so that's why we duplicate the code here. + delete static_cast(alloc_); + } } bool TensorrtExecutionProvider::IsGraphCaptureEnabled() const { @@ -2213,15 +2220,18 @@ common::Status TensorrtExecutionProvider::Compile(const std::vectorcontext->get(); auto trt_profiles = trt_state->profiles; auto max_context_mem_size_ptr = trt_state->max_context_mem_size_ptr; - OrtMemoryInfo mem_info("", OrtAllocatorType::OrtDeviceAllocator, OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, 0)); - OrtAllocator* alloc; - Ort::ThrowOnError(api->KernelContext_GetAllocator(context, &mem_info, &alloc)); int num_inputs = static_cast(input_indexes.size()); int num_outputs = static_cast(output_indexes.size()); bool engine_update = false; std::unordered_set input_names; std::unordered_map> tensor_shape_values; + OrtMemoryInfo mem_info("", OrtAllocatorType::OrtDeviceAllocator, OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, 0)); + if (alloc_ == nullptr) { + Ort::ThrowOnError(api->KernelContext_GetAllocator(context, &mem_info, &alloc_)); + } + OrtAllocator* alloc = alloc_; + void* cuda_stream; Ort::ThrowOnError(api->KernelContext_GetGPUComputeStream(context, &cuda_stream)); cudaStream_t stream = static_cast(cuda_stream); diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h index 35e7a110c5ed7..56eda7ad83537 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h @@ -214,6 +214,10 @@ class TensorrtExecutionProvider : public IExecutionProvider { bool detailed_build_log_ = false; bool cuda_graph_enable_ = false; + // The OrtAllocator object will be get during ep compute time + // and should be kept for the lifetime of TRT EP object. + OrtAllocator* alloc_ = nullptr; + std::unique_ptr cuda_graph_; // ORT TRT only supports CUDA graph when whole model is supported by TRT, so simply maintaining a CUDAGraph pointer is enough (no need to maintain one CUDAGraph pointer per TRT subgraph) bool is_graph_captured_ = false; int regular_run_count_before_graph_capture_ = 0;