Skip to content

Commit

Permalink
Fix DML regression from allocator refactor and enable unrounded weigh…
Browse files Browse the repository at this point in the history
…t allocation in ORT API (#17030)

This addresses a DML performance regression from the following PR
resulting in allocations not being rounded and pooled in the DML
execution provider.

#15833

This also fixes a pre-existing limitation that allocations during
session initialization (primarily large weights and persistent
resources) only bypassed rounding and pooling while using the Winml API.
The allocator now also respects a caller's rounding mode parameter when
provided.
  • Loading branch information
jeffbloo authored Aug 11, 2023
1 parent 9cd4e5a commit 0180c04
Show file tree
Hide file tree
Showing 13 changed files with 19 additions and 89 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ namespace Dml

ID3D12Resource* GetD3D12ResourceFromAllocation(onnxruntime::IAllocator* allocator, void* ptr);
void FlushContext(onnxruntime::IExecutionProvider* provider);
void SetDefaultRoundingMode(onnxruntime::IExecutionProvider* provider, AllocatorRoundingMode roundingMode);
void ReleaseCompletedReferences(onnxruntime::IExecutionProvider* provider);

onnxruntime::common::Status CopyTensor(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,11 @@ namespace Dml
}

void* BucketizedBufferAllocator::Alloc(size_t size)
{
return Alloc(size, m_defaultRoundingMode);
}

void* BucketizedBufferAllocator::Alloc(size_t size, AllocatorRoundingMode roundingMode)
{
// For some reason lotus likes requesting 0 bytes of memory
size = std::max<size_t>(1, size);
Expand All @@ -90,7 +95,7 @@ namespace Dml
uint64_t bucketSize = 0;

// Use a pooled resource if the size (post rounding, if requested) matches a bucket size
if (m_defaultRoundingMode == AllocatorRoundingMode::Enabled || size == GetBucketSizeFromIndex(GetBucketIndexFromSize(size)))
if (roundingMode == AllocatorRoundingMode::Enabled || size == GetBucketSizeFromIndex(GetBucketIndexFromSize(size)))
{
Bucket* bucket = nullptr;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ namespace Dml
void SetDefaultRoundingMode(AllocatorRoundingMode roundingMode);

public: // onnxruntime::IAllocator
void* Alloc(size_t size, AllocatorRoundingMode roundingMode);
void* Alloc(size_t size) final;
void Free(void* p) final;

Expand Down Expand Up @@ -82,7 +83,12 @@ namespace Dml
std::vector<Bucket> m_pool;
size_t m_currentAllocationId = 0;
uint64_t m_currentResourceId = 0;
AllocatorRoundingMode m_defaultRoundingMode = AllocatorRoundingMode::Enabled;

// Unless specifically requested, allocation sizes are not rounded to enable pooling
// until SetDefaultRoundingMode is called. This should be done at completion of session
// initialization.
AllocatorRoundingMode m_defaultRoundingMode = AllocatorRoundingMode::Disabled;

std::shared_ptr<ExecutionContext> m_context;
std::unique_ptr<DmlSubAllocator> m_subAllocator;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ namespace Dml
ORT_TRY
{
ComPtr<IUnknown> allocation;
allocation.Attach(static_cast<IUnknown* >(m_allocator->Alloc(size)));
allocation.Attach(static_cast<IUnknown* >(m_allocator->Alloc(size, roundingMode)));

const auto* allocInfo = m_allocator->DecodeDataHandle(allocation.Get());

Expand Down Expand Up @@ -204,7 +204,6 @@ namespace Dml
D3D12_RESOURCE_FLAG_ALLOW_UNORDERED_ACCESS,
D3D12_RESOURCE_STATE_UNORDERED_ACCESS,
std::make_unique<DmlCommittedResourceAllocator>(m_d3d12Device.Get()));
m_allocator->SetDefaultRoundingMode(m_defaultRoundingMode); // TODO(leca): REVIEW: the original code is able to set roudingMode multiple times during alloc's life time. Double check this case is not happening
m_context->SetAllocator(m_allocator);
// CPU Allocator used to create buffers for the MemcpyFromHost, Shape and Size operators.
m_cpuInputAllocator = std::make_shared<CPUAllocator>(OrtMemType::OrtMemTypeCPUInput);
Expand Down Expand Up @@ -963,11 +962,6 @@ namespace Dml
m_context->Flush();
}

void ExecutionProviderImpl::SetDefaultRoundingMode(AllocatorRoundingMode roundingMode)
{
m_defaultRoundingMode = roundingMode;
}

void ExecutionProviderImpl::ReleaseCompletedReferences()
{
m_context->ReleaseCompletedReferences();
Expand Down Expand Up @@ -1125,6 +1119,10 @@ namespace Dml
m_context->ReleaseCompletedReferences();
m_uploadHeap->Trim();

// Allocations after this point are potentially transient and their sizes are
// rounded to enable pooling.
m_allocator->SetDefaultRoundingMode(AllocatorRoundingMode::Enabled);

return onnxruntime::common::Status::OK();
}

Expand All @@ -1148,12 +1146,6 @@ namespace Dml
dmlexecutionprovider->Flush();
}

void SetDefaultRoundingMode(onnxruntime::IExecutionProvider* provider, AllocatorRoundingMode roundingMode)
{
ExecutionProvider* dmlexecutionprovider = static_cast<Dml::ExecutionProvider*>(provider);
dmlexecutionprovider->SetDefaultRoundingMode(roundingMode);
}

void ReleaseCompletedReferences(onnxruntime::IExecutionProvider * provider)
{
ExecutionProvider* dmlexecutionprovider = static_cast<Dml::ExecutionProvider*>(provider);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,8 +126,6 @@ namespace Dml
STDMETHOD_(D3D12_COMMAND_LIST_TYPE, GetCommandListTypeForQueue)() const override;
STDMETHOD_(void, Flush)() const override;

void SetDefaultRoundingMode(AllocatorRoundingMode roundingMode);

// Waits for flushed work, discards unflushed work, and discards associated references to
// prevent circular references. Must be the last call on the object before destruction.
void Close() override;
Expand Down Expand Up @@ -198,7 +196,6 @@ namespace Dml
bool m_closed = false;
mutable std::chrono::time_point<std::chrono::steady_clock> m_lastUploadFlushTime;
static constexpr std::chrono::milliseconds m_batchFlushInterval = std::chrono::milliseconds(10);
AllocatorRoundingMode m_defaultRoundingMode = AllocatorRoundingMode::Enabled;
};

class DataTransfer : public onnxruntime::IDataTransfer
Expand Down Expand Up @@ -287,11 +284,6 @@ namespace Dml
return m_impl->Flush();
}

void SetDefaultRoundingMode(AllocatorRoundingMode roundingMode)
{
return m_impl->SetDefaultRoundingMode(roundingMode);
}

void ReleaseCompletedReferences()
{
return m_impl->ReleaseCompletedReferences();
Expand Down
12 changes: 0 additions & 12 deletions onnxruntime/core/providers/dml/dml_provider_factory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,27 +32,20 @@ struct DMLProviderFactory : IExecutionProviderFactory {
~DMLProviderFactory() override {}

std::unique_ptr<IExecutionProvider> CreateProvider() override;
void SetDefaultRoundingMode(AllocatorRoundingMode rounding_mode);

void SetMetacommandsEnabled(bool metacommands_enabled);

private:
ComPtr<IDMLDevice> dml_device_{};
ComPtr<ID3D12CommandQueue> cmd_queue_{};
AllocatorRoundingMode rounding_mode_ = AllocatorRoundingMode::Enabled;
bool metacommands_enabled_ = true;
};

std::unique_ptr<IExecutionProvider> DMLProviderFactory::CreateProvider() {
auto provider = Dml::CreateExecutionProvider(dml_device_.Get(), cmd_queue_.Get(), metacommands_enabled_);
Dml::SetDefaultRoundingMode(provider.get(), rounding_mode_);
return provider;
}

void DMLProviderFactory::SetDefaultRoundingMode(AllocatorRoundingMode rounding_mode) {
rounding_mode_ = rounding_mode;
}

void DMLProviderFactory::SetMetacommandsEnabled(bool metacommands_enabled) {
metacommands_enabled_ = metacommands_enabled;
}
Expand Down Expand Up @@ -80,11 +73,6 @@ std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_DML(ID
return std::make_shared<onnxruntime::DMLProviderFactory>(dml_device, cmd_queue);
}

void DmlConfigureProviderFactoryDefaultRoundingMode(IExecutionProviderFactory* factory, AllocatorRoundingMode rounding_mode) {
auto dml_provider_factory = static_cast<DMLProviderFactory*>(factory);
dml_provider_factory->SetDefaultRoundingMode(rounding_mode);
}

void DmlConfigureProviderFactoryMetacommandsEnabled(IExecutionProviderFactory* factory, bool metacommandsEnabled) {
auto dml_provider_factory = static_cast<DMLProviderFactory*>(factory);
dml_provider_factory->SetMetacommandsEnabled(metacommandsEnabled);
Expand Down
3 changes: 0 additions & 3 deletions winml/adapter/winml_adapter_apis.h
Original file line number Diff line number Diff line change
Expand Up @@ -122,9 +122,6 @@ ORT_API_STATUS(
);

// Dml methods (TODO need to figure out how these need to move to session somehow...)
ORT_API_STATUS(
DmlExecutionProviderSetDefaultRoundingMode, _In_ OrtExecutionProvider* dml_provider, _In_ bool is_enabled
);
ORT_API_STATUS(DmlExecutionProviderFlushContext, _In_ OrtExecutionProvider* dml_provider);
ORT_API_STATUS(DmlExecutionProviderReleaseCompletedReferences, _In_ OrtExecutionProvider* dml_provider);

Expand Down
1 change: 0 additions & 1 deletion winml/adapter/winml_adapter_c_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,6 @@ static constexpr WinmlAdapterApi winml_adapter_api_1 = {
&winmla::SessionGetNamedDimensionsOverrides,

// Dml methods (TODO need to figure out how these need to move to session somehow...)
&winmla::DmlExecutionProviderSetDefaultRoundingMode,
&winmla::DmlExecutionProviderFlushContext,
&winmla::DmlExecutionProviderReleaseCompletedReferences,
&winmla::DmlCopyTensor,
Expand Down
10 changes: 0 additions & 10 deletions winml/adapter/winml_adapter_c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -412,16 +412,6 @@ struct WinmlAdapterApi {
_Out_ winrt::Windows::Foundation::Collections::IMapView<winrt::hstring, uint32_t>& overrides
)NO_EXCEPTION;

/**
* DmlExecutionProviderSetDefaultRoundingMode
* This api is used to configure the DML EP to turn on/off rounding.
*
* WinML uses this to disable rounding during session initialization and then enables it again post initialization.
*/
OrtStatus*(ORT_API_CALL* DmlExecutionProviderSetDefaultRoundingMode)(
_In_ OrtExecutionProvider* dml_provider, _In_ bool is_enabled
)NO_EXCEPTION;

/**
* DmlExecutionProviderFlushContext
* This api is used to flush the DML EP.
Expand Down
25 changes: 1 addition & 24 deletions winml/adapter/winml_adapter_dml.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
// Licensed under the MIT License.

#pragma once
#include "adapter/pch.h"
Expand Down Expand Up @@ -68,9 +68,6 @@ Microsoft::WRL::ComPtr<IDMLDevice> CreateDmlDevice(ID3D12Device* d3d12Device) {
}

namespace onnxruntime {
void DmlConfigureProviderFactoryDefaultRoundingMode(
onnxruntime::IExecutionProviderFactory* factory, AllocatorRoundingMode rounding_mode
);
void DmlConfigureProviderFactoryMetacommandsEnabled(IExecutionProviderFactory* factory, bool metacommandsEnabled);
}// namespace onnxruntime

Expand All @@ -91,32 +88,12 @@ ORT_API_STATUS_IMPL(
}
auto factory = options->provider_factories.back().get();

// OnnxRuntime uses the default rounding mode when calling the session's allocator.
// During initialization, OnnxRuntime allocates weights, which are permanent across session
// lifetime and can be large, so shouldn't be rounded.
// So we create the provider with rounding disabled, and expect the caller to enable it after.
onnxruntime::DmlConfigureProviderFactoryDefaultRoundingMode(factory, AllocatorRoundingMode::Disabled);

onnxruntime::DmlConfigureProviderFactoryMetacommandsEnabled(factory, metacommands_enabled);
#endif // USE_DML
return nullptr;
API_IMPL_END
}

ORT_API_STATUS_IMPL(
winmla::DmlExecutionProviderSetDefaultRoundingMode, _In_ OrtExecutionProvider* dml_provider, _In_ bool is_enabled
) {
API_IMPL_BEGIN
#ifdef USE_DML
auto dml_provider_internal = reinterpret_cast<::onnxruntime::IExecutionProvider*>(dml_provider);
Dml::SetDefaultRoundingMode(
dml_provider_internal, is_enabled ? AllocatorRoundingMode::Enabled : AllocatorRoundingMode::Disabled
);
#endif
return nullptr;
API_IMPL_END
}

ORT_API_STATUS_IMPL(winmla::DmlExecutionProviderFlushContext, _In_ OrtExecutionProvider* dml_provider) {
API_IMPL_BEGIN
#ifdef USE_DML
Expand Down
4 changes: 0 additions & 4 deletions winml/lib/Api.Ort/OnnxruntimeDmlSessionBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -104,10 +104,6 @@ HRESULT OnnxruntimeDmlSessionBuilder::Initialize(OrtSession* session) {
winml_adapter_api->SessionGetExecutionProvider(session, 0, &ort_provider), engine_factory_->UseOrtApi()
);

RETURN_HR_IF_NOT_OK_MSG(
winml_adapter_api->DmlExecutionProviderSetDefaultRoundingMode(ort_provider, true), engine_factory_->UseOrtApi()
);

// Flush the D3D12 work from the DML execution provider
RETURN_HR_IF_NOT_OK_MSG(
winml_adapter_api->DmlExecutionProviderFlushContext(ort_provider), engine_factory_->UseOrtApi()
Expand Down
9 changes: 0 additions & 9 deletions winml/test/adapter/AdapterDmlEpTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,14 +93,6 @@ UniqueOrtSession CreateCpuSession() {
return CreateUniqueOrtSession(FileHelpers::GetModulePath() + L"fns-candy.onnx", session_options);
}

void DmlExecutionProviderSetDefaultRoundingMode() {
GPUTEST;
auto session = CreateDmlSession();
OrtExecutionProvider* ort_provider;
THROW_IF_NOT_OK_MSG(winml_adapter_api->SessionGetExecutionProvider(session.get(), 0, &ort_provider), ort_api);
THROW_IF_NOT_OK_MSG(winml_adapter_api->DmlExecutionProviderSetDefaultRoundingMode(ort_provider, false), ort_api);
}

void DmlExecutionProviderFlushContext() {
GPUTEST;
auto session = CreateDmlSession();
Expand Down Expand Up @@ -364,7 +356,6 @@ const AdapterDmlEpTestApi& getapi() {
static constexpr AdapterDmlEpTestApi api = {
AdapterDmlEpTestSetup,
AdapterDmlEpTestTeardown,
DmlExecutionProviderSetDefaultRoundingMode,
DmlExecutionProviderFlushContext,
DmlExecutionProviderReleaseCompletedReferences,
DmlCreateAndFreeGPUAllocationFromD3DResource,
Expand Down
2 changes: 0 additions & 2 deletions winml/test/adapter/AdapterDmlEpTest.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
struct AdapterDmlEpTestApi {
SetupTest AdapterDmlEpTestSetup;
TeardownClass AdapterDmlEpTestTeardown;
VoidTest DmlExecutionProviderSetDefaultRoundingMode;
VoidTest DmlExecutionProviderFlushContext;
VoidTest DmlExecutionProviderReleaseCompletedReferences;
VoidTest DmlCreateGPUAllocationFromD3DResource;
Expand All @@ -23,7 +22,6 @@ WINML_TEST_CLASS_BEGIN(AdapterDmlEpTest)
WINML_TEST_CLASS_SETUP_METHOD(AdapterDmlEpTestSetup)
WINML_TEST_CLASS_TEARDOWN_METHOD(AdapterDmlEpTestTeardown)
WINML_TEST_CLASS_BEGIN_TESTS
WINML_TEST(AdapterDmlEpTest, DmlExecutionProviderSetDefaultRoundingMode)
WINML_TEST(AdapterDmlEpTest, DmlExecutionProviderFlushContext)
WINML_TEST(AdapterDmlEpTest, DmlExecutionProviderReleaseCompletedReferences)
WINML_TEST(AdapterDmlEpTest, DmlCreateGPUAllocationFromD3DResource)
Expand Down

0 comments on commit 0180c04

Please sign in to comment.