Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ExecutionProvider API refactor - move allocator from EP level to SessionState level and indexed by OrtDevice #15833

Merged
merged 12 commits into from
Jun 20, 2023
2 changes: 2 additions & 0 deletions include/onnxruntime/core/framework/allocator.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include "core/session/onnxruntime_c_api.h"
#include "ortdevice.h"
#include "ortmemoryinfo.h"
#include <map>

// This configures the arena based allocator used by ORT
// See docs/C_API.md for details on what these mean and how to choose these values
Expand Down Expand Up @@ -210,6 +211,7 @@ class CPUAllocator : public IAllocator {
};

using AllocatorPtr = std::shared_ptr<IAllocator>;
using AllocatorMap = std::map<OrtDevice, AllocatorPtr>;

void* AllocatorDefaultAlloc(size_t size);
void AllocatorDefaultFree(void* p);
Expand Down
2 changes: 1 addition & 1 deletion include/onnxruntime/core/framework/execution_provider.h
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,7 @@ class IExecutionProvider {
return DataLayout::NCHW;
}

virtual void RegisterStreamHandlers(IStreamCommandHandleRegistry& /*stream_handle_registry*/, std::map<OrtDevice, AllocatorPtr>&) const {}
virtual void RegisterStreamHandlers(IStreamCommandHandleRegistry& /*stream_handle_registry*/, AllocatorMap&) const {}

/** Does the EP support concurrent calls to InferenceSession::Run to execute the model.
*/
Expand Down
6 changes: 3 additions & 3 deletions include/onnxruntime/core/framework/op_kernel_info.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class OpKernelInfo : public OpNodeProtoHelper<ProtoHelperNodeContext> {
const std::unordered_map<int, OrtValue>& constant_initialized_tensors,
const OrtValueNameIdxMap& mlvalue_name_idx_map,
const DataTransferManager& data_transfer_mgr,
const std::map<OrtDevice, AllocatorPtr>& allocators = {});
const AllocatorMap& allocators = {});

OpKernelInfo(const OpKernelInfo& other);

Expand All @@ -48,7 +48,7 @@ class OpKernelInfo : public OpNodeProtoHelper<ProtoHelperNodeContext> {

bool TryGetConstantInput(int input_index, const OrtValue** constant_input_value) const;

const std::map<OrtDevice, AllocatorPtr>& GetAllocators() const { return allocators_; }
const AllocatorMap& GetAllocators() const { return allocators_; }

private:
ORT_DISALLOW_MOVE(OpKernelInfo);
Expand All @@ -63,7 +63,7 @@ class OpKernelInfo : public OpNodeProtoHelper<ProtoHelperNodeContext> {
const OrtValueNameIdxMap& ort_value_name_idx_map_;
const DataTransferManager& data_transfer_mgr_;
ProtoHelperNodeContext proto_helper_context_;
const std::map<OrtDevice, AllocatorPtr>& allocators_;
const AllocatorMap& allocators_;
};

} // namespace onnxruntime
6 changes: 3 additions & 3 deletions include/onnxruntime/core/session/onnxruntime_c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -4300,9 +4300,9 @@ struct OrtApi {
/// @}

/** \brief Create and register Cuda Allocator
* \param[in] env OrtEnv instance
* \param[in] mem_info OrtMemoryInfo instance
*/
* \param[in] env OrtEnv instance
* \param[in] mem_info OrtMemoryInfo instance
jslhcl marked this conversation as resolved.
Show resolved Hide resolved
*/
ORT_API2_STATUS(CreateAndRegisterCudaAllocator, _Inout_ OrtEnv* env, _In_ const OrtMemoryInfo* mem_info, _In_ const OrtArenaCfg* arena_cfg,
jslhcl marked this conversation as resolved.
Show resolved Hide resolved
_In_reads_(num_keys) const char* const* provider_options_keys, _In_reads_(num_keys) const char* const* provider_options_values, _In_ size_t num_keys);
};
Expand Down
8 changes: 4 additions & 4 deletions onnxruntime/core/framework/device_stream_collection.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ struct DummyStream : Stream {

class DeviceStreamCollectionImpl {
public:
DeviceStreamCollectionImpl(size_t num_streams, const std::map<OrtDevice, AllocatorPtr>& allocators, bool is_main_graph) : num_streams_(num_streams), allocators_(allocators), is_main_graph_(is_main_graph) {
DeviceStreamCollectionImpl(size_t num_streams, const AllocatorMap& allocators, bool is_main_graph) : num_streams_(num_streams), allocators_(allocators), is_main_graph_(is_main_graph) {
device_streams_.resize(num_streams, nullptr);
owned_streams_.reserve(num_streams);
root_stream_ = std::make_unique<DummyStream>(nullptr, root_stream_device_);
Expand Down Expand Up @@ -91,7 +91,7 @@ class DeviceStreamCollectionImpl {
std::vector<Stream*> device_streams_;
InlinedVector<std::unique_ptr<Stream>> owned_streams_;
// TODO(leca): review
const std::map<OrtDevice, AllocatorPtr>& allocators_;
const AllocatorMap& allocators_;
bool is_main_graph_ = false;
// This is used in ExecutionFrame when memory pattern is enabled, to allocate the peak size memory
// labelled this stream in the current thread, instead of the default stream which will be used in all the threads (thus caused thread safe issue)
Expand All @@ -100,8 +100,8 @@ class DeviceStreamCollectionImpl {
void ReleaseSingleStreamBuffers();
};

DeviceStreamCollection::DeviceStreamCollection(size_t num_streams, const std::map<OrtDevice, AllocatorPtr>& allocators, bool is_main_graph)
: impl_(std::make_unique<DeviceStreamCollectionImpl>(num_streams, allocators, is_main_graph)) {}
DeviceStreamCollection::DeviceStreamCollection(size_t num_streams, const AllocatorMap& allocators, bool is_main_graph)
: impl_(std::make_unique<DeviceStreamCollectionImpl>(num_streams, allocators, is_main_graph)) {}

DeviceStreamCollection::~DeviceStreamCollection() {}

Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/framework/device_stream_collection.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class DeviceStreamCollectionImpl;
// this collection may be cached and reused for future iterations.
class DeviceStreamCollection {
public:
DeviceStreamCollection(size_t num_streams, const std::map<OrtDevice, AllocatorPtr>& allocators, bool is_main_graph);
DeviceStreamCollection(size_t num_streams, const AllocatorMap& allocators, bool is_main_graph);
~DeviceStreamCollection();
// Add the device stream instance to given index.
// and set the current collection as the owner of the device stream.
Expand Down
5 changes: 3 additions & 2 deletions onnxruntime/core/framework/op_kernel_info.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,16 @@ OpKernelInfo::OpKernelInfo(const onnxruntime::Node& node,
const std::unordered_map<int, OrtValue>& constant_initialized_tensors,
const OrtValueNameIdxMap& ort_value_name_idx_map,
const DataTransferManager& data_transfer_mgr,
const std::map<OrtDevice, AllocatorPtr>& allocators)
const AllocatorMap& allocators)
: OpNodeProtoHelper(&proto_helper_context_),
node_(node),
kernel_def_(kernel_def),
execution_provider_(&execution_provider),
constant_initialized_tensors_(constant_initialized_tensors),
ort_value_name_idx_map_(ort_value_name_idx_map),
data_transfer_mgr_(data_transfer_mgr),
proto_helper_context_(node), allocators_(allocators){}
proto_helper_context_(node),
allocators_(allocators) {}

OpKernelInfo::OpKernelInfo(const OpKernelInfo& other)
: OpKernelInfo(other.node_, other.kernel_def_, *other.execution_provider_, other.constant_initialized_tensors_,
Expand Down
11 changes: 9 additions & 2 deletions onnxruntime/core/framework/session_state.cc
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ SessionState::SessionState(Graph& graph,
profiling::Profiler& profiler,
const SessionOptions& sess_options,
PrepackedWeightsContainer* prepacked_weights_container,
std::shared_ptr<std::map<OrtDevice, AllocatorPtr>> parent_allocators)
AllocatorMap* parent_allocators)
: graph_(graph),
execution_providers_(execution_providers),
logger_(logger),
Expand All @@ -90,7 +90,8 @@ SessionState::SessionState(Graph& graph,
if (parent_allocators) {
allocators_ = parent_allocators;
} else {
allocators_ = std::make_shared<std::map<OrtDevice, AllocatorPtr>>();
allocators_unique_ptr_ = std::make_unique<AllocatorMap>();
allocators_ = allocators_unique_ptr_.get();
// The allocator registration rule:
// Each location (OrtDevice) will only have 1 allocator used for whole session.
// The EP which is registered first will have higher priority
Expand All @@ -113,6 +114,12 @@ AllocatorPtr SessionState::GetAllocator(const OrtDevice& device) const noexcept
return nullptr;
}

void SessionState::UpdateAllocatorsWithEnvAllocators(const std::vector<AllocatorPtr>& env_allocators) {
for (const auto& env_alloc : env_allocators) {
(*allocators_)[env_alloc->Info().device] = env_alloc;
}
}

void SessionState::CreateGraphInfo() {
graph_viewer_.emplace(graph_);
// use graph_viewer_ to initialize ort_value_name_idx_map_
Expand Down
13 changes: 8 additions & 5 deletions onnxruntime/core/framework/session_state.h
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ class SessionState {
profiling::Profiler& profiler,
const SessionOptions& sess_options,
PrepackedWeightsContainer* prepacked_weights_container = nullptr,
std::shared_ptr<std::map<OrtDevice, AllocatorPtr>> parent_allocators = nullptr);
AllocatorMap* parent_allocators = nullptr);

~SessionState() {
for (auto& kvp : deleter_for_initialized_tensors_) {
Expand Down Expand Up @@ -133,9 +133,11 @@ class SessionState {
AllocatorPtr GetAllocator(const OrtDevice& device) const noexcept;

/*
* Get allocators. CANNOT be const member function as allocators_ will be changed after SessionState's initialization for shared allocator scenario (InferenceSession::UpdateSessionStateAllocatorsWithSharedAllocators())
*/
std::map<OrtDevice, AllocatorPtr>& GetAllocators() { return *allocators_; }
* Get allocators.
*/
const AllocatorMap& GetAllocators() { return *allocators_; }
jslhcl marked this conversation as resolved.
Show resolved Hide resolved

void UpdateAllocatorsWithEnvAllocators(const std::vector<AllocatorPtr>&);

const OrtValueNameIdxMap& GetOrtValueNameIdxMap() const noexcept { return ort_value_name_idx_map_; }

Expand Down Expand Up @@ -446,7 +448,8 @@ class SessionState {
// and as this isn't considered performance critical currently it's not worth the maintenance overhead of adding one.
// We do get an allocator from ExecutionFrame so this is looked up frequently, however there most likely aren't many
// entries in the map
std::shared_ptr<std::map<OrtDevice, AllocatorPtr>> allocators_;
std::unique_ptr<AllocatorMap> allocators_unique_ptr_;
AllocatorMap* allocators_;
pranavsharma marked this conversation as resolved.
Show resolved Hide resolved

OrtValueNameIdxMap ort_value_name_idx_map_;

Expand Down
3 changes: 1 addition & 2 deletions onnxruntime/core/providers/acl/acl_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,7 @@ std::shared_ptr<KernelRegistry> GetAclKernelRegistry() {
} // namespace acl

ACLExecutionProvider::ACLExecutionProvider(const ACLExecutionProviderInfo&)
: IExecutionProvider{onnxruntime::kAclExecutionProvider} {}

: IExecutionProvider{onnxruntime::kAclExecutionProvider} {}

ACLExecutionProvider::~ACLExecutionProvider() {}

Expand Down
4 changes: 2 additions & 2 deletions onnxruntime/core/providers/cann/cann_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1457,10 +1457,10 @@ std::vector<AllocatorPtr> CANNExecutionProvider::CreatePreferredAllocators() {
},
pinned_device.Id());

return std::vector<AllocatorPtr> { CreateAllocator(default_memory_info), CreateAllocator(pinned_memory_info) };
return std::vector<AllocatorPtr>{CreateAllocator(default_memory_info), CreateAllocator(pinned_memory_info)};
}

void CANNExecutionProvider::RegisterStreamHandlers(IStreamCommandHandleRegistry& stream_handle_registry, std::map<OrtDevice, AllocatorPtr>&) const {
void CANNExecutionProvider::RegisterStreamHandlers(IStreamCommandHandleRegistry& stream_handle_registry, AllocatorMap&) const {
RegisterCannStreamHandles(stream_handle_registry, OrtDevice::NPU);
}

Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/providers/cann/cann_execution_provider.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ class CANNExecutionProvider : public IExecutionProvider {
return CANNExecutionProviderInfo::ToProviderOptions(info_);
}

void RegisterStreamHandlers(IStreamCommandHandleRegistry& stream_handle_registry, std::map<OrtDevice, AllocatorPtr>&) const override;
void RegisterStreamHandlers(IStreamCommandHandleRegistry& stream_handle_registry, AllocatorMap&) const override;

OrtDevice GetOrtDeviceByMemType(OrtMemType mem_type) const override;

Expand Down
10 changes: 5 additions & 5 deletions onnxruntime/core/providers/cuda/cuda_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2537,7 +2537,7 @@ CUDAExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph,
return result;
}

void CUDAExecutionProvider::RegisterStreamHandlers(IStreamCommandHandleRegistry& stream_handle_registry, std::map<OrtDevice, AllocatorPtr>& allocators) const {
void CUDAExecutionProvider::RegisterStreamHandlers(IStreamCommandHandleRegistry& stream_handle_registry, AllocatorMap& allocators) const {
// This allocator must be the same to the allocator
// used in AllocateBufferOnCPUPinned.
auto allocator = allocators[GetOrtDeviceByMemType(OrtMemTypeCPU)];
Expand Down Expand Up @@ -2570,10 +2570,10 @@ std::vector<AllocatorPtr> CUDAExecutionProvider::CreatePreferredAllocators() {
// correct to use the GPU device id, unless we wanted to share the pinned memory allocator across devices,
// at the risk the lifetime isn't managed correctly if one of those devices go away.
0);
return std::vector<AllocatorPtr> {
CreateCudaAllocator(info_.device_id, info_.gpu_mem_limit, info_.arena_extend_strategy,
info_.external_allocator_info, info_.default_memory_arena_cfg),
CreateAllocator(pinned_memory_info),
return std::vector<AllocatorPtr>{
CreateCudaAllocator(info_.device_id, info_.gpu_mem_limit, info_.arena_extend_strategy,
info_.external_allocator_info, info_.default_memory_arena_cfg),
CreateAllocator(pinned_memory_info),
};
}

Expand Down
3 changes: 1 addition & 2 deletions onnxruntime/core/providers/cuda/cuda_execution_provider.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,7 @@ class CUDAExecutionProvider : public IExecutionProvider {
bool IsGraphCaptureEnabled() const override;
bool IsGraphCaptured() const override;
Status ReplayGraph() override;
#endif
void RegisterStreamHandlers(IStreamCommandHandleRegistry& stream_handle_registry, std::map<OrtDevice, AllocatorPtr>& allocators) const override;
void RegisterStreamHandlers(IStreamCommandHandleRegistry& stream_handle_registry, AllocatorMap& allocators) const override;
OrtDevice GetOrtDeviceByMemType(OrtMemType mem_type) const override;
std::vector<AllocatorPtr> CreatePreferredAllocators() override;

Expand Down
7 changes: 4 additions & 3 deletions onnxruntime/core/providers/js/js_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -316,9 +316,10 @@ JsExecutionProvider::JsExecutionProvider(const JsExecutionProviderInfo& info)

std::vector<AllocatorPtr> JsExecutionProvider::CreatePreferredAllocators() {
AllocatorCreationInfo customAllocatorCreationInfo([&](int) {
return std::make_unique<js::JsCustomAllocator>();
}, 0, false); // TODO(leca): REVIEW: need JsCPUAllocator?
return std::vector<AllocatorPtr> {CreateAllocator(customAllocatorCreationInfo)};
return std::make_unique<js::JsCustomAllocator>();
},
0, false); // TODO(leca): REVIEW: need JsCPUAllocator?
return std::vector<AllocatorPtr>{CreateAllocator(customAllocatorCreationInfo)};
}

std::vector<std::unique_ptr<ComputeCapability>> JsExecutionProvider::GetCapability(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,12 +133,13 @@ MIGraphXExecutionProvider::~MIGraphXExecutionProvider() {

std::vector<AllocatorPtr> MIGraphXExecutionProvider::CreatePreferredAllocators() {
AllocatorCreationInfo default_memory_info(
[](OrtDevice::DeviceId device_id) { return CreateROCMAllocator(device_id, onnxruntime::CUDA); }, device_id_);
[](OrtDevice::DeviceId device_id) { return CreateROCMAllocator(device_id, onnxruntime::CUDA); }, device_id_);
AllocatorCreationInfo pinned_allocator_info(
[](OrtDevice::DeviceId device_id) {
return CreateROCMPinnedAllocator(device_id, onnxruntime::CUDA_PINNED);
}, 0);
return std::vector<AllocatorPtr> {CreateAllocator(default_memory_info), CreateAllocator(pinned_allocator_info)};
[](OrtDevice::DeviceId device_id) {
return CreateROCMPinnedAllocator(device_id, onnxruntime::CUDA_PINNED);
},
0);
return std::vector<AllocatorPtr>{CreateAllocator(default_memory_info), CreateAllocator(pinned_allocator_info)};
}

std::unique_ptr<onnxruntime::IDataTransfer> MIGraphXExecutionProvider::GetDataTransfer() const {
Expand Down Expand Up @@ -1135,7 +1136,7 @@ Status MIGraphXExecutionProvider::Compile(const std::vector<FusedNodeAndGraph>&
return Status::OK();
}

void MIGraphXExecutionProvider::RegisterStreamHandlers(IStreamCommandHandleRegistry& stream_handle_registry, std::map<OrtDevice, AllocatorPtr>& allocators) const {
void MIGraphXExecutionProvider::RegisterStreamHandlers(IStreamCommandHandleRegistry& stream_handle_registry, AllocatorMap& allocators) const {
auto allocator = allocators[GetOrtDeviceByMemType(OrtMemTypeCPU)];
RegisterRocmStreamHandles(stream_handle_registry, OrtDevice::GPU, allocator, true, stream_, false /*TODO:external_stream_*/, external_miopen_handle_, external_rocblas_handle_);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ class MIGraphXExecutionProvider : public IExecutionProvider {
std::unique_ptr<onnxruntime::IDataTransfer> GetDataTransfer() const override;

std::unique_ptr<IndexedSubGraph> GetSubGraph(const std::vector<std::size_t>& graph_nodes_index, const GraphViewer& graph) const;
void RegisterStreamHandlers(IStreamCommandHandleRegistry& stream_handle_registry, std::map<OrtDevice, AllocatorPtr>& allocators) const override;
void RegisterStreamHandlers(IStreamCommandHandleRegistry& stream_handle_registry, AllocatorMap& allocators) const override;
OrtDevice GetOrtDeviceByMemType(OrtMemType mem_type) const override;
std::vector<AllocatorPtr> CreatePreferredAllocators() override;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ struct EinsumRocmAssets {
Stream* ort_stream,
AllocatorPtr gpu_allocator) : rocblas_handle_(rocblas_handle),
rocm_ep_(rocm_ep),
ort_stream_(ort_stream), gpu_allocator_(gpu_allocator) {}
ort_stream_(ort_stream),
gpu_allocator_(gpu_allocator) {}

hipStream_t GetRocmStream() {
return ort_stream_ ? static_cast<hipStream_t>(ort_stream_->GetHandle()) : nullptr;
Expand Down
11 changes: 5 additions & 6 deletions onnxruntime/core/providers/rocm/rocm_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2327,7 +2327,7 @@ ROCMExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph,
return result;
}

void ROCMExecutionProvider::RegisterStreamHandlers(IStreamCommandHandleRegistry& stream_handle_registry, std::map<OrtDevice, AllocatorPtr>& allocators) const {
void ROCMExecutionProvider::RegisterStreamHandlers(IStreamCommandHandleRegistry& stream_handle_registry, AllocatorMap& allocators) const {
// This allocator must be the same to the allocator
// used in AllocateBufferOnCPUPinned.
auto allocator = allocators[GetOrtDeviceByMemType(OrtMemTypeCPU)];
Expand Down Expand Up @@ -2358,11 +2358,10 @@ std::vector<AllocatorPtr> ROCMExecutionProvider::CreatePreferredAllocators() {
// correct to use the GPU device id, unless we wanted to share the pinned memory allocator across devices,
// at the risk the lifetime isn't managed correctly if one of those devices go away.
0);
return std::vector<AllocatorPtr> {
CreateRocmAllocator(info_.device_id, info_.gpu_mem_limit, info_.arena_extend_strategy,
info_.external_allocator_info, info_.default_memory_arena_cfg),
CreateAllocator(pinned_memory_info)
};
return std::vector<AllocatorPtr>{
CreateRocmAllocator(info_.device_id, info_.gpu_mem_limit, info_.arena_extend_strategy,
info_.external_allocator_info, info_.default_memory_arena_cfg),
CreateAllocator(pinned_memory_info)};
}

} // namespace onnxruntime
2 changes: 1 addition & 1 deletion onnxruntime/core/providers/rocm/rocm_execution_provider.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ class ROCMExecutionProvider : public IExecutionProvider {

std::unique_ptr<profiling::EpProfiler> GetProfiler() override;

void RegisterStreamHandlers(IStreamCommandHandleRegistry& stream_handle_registry, std::map<OrtDevice, AllocatorPtr>& allocators) const override;
void RegisterStreamHandlers(IStreamCommandHandleRegistry& stream_handle_registry, AllocatorMap& allocators) const override;
std::vector<AllocatorPtr> CreatePreferredAllocators() override;
OrtDevice GetOrtDeviceByMemType(OrtMemType mem_type) const override;

Expand Down
Loading
You are viewing a condensed version of this merge commit. You can view the full changes here.