From 98e021976231699c17211110f1b36b56cb22d622 Mon Sep 17 00:00:00 2001 From: Vishal Agarwal Date: Fri, 9 Feb 2024 15:42:27 +0530 Subject: [PATCH 1/2] add convolution sample --- Samples/DirectMLConv/DirectMLConv.sln | 31 + .../DirectMLConv/DirectMLConv.vcxproj | 160 + .../DirectMLConv/DirectMLConv.vcxproj.filters | 40 + Samples/DirectMLConv/DirectMLConv/DirectMLX.h | 4367 +++++++++++++++++ Samples/DirectMLConv/DirectMLConv/d3dx12.h | 3439 +++++++++++++ Samples/DirectMLConv/DirectMLConv/main.cpp | 596 +++ .../DirectMLConv/DirectMLConv/packages.config | 5 + Samples/DirectMLConv/DirectMLConv/pch.cpp | 4 + Samples/DirectMLConv/DirectMLConv/pch.h | 34 + Samples/DirectMLConv/README.md | 14 + 10 files changed, 8690 insertions(+) create mode 100644 Samples/DirectMLConv/DirectMLConv.sln create mode 100644 Samples/DirectMLConv/DirectMLConv/DirectMLConv.vcxproj create mode 100644 Samples/DirectMLConv/DirectMLConv/DirectMLConv.vcxproj.filters create mode 100644 Samples/DirectMLConv/DirectMLConv/DirectMLX.h create mode 100644 Samples/DirectMLConv/DirectMLConv/d3dx12.h create mode 100644 Samples/DirectMLConv/DirectMLConv/main.cpp create mode 100644 Samples/DirectMLConv/DirectMLConv/packages.config create mode 100644 Samples/DirectMLConv/DirectMLConv/pch.cpp create mode 100644 Samples/DirectMLConv/DirectMLConv/pch.h create mode 100644 Samples/DirectMLConv/README.md diff --git a/Samples/DirectMLConv/DirectMLConv.sln b/Samples/DirectMLConv/DirectMLConv.sln new file mode 100644 index 00000000..99e81ad2 --- /dev/null +++ b/Samples/DirectMLConv/DirectMLConv.sln @@ -0,0 +1,31 @@ + +Microsoft Visual Studio Solution File, Format Version 12.00 +# Visual Studio Version 17 +VisualStudioVersion = 17.8.34525.116 +MinimumVisualStudioVersion = 10.0.40219.1 +Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "DirectMLConv", "DirectMLConv\DirectMLConv.vcxproj", "{A3AA531B-DAF1-4D6C-8389-EA88FA2082E6}" +EndProject +Global + GlobalSection(SolutionConfigurationPlatforms) = preSolution + Debug|x64 = Debug|x64 + Debug|x86 = Debug|x86 + Release|x64 = Release|x64 + Release|x86 = Release|x86 + EndGlobalSection + GlobalSection(ProjectConfigurationPlatforms) = postSolution + {A3AA531B-DAF1-4D6C-8389-EA88FA2082E6}.Debug|x64.ActiveCfg = Debug|x64 + {A3AA531B-DAF1-4D6C-8389-EA88FA2082E6}.Debug|x64.Build.0 = Debug|x64 + {A3AA531B-DAF1-4D6C-8389-EA88FA2082E6}.Debug|x86.ActiveCfg = Debug|Win32 + {A3AA531B-DAF1-4D6C-8389-EA88FA2082E6}.Debug|x86.Build.0 = Debug|Win32 + {A3AA531B-DAF1-4D6C-8389-EA88FA2082E6}.Release|x64.ActiveCfg = Release|x64 + {A3AA531B-DAF1-4D6C-8389-EA88FA2082E6}.Release|x64.Build.0 = Release|x64 + {A3AA531B-DAF1-4D6C-8389-EA88FA2082E6}.Release|x86.ActiveCfg = Release|Win32 + {A3AA531B-DAF1-4D6C-8389-EA88FA2082E6}.Release|x86.Build.0 = Release|Win32 + EndGlobalSection + GlobalSection(SolutionProperties) = preSolution + HideSolutionNode = FALSE + EndGlobalSection + GlobalSection(ExtensibilityGlobals) = postSolution + SolutionGuid = {72CE4CBB-7C71-40C5-956F-D5A9DCCE48A1} + EndGlobalSection +EndGlobal diff --git a/Samples/DirectMLConv/DirectMLConv/DirectMLConv.vcxproj b/Samples/DirectMLConv/DirectMLConv/DirectMLConv.vcxproj new file mode 100644 index 00000000..e5f1791c --- /dev/null +++ b/Samples/DirectMLConv/DirectMLConv/DirectMLConv.vcxproj @@ -0,0 +1,160 @@ + + + + + + Debug + Win32 + + + Release + Win32 + + + Debug + x64 + + + Release + x64 + + + + 17.0 + Win32Proj + {a3aa531b-daf1-4d6c-8389-ea88fa2082e6} + DirectMLConv + 10.0 + + + + Application + true + v143 + Unicode + + + Application + false + v143 + true + Unicode + + + Application + true + v143 + Unicode + + + Application + false + v143 + true + Unicode + + + + + + + + + + + + + + + + + + + + + + Level3 + true + WIN32;_DEBUG;_CONSOLE;%(PreprocessorDefinitions) + true + + + Console + true + + + + + Level3 + true + true + true + WIN32;NDEBUG;_CONSOLE;%(PreprocessorDefinitions) + true + + + Console + true + true + true + + + + + Level3 + true + _DEBUG;_CONSOLE;%(PreprocessorDefinitions) + false + stdcpp17 + Create + pch.h + + + Console + true + WindowsApp.lib;dxgi.lib;d3d12.lib;$(CoreLibraryDependencies);%(AdditionalDependencies) + + + + + Level3 + true + true + true + NDEBUG;_CONSOLE;%(PreprocessorDefinitions) + true + + + Console + true + true + true + + + + + + + + + + + + + + + + + + + + + + + This project references NuGet package(s) that are missing on this computer. Use NuGet Package Restore to download them. For more information, see http://go.microsoft.com/fwlink/?LinkID=322105. The missing file is {0}. + + + + + + \ No newline at end of file diff --git a/Samples/DirectMLConv/DirectMLConv/DirectMLConv.vcxproj.filters b/Samples/DirectMLConv/DirectMLConv/DirectMLConv.vcxproj.filters new file mode 100644 index 00000000..dd8b06c4 --- /dev/null +++ b/Samples/DirectMLConv/DirectMLConv/DirectMLConv.vcxproj.filters @@ -0,0 +1,40 @@ + + + + + {4FC737F1-C7A5-4376-A066-2A32D752A2FF} + cpp;c;cc;cxx;c++;cppm;ixx;def;odl;idl;hpj;bat;asm;asmx + + + {93995380-89BD-4b04-88EB-625FBE52EBFB} + h;hh;hpp;hxx;h++;hm;inl;inc;ipp;xsd + + + {67DA6AB6-F800-4c08-8B7A-83BB121AAD01} + rc;ico;cur;bmp;dlg;rc2;rct;bin;rgs;gif;jpg;jpeg;jpe;resx;tiff;tif;png;wav;mfcribbon-ms + + + + + Header Files + + + Header Files + + + Header Files + + + + + Source Files + + + Source Files + + + + + + + \ No newline at end of file diff --git a/Samples/DirectMLConv/DirectMLConv/DirectMLX.h b/Samples/DirectMLConv/DirectMLConv/DirectMLX.h new file mode 100644 index 00000000..5a1308c3 --- /dev/null +++ b/Samples/DirectMLConv/DirectMLConv/DirectMLX.h @@ -0,0 +1,4367 @@ +//********************************************************* +// +// Copyright (c) Microsoft. All rights reserved. +// This code is licensed under the MIT License (MIT). +// THIS CODE IS PROVIDED *AS IS* WITHOUT WARRANTY OF +// ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING ANY +// IMPLIED WARRANTIES OF FITNESS FOR A PARTICULAR +// PURPOSE, MERCHANTABILITY, OR NON-INFRINGEMENT. +// +//********************************************************* +// clang-format off + +#pragma once +#include "DirectML.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include // For Microsoft::WRL::ComPtr + +#if DMLX_USE_ABSEIL + #if __cpp_lib_span + #include + #endif +#elif __cplusplus >= 201703L && __has_include() + // stl optional is only available in cpp17 and above. + #include +#elif __has_include("dml_optional_extensions.h") + #include "dml_optional_extensions.h" + #define DMLX_OPTIONAL_EXTENDED +#endif + +#if __cpp_exceptions + #include +#endif + +#if __cplusplus >= 201703L && __has_include() + #include +#endif + +/** Calculates the minimum number of bytes required to store a buffer tensor with the specified type, sizes, and + strides. The formula can be expressed as the following: + + IndexOfLastElement = dot(Sizes - 1, Strides); + MinimumImpliedSizeInBytes = roundup((IndexOfLastElement + 1) * ElementSizeInBytes, 4) + + In other words, the minimum size of a tensor is the index of the one-past-the-end element, multiplied by the + element size (e.g. 2 bytes for a FLOAT16 tensor). Additionally DirectML requires that all buffers bound must have + a total size which is DWORD-aligned, and hence the minimum implied size in bytes must be rounded up to the nearest + 4-byte boundary. + */ + +inline UINT64 DMLCalcBufferTensorSize( + DML_TENSOR_DATA_TYPE dataType, + UINT dimensionCount, + _In_reads_(dimensionCount) const UINT* sizes, + _In_reads_opt_(dimensionCount) const UINT* strides) +{ + UINT elementSizeInBytes = 0; + switch (dataType) + { + case DML_TENSOR_DATA_TYPE_FLOAT32: + case DML_TENSOR_DATA_TYPE_UINT32: + case DML_TENSOR_DATA_TYPE_INT32: + elementSizeInBytes = 4; + break; + + case DML_TENSOR_DATA_TYPE_FLOAT16: + case DML_TENSOR_DATA_TYPE_UINT16: + case DML_TENSOR_DATA_TYPE_INT16: + elementSizeInBytes = 2; + break; + + case DML_TENSOR_DATA_TYPE_UINT8: + case DML_TENSOR_DATA_TYPE_INT8: + elementSizeInBytes = 1; + break; + + case DML_TENSOR_DATA_TYPE_FLOAT64: + case DML_TENSOR_DATA_TYPE_UINT64: + case DML_TENSOR_DATA_TYPE_INT64: + elementSizeInBytes = 8; + break; + + default: + return 0; // Invalid data type + } + + UINT64 minimumImpliedSizeInBytes = 0; + if (!strides) + { + minimumImpliedSizeInBytes = sizes[0]; + for (UINT i = 1; i < dimensionCount; ++i) + { + minimumImpliedSizeInBytes *= sizes[i]; + } + minimumImpliedSizeInBytes *= elementSizeInBytes; + } + else + { + UINT indexOfLastElement = 0; + for (UINT i = 0; i < dimensionCount; ++i) + { + indexOfLastElement += (sizes[i] - 1) * strides[i]; + } + + minimumImpliedSizeInBytes = (static_cast(indexOfLastElement) + 1) * elementSizeInBytes; + } + + // Round up to the nearest 4 bytes. + minimumImpliedSizeInBytes = (minimumImpliedSizeInBytes + 3) & ~3ull; + + return minimumImpliedSizeInBytes; +} + +namespace dml +{ + namespace detail + { + // Provide non-member size() and data(). Defaults to standard library implementation (if available) +#if __cpp_lib_nonmember_container_access + template + constexpr auto size(const C& c) -> decltype(c.size()) + { + return std::size(c); + } + + template + constexpr std::size_t size(const T(&array)[N]) noexcept + { + return std::size(array); + } + + template + constexpr auto data(C& c) -> decltype(c.data()) + { + return std::data(c); + } + + template + constexpr T* data(T(&array)[N]) noexcept + { + return std::data(array); + } +#else + template + constexpr auto size(const C& c) -> decltype(c.size()) + { + return c.size(); + } + + template + constexpr std::size_t size(const T(&array)[N]) noexcept + { + return N; + } + + template + constexpr auto data(C& c) -> decltype(c.data()) + { + return c.data(); + } + + template + constexpr T* data(T(&array)[N]) noexcept + { + return array; + } +#endif + + template + class span + { + public: + span() = default; + + constexpr span(std::initializer_list i) : m_begin(i.begin()), m_end(i.end()) {} + constexpr span(T* begin, T* end) : m_begin(begin), m_end(end) {} + constexpr span(T* begin, size_t elementCount) : m_begin(begin), m_end(begin + elementCount) {} + + template + constexpr span(ContiguousContainer&& container) + : m_begin(dml::detail::data(container)), m_end(m_begin + dml::detail::size(container)) {} + + template + constexpr span(T(&a)[N]) noexcept : span(a, N) {} + + T* data() noexcept { return m_begin; } + T* begin() noexcept { return m_begin; } + T* end() noexcept { return m_end; } + T const* data() const noexcept { return m_begin; } + T const* begin() const noexcept { return m_begin; } + T const* end() const noexcept { return m_end; } + bool empty() const noexcept { return m_end == m_begin; } + size_t size() const noexcept { return m_end - m_begin; } + size_t size_bytes() const noexcept { return sizeof(T) * size(); } + T& operator[](size_t index) const noexcept { return m_begin[index]; } + span subspan(size_t index, size_t count) { return span(m_begin + index, m_begin + index + count); } + + protected: + T* m_begin = nullptr; + T* m_end = nullptr; + }; + } + +#if DMLX_USE_ABSEIL + template + using Optional = absl::optional; + + constexpr absl::nullopt_t NullOpt = absl::nullopt; + + template + using SmallVector = absl::InlinedVector; + + template + using Span = absl::Span; + + using absl::make_unique; +#else + #ifndef DMLX_OPTIONAL_EXTENDED + template + using Optional = std::optional; + constexpr std::nullopt_t NullOpt = std::nullopt; + #endif + + template + using SmallVector = std::vector; + + #if __cpp_lib_span + template + using Span = std::span; + #elif DMLX_USE_GSL + template + using Span = gsl::span; + #else + template + using Span = dml::detail::span; + #endif + + using std::make_unique; +#endif + +#if __cplusplus >= 201703L && __has_include() + using StringView = std::string_view; +#else + using StringView = const std::string&; +#endif + +#if __cpp_exceptions + #if DMLX_USE_WIL + #define DMLX_THROW_IF_FAILED(_hr) THROW_IF_FAILED(_hr) + #define DMLX_THROW(_hr) THROW_HR(_hr) + #else + #define DMLX_THROW_IF_FAILED(_hr) if (FAILED(_hr)) { throw std::runtime_error(#_hr); } + #define DMLX_THROW(_hr) throw std::runtime_error(#_hr); + #endif +#else + #define DMLX_THROW_IF_FAILED(_hr) if (FAILED(_hr)) { std::abort(); } + #define DMLX_THROW(_hr) { std::abort(); } +#endif + + class Graph; + class Expression; + + using TensorDimensions = SmallVector; + using TensorStrides = SmallVector; + + // The custom properties returned by a TensorPolicy. + struct TensorProperties + { + Optional strides; + uint64_t totalTensorSizeInBytes; + uint32_t guaranteedBaseOffsetAlignment; + }; + + // Provides a way to customize the properties that DMLX automatically sets on tensors. Callers may provide their + // own TensorPolicy implementation to provide custom strides, total tensor sizes, and alignment. TensorPolicy + // objects can be set using Graph::SetTensorPolicy(). + class TensorPolicy + { + public: + // A function type that returns a TensorProperties object given a tensor data type, flags, and sizes. + using Func = std::function< + TensorProperties (DML_TENSOR_DATA_TYPE dataType, DML_TENSOR_FLAGS flags, Span sizes) + >; + + TensorPolicy() = default; + /*implicit*/ TensorPolicy(Func impl) + : m_impl(impl) + {} + + TensorProperties Get( + DML_TENSOR_DATA_TYPE dataType, + DML_TENSOR_FLAGS flags, + Span sizes) const + { + // Empty/uninitialized policy falls back to default. + if (!m_impl) + { + return ComputeDefault(dataType, flags, sizes); + } + + return m_impl(dataType, flags, sizes); + } + + // Returns the default tensor policy, which doesn't produce any changes to tensor layout, has no guaranteed + // alignment, and which uses DMLCalcBufferTensorSize to compute the total tensor size. + static TensorPolicy Default() + { + return TensorPolicy(); + } + + // A tensor policy that returns strides which produce tensors with a layout transposed to dimension order + // (0, 2, ..., n, 1). This is often referred to as "NHWC" or "interleaved channel" layout. This is useful, + // for example, when applied to 2D Convolution to produce outputs in an NHWC layout (as opposed to NCHW, which + // is the DirectML default for 2D Convolution). + // + // Examples of the transposes produced by this policy: + // NCW -> NWC + // NCHW -> NHWC + // NCDHW -> NDHWC + static TensorPolicy InterleavedChannel() + { + return TensorPolicy(&ComputeInterleavedChannel); + } + + private: + static TensorProperties ComputeDefault( + DML_TENSOR_DATA_TYPE dataType, + DML_TENSOR_FLAGS /*flags*/, + Span sizes) + { + uint32_t dimensionCount = static_cast(sizes.size()); + TensorProperties props; + props.strides = NullOpt; // no strides + props.totalTensorSizeInBytes = DMLCalcBufferTensorSize(dataType, dimensionCount, sizes.data(), nullptr); + props.guaranteedBaseOffsetAlignment = 0; + return props; + } + + static TensorProperties ComputeInterleavedChannel( + DML_TENSOR_DATA_TYPE dataType, + DML_TENSOR_FLAGS /*flags*/, + Span sizes) + { + uint32_t dimensionCount = static_cast(sizes.size()); + TensorStrides strides(dimensionCount); + + enum Axes { N, C, /* spatial dimensions ... */ }; + + // N dimension strides + if (dimensionCount >= 1) + { + strides[N] = 1; + for (uint32_t i = 1; i < dimensionCount; ++i) + { + strides[N] *= sizes[i]; + } + } + + // C dimension strides + if (dimensionCount >= 2) + { + strides[C] = 1; + } + + // Spatial dimension strides + if (dimensionCount >= 3) + { + uint32_t stride = sizes[C]; + for (uint32_t i = dimensionCount - 1; i >= 2; --i) + { + strides[i] = stride; + stride *= sizes[i]; + } + } + + TensorProperties props; + props.strides = std::move(strides); + props.totalTensorSizeInBytes = DMLCalcBufferTensorSize(dataType, dimensionCount, sizes.data(), props.strides->data()); + props.guaranteedBaseOffsetAlignment = 0; + return props; + } + + Func m_impl; + }; + + struct TensorDesc + { + public: + using Dimensions = TensorDimensions; + using Strides = TensorStrides; + + DML_TENSOR_DATA_TYPE dataType = DML_TENSOR_DATA_TYPE_UNKNOWN; + DML_TENSOR_FLAGS flags = DML_TENSOR_FLAG_NONE; + Dimensions sizes; + Optional strides; + uint64_t totalTensorSizeInBytes = 0; + uint32_t guaranteedBaseOffsetAlignment = 0; + + TensorDesc() = default; + + TensorDesc(DML_TENSOR_DATA_TYPE dataType, Dimensions sizes, const TensorPolicy& policy = {}) + : TensorDesc(dataType, DML_TENSOR_FLAG_NONE, sizes, policy) + {} + + TensorDesc(DML_TENSOR_DATA_TYPE dataType, DML_TENSOR_FLAGS flags, Dimensions sizes, const TensorPolicy& policy = {}) + { + TensorProperties props = policy.Get(dataType, flags, sizes); + Initialize( + dataType, + flags, + std::move(sizes), + std::move(props.strides), + props.totalTensorSizeInBytes, + props.guaranteedBaseOffsetAlignment); + } + + TensorDesc( + DML_TENSOR_DATA_TYPE dataType, + DML_TENSOR_FLAGS flags, + Dimensions sizes, + Optional strides, + uint64_t totalTensorSizeInBytes, + uint32_t guaranteedBaseOffsetAlignment) + { + Initialize(dataType, flags, std::move(sizes), std::move(strides), totalTensorSizeInBytes, guaranteedBaseOffsetAlignment); + } + + /* implicit */ TensorDesc(const DML_TENSOR_DESC& desc) + : TensorDesc(*static_cast(desc.Desc)) + { + assert(desc.Type == DML_TENSOR_TYPE_BUFFER); + assert(desc.Desc != nullptr); + } + + /* implicit */ TensorDesc(const DML_BUFFER_TENSOR_DESC& desc) + { + this->dataType = desc.DataType; + this->flags = desc.Flags; + this->sizes.assign(desc.Sizes, desc.Sizes + desc.DimensionCount); + if (desc.Strides) + { + this->strides.emplace(); + this->strides->assign(desc.Strides, desc.Strides + desc.DimensionCount); + } + this->totalTensorSizeInBytes = desc.TotalTensorSizeInBytes; + this->guaranteedBaseOffsetAlignment = desc.GuaranteedBaseOffsetAlignment; + } + + // Returns an equivalent DML_TENSOR_DESC or DML_BUFFER_TENSOR_DESC. The returned object contains pointers + // into the TensorDesc, so it is only valid as long as the TensorDesc itself is alive. + template + T* AsPtr() + { + // "sizeof(T) == -1" is always false; this is just to make the static_assert dependent on the template + // parameter and therefore not evaluated until template instantiation + static_assert(sizeof(T) == -1, "Invalid type"); + } + + template <> + DML_BUFFER_TENSOR_DESC* AsPtr() + { + assert(!strides || sizes.size() == strides->size()); + + m_bufferDesc.DataType = this->dataType; + m_bufferDesc.Flags = this->flags; + m_bufferDesc.DimensionCount = static_cast(sizes.size()); + m_bufferDesc.Sizes = this->sizes.data(); + m_bufferDesc.Strides = this->strides ? this->strides->data() : nullptr; + m_bufferDesc.TotalTensorSizeInBytes = this->totalTensorSizeInBytes; + m_bufferDesc.GuaranteedBaseOffsetAlignment = this->guaranteedBaseOffsetAlignment; + return &m_bufferDesc; + } + + template <> + DML_TENSOR_DESC* AsPtr() + { + m_tensorDesc = DML_TENSOR_DESC{ DML_TENSOR_TYPE_BUFFER, AsPtr() }; + return &m_tensorDesc; + } + + private: + DML_BUFFER_TENSOR_DESC m_bufferDesc; + DML_TENSOR_DESC m_tensorDesc; + + void Initialize( + DML_TENSOR_DATA_TYPE tensorDataType, + DML_TENSOR_FLAGS tensorFlags, + Dimensions tensorSizes, + Optional tensorStrides, + uint64_t totalTensorSizeInBytesVal, + uint32_t guaranteedBaseOffsetAlignmentVal) + { + assert(!tensorStrides || tensorStrides->size() == static_cast(tensorSizes.size())); + + this->dataType = tensorDataType; + this->flags = tensorFlags; + this->sizes = std::move(tensorSizes); + this->strides = std::move(tensorStrides); + this->totalTensorSizeInBytes = totalTensorSizeInBytesVal; + this->guaranteedBaseOffsetAlignment = guaranteedBaseOffsetAlignmentVal; + } + }; + + namespace detail + { + class GraphBuilder; + class NodeOutput; + + // A node in the graph which represents a graph input. + struct InputNode + { + uint32_t inputIndex; + }; + + // A node in the graph which represents a DML operator. + struct OperatorNode + { + Microsoft::WRL::ComPtr op; + + // The inputs to this node + std::vector inputs; + + std::string name; + }; + + // Used for representing reshapes and type punning + struct ReinterpretNode + { + NodeOutput* input; + }; + + enum class NodeType + { + Invalid, + Input, + Operator, + Reinterpret, + }; + + // Identifies a node in the graph. + struct NodeID + { + NodeType type; + uint32_t index; // The index of this node in the GraphBuilder + }; + + // Represents one of the outputs of a node. + class NodeOutput + { + public: + NodeOutput(GraphBuilder* owner, NodeID node, uint32_t outputIndex, TensorDesc tensorDesc) + : m_owner(owner) + , m_node(node) + , m_outputIndex(outputIndex) + , m_tensorDesc(std::move(tensorDesc)) + {} + + // Retrieves the GraphBuilder that owns this object. + GraphBuilder* GetGraphBuilder() const { return m_owner; } + + NodeID GetNode() const { return m_node; } + uint32_t GetOutputIndex() const { return m_outputIndex; } + const TensorDesc& GetOutputDesc() const { return m_tensorDesc; } + + private: + GraphBuilder* m_owner; + NodeID m_node; + + // An operator can have multiple outputs; this index identifies which one of the operator's outputs this + // NodeOutput represents. + uint32_t m_outputIndex; + + TensorDesc m_tensorDesc; + }; + + struct GraphDesc + { + uint32_t inputCount; + uint32_t outputCount; + std::vector nodes; + std::vector inputEdges; + std::vector outputEdges; + std::vector intermediateEdges; + }; + + class GraphBuilder + { + public: + GraphBuilder(IDMLDevice* device, TensorPolicy tensorPolicy = {}) + : m_device(device) + , m_tensorPolicy(tensorPolicy) + {} + + IDMLDevice* GetDevice() const + { + return m_device.Get(); + } + + void PushName(StringView name) + { + m_nameSubLengths.push(m_name.size()); + if (!m_name.empty()) + { + m_name += "_"; + } + m_name += name; + } + + void PopName() + { + if (!m_nameSubLengths.empty()) + { + m_name.resize(m_nameSubLengths.top()); + m_nameSubLengths.pop(); + } + } + + void SetTensorPolicy(TensorPolicy policy) { m_tensorPolicy = std::move(policy); } + const TensorPolicy& GetTensorPolicy() const { return m_tensorPolicy; } + TensorPolicy& GetTensorPolicy() { return m_tensorPolicy; } + + // Creates a DML operator node owned by this graph builder and returns a NodeInfo identifier. The + // inputs to this node must be supplied in the correct order matching the DML operator. + NodeID CreateOperatorNode(DML_OPERATOR_TYPE type, const void* desc, Span inputs); + NodeID CreateInputNode(uint32_t inputIndex); + NodeID CreateReinterpretNode(NodeOutput* input); + NodeOutput* CreateNodeOutput(NodeID node, uint32_t outputIndex, TensorDesc tensorDesc); + GraphDesc GetGraphDesc(Span outputs) const; + + private: + Microsoft::WRL::ComPtr m_device; + TensorPolicy m_tensorPolicy; + std::vector m_inputNodes; + std::vector m_operatorNodes; + std::vector m_reinterpretNodes; + std::deque m_nodeOutputs; // deque doesn't invalidate references to elements when it resizes + + std::string m_name; + std::stack m_nameSubLengths; + }; + + } // namespace detail + + class Expression + { + public: + /*implicit*/ Expression(detail::NodeOutput* nodeOutput = nullptr) + : m_nodeOutput(nodeOutput) + {} + + // Returns a struct containing the required properties of the tensor to hold the output of this expression, + // once evaluated. + const TensorDesc& GetOutputDesc() const { return Impl()->GetOutputDesc(); } + + // For internal use only + detail::NodeOutput* Impl() const { return m_nodeOutput; } + + explicit operator bool() const + { + return m_nodeOutput != nullptr; + } + + private: + detail::NodeOutput* m_nodeOutput; // weak; this is owned by the GraphBuilder + }; + + class NameScope + { + public: + detail::GraphBuilder* m_builder = nullptr; + + NameScope(detail::GraphBuilder* builder, StringView name) : m_builder(builder) + { + if (m_builder) m_builder->PushName(name); + } + + ~NameScope() + { + if (m_builder) m_builder->PopName(); + } + }; + + class Graph + { + public: + explicit Graph(IDMLDevice* device, TensorPolicy tensorPolicy = {}) + : m_graphBuilder(make_unique(device, tensorPolicy)) + {} + + // For internal use only + detail::GraphBuilder* Impl() { return m_graphBuilder.get(); } + + // Sets/gets the tensor policy. If not set, defaults to TensorPolicy::Default(). Tensor policies can be used + // to control properties (such as strides) on output tensors produced by this Graph. + void SetTensorPolicy(TensorPolicy policy) { m_graphBuilder->SetTensorPolicy(std::move(policy)); } + const TensorPolicy& GetTensorPolicy() const { return m_graphBuilder->GetTensorPolicy(); } + TensorPolicy& GetTensorPolicy() { return m_graphBuilder->GetTensorPolicy(); } + + NameScope CreateNameScope(StringView name) { return NameScope(m_graphBuilder.get(), name); } + + void PushName(StringView name) { m_graphBuilder->PushName(name); } + void PopName() { m_graphBuilder->PopName(); } + + Microsoft::WRL::ComPtr Compile( + DML_EXECUTION_FLAGS flags, + Span outputs, + uint32_t inputCount = 0) const + { + detail::GraphDesc graph = m_graphBuilder->GetGraphDesc(outputs); + + // If supplied, the requested number of inputs to the compiled operator can be larger than the actual + // number of input nodes on the graph (e.g. in the case of unused empty inputs), but never smaller. + assert(inputCount == 0 || inputCount >= graph.inputCount); + + std::vector graphNodes(graph.nodes.size()); + for (size_t i = 0; i < graphNodes.size(); ++i) + { + graphNodes[i] = { DML_GRAPH_NODE_TYPE_OPERATOR, &graph.nodes[i] }; + } + + std::vector inputEdges(graph.inputEdges.size()); + for (size_t i = 0; i < inputEdges.size(); ++i) + { + inputEdges[i] = { DML_GRAPH_EDGE_TYPE_INPUT, &graph.inputEdges[i] }; + } + + std::vector outputEdges(graph.outputEdges.size()); + for (size_t i = 0; i < outputEdges.size(); ++i) + { + outputEdges[i] = { DML_GRAPH_EDGE_TYPE_OUTPUT, &graph.outputEdges[i] }; + } + + std::vector intermediateEdges(graph.intermediateEdges.size()); + for (size_t i = 0; i < intermediateEdges.size(); ++i) + { + intermediateEdges[i] = { DML_GRAPH_EDGE_TYPE_INTERMEDIATE, &graph.intermediateEdges[i] }; + } + + DML_GRAPH_DESC graphDesc = {}; + graphDesc.InputCount = inputCount ? inputCount : graph.inputCount; + graphDesc.OutputCount = graph.outputCount; + graphDesc.NodeCount = static_cast(graphNodes.size()); + graphDesc.Nodes = graphNodes.data(); + graphDesc.InputEdgeCount = static_cast(inputEdges.size()); + graphDesc.InputEdges = inputEdges.data(); + graphDesc.OutputEdgeCount = static_cast(outputEdges.size()); + graphDesc.OutputEdges = outputEdges.data(); + graphDesc.IntermediateEdgeCount = static_cast(intermediateEdges.size()); + graphDesc.IntermediateEdges = intermediateEdges.data(); + + Microsoft::WRL::ComPtr device1; + DMLX_THROW_IF_FAILED(m_graphBuilder->GetDevice()->QueryInterface(IID_PPV_ARGS(&device1))); + + Microsoft::WRL::ComPtr compiledGraph; + DMLX_THROW_IF_FAILED(device1->CompileGraph(&graphDesc, flags, IID_PPV_ARGS(&compiledGraph))); + + return compiledGraph; + } + + private: + std::unique_ptr m_graphBuilder; + }; + + // Represents an activation to be fused with an existing operator. The meaning of param1 and param2 depend on the + // activation to be fused. + // + // For HARD_SIGMOID, LINEAR, PARAMETRIC_SOFTPLUS, and SCALED_TANH: param1 = Alpha and param2 = Beta + // For ELU, LEAKY_RELU, THRESHOLDED_RELU, and CELU: param1 = Alpha. param2 is unused. + // For SCALED_ELU, param1 = Alpha and param2 = Gamma. + // For SHRINK, param1 = Bias and param2 = Threshold + // For SOFTPLUS, param1 = Steepness. + // For all other activations, both param1 and param2 are unused. + struct FusedActivation + { + DML_OPERATOR_TYPE activation = DML_OPERATOR_INVALID; + float param1 = 0.0f; + float param2 = 0.0f; + + FusedActivation() = default; + + explicit FusedActivation(DML_OPERATOR_TYPE activation, float param1 = 0.0f, float param2 = 0.0f) + : activation(activation), param1(param1), param2(param2) + {} + + static FusedActivation None() + { + return FusedActivation(); + } + + static FusedActivation Elu(float alpha = 1.0f) + { + return FusedActivation(DML_OPERATOR_ACTIVATION_ELU, alpha); + } + + static FusedActivation HardSigmoid(float alpha = 0.2f, float beta = 0.5f) + { + return FusedActivation(DML_OPERATOR_ACTIVATION_HARD_SIGMOID, alpha, beta); + } + + static FusedActivation Identity() + { + return FusedActivation(DML_OPERATOR_ACTIVATION_IDENTITY); + } + + static FusedActivation LeakyRelu(float alpha = 0.01f) + { + return FusedActivation(DML_OPERATOR_ACTIVATION_LEAKY_RELU, alpha); + } + + static FusedActivation Linear(float alpha, float beta) + { + return FusedActivation(DML_OPERATOR_ACTIVATION_LINEAR, alpha, beta); + } + + static FusedActivation ParametricSoftplus(float alpha, float beta) + { + return FusedActivation(DML_OPERATOR_ACTIVATION_PARAMETRIC_SOFTPLUS, alpha, beta); + } + + static FusedActivation Relu() + { + return FusedActivation(DML_OPERATOR_ACTIVATION_RELU); + } + + static FusedActivation ScaledElu(float alpha = 1.67326319217681884765625f, float gamma = 1.05070102214813232421875f) + { + return FusedActivation(DML_OPERATOR_ACTIVATION_SCALED_ELU, alpha, gamma); + } + + static FusedActivation ScaledTanh(float alpha = 1.0f, float beta = 0.5f) + { + return FusedActivation(DML_OPERATOR_ACTIVATION_SCALED_TANH, alpha, beta); + } + + static FusedActivation Sigmoid() + { + return FusedActivation(DML_OPERATOR_ACTIVATION_SIGMOID); + } + + static FusedActivation Softplus(float steepness = 1.0f) + { + return FusedActivation(DML_OPERATOR_ACTIVATION_SOFTPLUS, steepness); + } + + static FusedActivation Softsign() + { + return FusedActivation(DML_OPERATOR_ACTIVATION_SOFTSIGN); + } + + static FusedActivation Tanh() + { + return FusedActivation(DML_OPERATOR_ACTIVATION_TANH); + } + + static FusedActivation ThresholdedRelu(float alpha = 1.0f) + { + return FusedActivation(DML_OPERATOR_ACTIVATION_THRESHOLDED_RELU, alpha); + } + + static FusedActivation Shrink(float bias = 0.0f, float threshold = 0.5f) + { + return FusedActivation(DML_OPERATOR_ACTIVATION_SHRINK, bias, threshold); + } + + static FusedActivation Celu(float alpha = 1.0f) + { + return FusedActivation(DML_OPERATOR_ACTIVATION_CELU, alpha); + } + }; + + // Implementation detail helper for determining if a list of expressions share the same GraphBuilder. + namespace detail + { + inline bool HasSameOwner(Span exprs) + { + if (exprs.size() == 0) + { + return true; + } + + detail::GraphBuilder* owner = exprs.begin()->Impl()->GetGraphBuilder(); + for (Expression expr : exprs) + { + if (expr.Impl()->GetGraphBuilder() != owner) + { + return false; + } + } + + return true; + } + + inline bool HasSameOwner(std::initializer_list exprs) + { + Span span(exprs.begin(), exprs.size()); + return HasSameOwner(span); + } + + inline bool HasSameDataType(Span exprs) + { + if (exprs.size() == 0) + { + return true; + } + + DML_TENSOR_DATA_TYPE dataType = exprs.begin()->Impl()->GetOutputDesc().dataType; + for (Expression expr : exprs) + { + if (expr.Impl()->GetOutputDesc().dataType != dataType) + { + return false; + } + } + + return true; + } + + inline bool HasSameDataType(std::initializer_list exprs) + { + Span span(exprs.begin(), exprs.size()); + return HasSameDataType(span); + } + } // namespace detail + + // Expression implementation helpers + namespace detail + { + template + Expression ElementWiseUnary(Expression input, const Optional& scaleBias) + { + detail::GraphBuilder* builder = input.Impl()->GetGraphBuilder(); + + TensorDesc inputTensor = input.Impl()->GetOutputDesc(); + TensorDesc outputTensor(inputTensor.dataType, inputTensor.sizes, builder->GetTensorPolicy()); // Same as input + + TDesc desc = {}; + desc.InputTensor = inputTensor.AsPtr(); + desc.OutputTensor = outputTensor.AsPtr(); + desc.ScaleBias = scaleBias ? &scaleBias.value() : nullptr; + + detail::NodeOutput* const inputs[] = { input.Impl() }; + detail::NodeID node = builder->CreateOperatorNode(OperatorType, &desc, inputs); + detail::NodeOutput* output = builder->CreateNodeOutput(node, 0, std::move(outputTensor)); + + return output; + } + + template + Expression ElementWiseUnary(Expression input, DML_TENSOR_DATA_TYPE outputDataType = DML_TENSOR_DATA_TYPE_UNKNOWN) + { + detail::GraphBuilder* builder = input.Impl()->GetGraphBuilder(); + + TensorDesc inputTensor = input.Impl()->GetOutputDesc(); + + if (outputDataType == DML_TENSOR_DATA_TYPE_UNKNOWN) + { + outputDataType = inputTensor.dataType; + } + TensorDesc outputTensor(outputDataType, inputTensor.sizes, builder->GetTensorPolicy()); + + TDesc desc = {}; + desc.InputTensor = inputTensor.AsPtr(); + desc.OutputTensor = outputTensor.AsPtr(); + + detail::NodeOutput* const inputs[] = { input.Impl() }; + detail::NodeID node = builder->CreateOperatorNode(OperatorType, &desc, inputs); + detail::NodeOutput* output = builder->CreateNodeOutput(node, 0, std::move(outputTensor)); + + return output; + } + + template + Expression ElementWiseBinary(Expression a, Expression b) + { + assert(detail::HasSameOwner({ a, b })); + detail::GraphBuilder* builder = a.Impl()->GetGraphBuilder(); + + TensorDesc aTensor = a.Impl()->GetOutputDesc(); + TensorDesc bTensor = b.Impl()->GetOutputDesc(); + TensorDesc outputTensor(aTensor.dataType, aTensor.sizes, builder->GetTensorPolicy()); // Same as input + + TDesc desc = {}; + desc.ATensor = aTensor.AsPtr(); + desc.BTensor = bTensor.AsPtr(); + desc.OutputTensor = outputTensor.AsPtr(); + + detail::NodeOutput* const inputs[] = { a.Impl(), b.Impl() }; + detail::NodeID node = builder->CreateOperatorNode(OperatorType, &desc, inputs); + detail::NodeOutput* output = builder->CreateNodeOutput(node, 0, std::move(outputTensor)); + + return output; + } + + template + Expression ElementWiseComparison(Expression a, Expression b, DML_TENSOR_DATA_TYPE outputDataType = DML_TENSOR_DATA_TYPE_UINT8) + { + assert(detail::HasSameOwner({ a, b })); + detail::GraphBuilder* builder = a.Impl()->GetGraphBuilder(); + + TensorDesc aTensor = a.Impl()->GetOutputDesc(); + TensorDesc bTensor = b.Impl()->GetOutputDesc(); + TensorDesc outputTensor(outputDataType, aTensor.sizes, builder->GetTensorPolicy()); + + TDesc desc = {}; + desc.ATensor = aTensor.AsPtr(); + desc.BTensor = bTensor.AsPtr(); + desc.OutputTensor = outputTensor.AsPtr(); + + detail::NodeOutput* const inputs[] = { a.Impl(), b.Impl() }; + detail::NodeID node = builder->CreateOperatorNode(OperatorType, &desc, inputs); + detail::NodeOutput* output = builder->CreateNodeOutput(node, 0, std::move(outputTensor)); + + return output; + } + + // Used to reserve some space on the stack for setting up fused activation operator descs. + struct FusedActivationStorage + { + DML_OPERATOR_DESC opDesc; + + // All fuseable activation descs have a common layout: two tensor desc pointers and up to 2 optional + // float parameters, so just use LINEAR as an archetype + DML_ACTIVATION_LINEAR_OPERATOR_DESC activationDesc; + }; + + // Returns the correct value for filling out fused activation fields in the DML API, e.g. + // DML_CONVOLUTION_OPERATOR_DESC::FusedActivation. The descs themselves are stored in the `storage` outptr. + inline const DML_OPERATOR_DESC* GetFusedActivationPtr( + FusedActivation fusedActivation, + _Out_ FusedActivationStorage* storage) + { + if (fusedActivation.activation == DML_OPERATOR_INVALID) + { + // No fused activation + return nullptr; + } + + storage->activationDesc.InputTensor = nullptr; + storage->activationDesc.OutputTensor = nullptr; + storage->activationDesc.Alpha = fusedActivation.param1; + storage->activationDesc.Beta = fusedActivation.param2; + + storage->opDesc.Type = fusedActivation.activation; + storage->opDesc.Desc = &storage->activationDesc; + + return &storage->opDesc; + } + + } // namespace detail + + inline Expression InputTensor(Graph& graph, uint32_t inputIndex, TensorDesc desc) + { + detail::GraphBuilder* builder = graph.Impl(); + + detail::NodeID node = builder->CreateInputNode(inputIndex); + detail::NodeOutput* output = builder->CreateNodeOutput(node, 0, std::move(desc)); + return output; + } + + inline Expression Identity(Expression input, const Optional& scaleBias = NullOpt) + { + return detail::ElementWiseUnary(input, scaleBias); + } + + inline Expression Abs(Expression input, const Optional& scaleBias = NullOpt) + { + return detail::ElementWiseUnary(input, scaleBias); + } + + inline Expression ACos(Expression input, const Optional& scaleBias = NullOpt) + { + return detail::ElementWiseUnary(input, scaleBias); + } + + inline Expression Add(Expression a, Expression b) + { + assert(detail::HasSameOwner({ a, b })); + detail::GraphBuilder* builder = a.Impl()->GetGraphBuilder(); + + TensorDesc aTensor = a.Impl()->GetOutputDesc(); + TensorDesc bTensor = b.Impl()->GetOutputDesc(); + TensorDesc outputTensor(aTensor.dataType, aTensor.sizes, builder->GetTensorPolicy()); // Same as input + + DML_ELEMENT_WISE_ADD_OPERATOR_DESC desc = {}; + desc.ATensor = aTensor.AsPtr(); + desc.BTensor = bTensor.AsPtr(); + desc.OutputTensor = outputTensor.AsPtr(); + + detail::NodeOutput* const inputs[] = { a.Impl(), b.Impl() }; + detail::NodeID node = builder->CreateOperatorNode(DML_OPERATOR_ELEMENT_WISE_ADD, &desc, inputs); + detail::NodeOutput* output = builder->CreateNodeOutput(node, 0, std::move(outputTensor)); + + return output; + } + + inline Expression Add(Expression a, Expression b, FusedActivation fusedActivation) + { + assert(detail::HasSameOwner({ a, b })); + detail::GraphBuilder* builder = a.Impl()->GetGraphBuilder(); + + TensorDesc aTensor = a.Impl()->GetOutputDesc(); + TensorDesc bTensor = b.Impl()->GetOutputDesc(); + TensorDesc outputTensor(aTensor.dataType, aTensor.sizes, builder->GetTensorPolicy()); // Same as input + detail::FusedActivationStorage storage; + + DML_ELEMENT_WISE_ADD1_OPERATOR_DESC desc = {}; + desc.ATensor = aTensor.AsPtr(); + desc.BTensor = bTensor.AsPtr(); + desc.OutputTensor = outputTensor.AsPtr(); + desc.FusedActivation = detail::GetFusedActivationPtr(fusedActivation, &storage); + + detail::NodeOutput* const inputs[] = { a.Impl(), b.Impl() }; + detail::NodeID node = builder->CreateOperatorNode(DML_OPERATOR_ELEMENT_WISE_ADD1, &desc, inputs); + detail::NodeOutput* output = builder->CreateNodeOutput(node, 0, std::move(outputTensor)); + + return output; + } + + inline Expression ASin(Expression input, const Optional& scaleBias = NullOpt) + { + return detail::ElementWiseUnary(input, scaleBias); + } + + inline Expression ATan(Expression input, const Optional& scaleBias = NullOpt) + { + return detail::ElementWiseUnary(input, scaleBias); + } + +#if DML_TARGET_VERSION >= 0x3100 + + inline Expression ATanYX(Expression a, Expression b) + { + return detail::ElementWiseBinary(a, b); + } + +#endif // DML_TARGET_VERSION >= 0x3100 + + inline Expression Ceil(Expression input, const Optional& scaleBias = NullOpt) + { + return detail::ElementWiseUnary(input, scaleBias); + } + + inline Expression Clip(Expression input, float min, float max, const Optional& scaleBias = NullOpt) + { + detail::GraphBuilder* builder = input.Impl()->GetGraphBuilder(); + + TensorDesc inputTensor = input.Impl()->GetOutputDesc(); + TensorDesc outputTensor(inputTensor.dataType, inputTensor.sizes, builder->GetTensorPolicy()); // Same as input + + DML_ELEMENT_WISE_CLIP_OPERATOR_DESC desc = {}; + desc.InputTensor = inputTensor.AsPtr(); + desc.OutputTensor = outputTensor.AsPtr(); + desc.ScaleBias = scaleBias ? &scaleBias.value() : nullptr; + desc.Min = min; + desc.Max = max; + + detail::NodeOutput* const inputs[] = { input.Impl() }; + detail::NodeID node = builder->CreateOperatorNode(DML_OPERATOR_ELEMENT_WISE_CLIP, &desc, inputs); + detail::NodeOutput* output = builder->CreateNodeOutput(node, 0, std::move(outputTensor)); + + return output; + } + +#if DML_TARGET_VERSION >= 0x3100 + + inline Expression ClipGrad(Expression input, Expression inputGradient, float min, float max) + { + detail::GraphBuilder* builder = input.Impl()->GetGraphBuilder(); + + TensorDesc inputTensor = input.Impl()->GetOutputDesc(); + TensorDesc inputGradientTensor = inputGradient.Impl()->GetOutputDesc(); + TensorDesc outputGradientTensor(inputGradientTensor.dataType, inputGradientTensor.sizes, builder->GetTensorPolicy()); + + DML_ELEMENT_WISE_CLIP_GRAD_OPERATOR_DESC desc = {}; + desc.InputTensor = inputTensor.AsPtr(); + desc.InputGradientTensor = inputGradientTensor.AsPtr(); + desc.OutputGradientTensor = outputGradientTensor.AsPtr(); + desc.Min = min; + desc.Max = max; + + detail::NodeOutput* const inputs[] = { input.Impl() }; + detail::NodeID node = builder->CreateOperatorNode(DML_OPERATOR_ELEMENT_WISE_CLIP_GRAD, &desc, inputs); + detail::NodeOutput* output = builder->CreateNodeOutput(node, 0, std::move(outputGradientTensor)); + + return output; + } + +#endif // DML_TARGET_VERSION >= 0x3100 + + inline Expression Cos(Expression input, const Optional& scaleBias = NullOpt) + { + return detail::ElementWiseUnary(input, scaleBias); + } + + inline Expression Divide(Expression a, Expression b) + { + return detail::ElementWiseBinary(a, b); + } + + inline Expression Exp(Expression input, const Optional& scaleBias = NullOpt) + { + return detail::ElementWiseUnary(input, scaleBias); + } + + inline Expression Floor(Expression input, const Optional& scaleBias = NullOpt) + { + return detail::ElementWiseUnary(input, scaleBias); + } + + inline Expression Log(Expression input, const Optional& scaleBias = NullOpt) + { + return detail::ElementWiseUnary(input, scaleBias); + } + + inline Expression LogicalAnd(Expression a, Expression b) + { + return detail::ElementWiseBinary(a, b); + } + + inline Expression Equals(Expression a, Expression b, DML_TENSOR_DATA_TYPE outputDataType = DML_TENSOR_DATA_TYPE_UINT8) + { + return detail::ElementWiseComparison< + DML_OPERATOR_ELEMENT_WISE_LOGICAL_EQUALS, + DML_ELEMENT_WISE_LOGICAL_EQUALS_OPERATOR_DESC>(a, b, outputDataType); + } + + inline Expression GreaterThan(Expression a, Expression b, DML_TENSOR_DATA_TYPE outputDataType = DML_TENSOR_DATA_TYPE_UINT8) + { + return detail::ElementWiseComparison< + DML_OPERATOR_ELEMENT_WISE_LOGICAL_GREATER_THAN, + DML_ELEMENT_WISE_LOGICAL_GREATER_THAN_OPERATOR_DESC>(a, b, outputDataType); + } + + inline Expression GreaterThanOrEqual(Expression a, Expression b, DML_TENSOR_DATA_TYPE outputDataType = DML_TENSOR_DATA_TYPE_UINT8) + { + return detail::ElementWiseComparison< + DML_OPERATOR_ELEMENT_WISE_LOGICAL_GREATER_THAN_OR_EQUAL, + DML_ELEMENT_WISE_LOGICAL_GREATER_THAN_OR_EQUAL_OPERATOR_DESC>(a, b, outputDataType); + } + + inline Expression LessThan(Expression a, Expression b, DML_TENSOR_DATA_TYPE outputDataType = DML_TENSOR_DATA_TYPE_UINT8) + { + return detail::ElementWiseComparison< + DML_OPERATOR_ELEMENT_WISE_LOGICAL_LESS_THAN, + DML_ELEMENT_WISE_LOGICAL_LESS_THAN_OPERATOR_DESC>(a, b, outputDataType); + } + + inline Expression LessThanOrEqual(Expression a, Expression b, DML_TENSOR_DATA_TYPE outputDataType = DML_TENSOR_DATA_TYPE_UINT8) + { + return detail::ElementWiseComparison< + DML_OPERATOR_ELEMENT_WISE_LOGICAL_LESS_THAN_OR_EQUAL, + DML_ELEMENT_WISE_LOGICAL_LESS_THAN_OR_EQUAL_OPERATOR_DESC>(a, b, outputDataType); + } + + inline Expression LogicalNot(Expression input) + { + detail::GraphBuilder* builder = input.Impl()->GetGraphBuilder(); + + TensorDesc inputTensor = input.Impl()->GetOutputDesc(); + TensorDesc outputTensor(inputTensor.dataType, inputTensor.sizes, builder->GetTensorPolicy()); // Same as input + + DML_ELEMENT_WISE_LOGICAL_NOT_OPERATOR_DESC desc = {}; + desc.InputTensor = inputTensor.AsPtr(); + desc.OutputTensor = outputTensor.AsPtr(); + + detail::NodeOutput* const inputs[] = { input.Impl() }; + detail::NodeID node = builder->CreateOperatorNode(DML_OPERATOR_ELEMENT_WISE_LOGICAL_NOT, &desc, inputs); + detail::NodeOutput* output = builder->CreateNodeOutput(node, 0, std::move(outputTensor)); + + return output; + } + + inline Expression LogicalOr(Expression a, Expression b) + { + return detail::ElementWiseBinary(a, b); + } + + inline Expression LogicalXor(Expression a, Expression b) + { + return detail::ElementWiseBinary(a, b); + } + + inline Expression Max(Expression a, Expression b) + { + return detail::ElementWiseBinary(a, b); + } + + inline Expression Mean(Expression a, Expression b) + { + return detail::ElementWiseBinary(a, b); + } + + inline Expression Min(Expression a, Expression b) + { + return detail::ElementWiseBinary(a, b); + } + + inline Expression Multiply(Expression a, Expression b) + { + return detail::ElementWiseBinary(a, b); + } + + inline Expression Pow(Expression input, Expression exponent, const Optional& scaleBias = NullOpt) + { + assert(detail::HasSameOwner({ input, exponent })); + detail::GraphBuilder* builder = input.Impl()->GetGraphBuilder(); + + TensorDesc inputTensor = input.Impl()->GetOutputDesc(); + TensorDesc exponentTensor = exponent.Impl()->GetOutputDesc(); + TensorDesc outputTensor(inputTensor.dataType, inputTensor.sizes, builder->GetTensorPolicy()); // Same as input + + DML_ELEMENT_WISE_POW_OPERATOR_DESC desc = {}; + desc.InputTensor = inputTensor.AsPtr(); + desc.ExponentTensor = exponentTensor.AsPtr(); + desc.OutputTensor = outputTensor.AsPtr(); + desc.ScaleBias = scaleBias ? &scaleBias.value() : nullptr; + + detail::NodeOutput* const inputs[] = { input.Impl(), exponent.Impl() }; + detail::NodeID node = builder->CreateOperatorNode(DML_OPERATOR_ELEMENT_WISE_POW, &desc, inputs); + detail::NodeOutput* output = builder->CreateNodeOutput(node, 0, std::move(outputTensor)); + + return output; + } + + inline Expression Pow(Expression input, float exponent, const Optional& scaleBias = NullOpt) + { + detail::GraphBuilder* builder = input.Impl()->GetGraphBuilder(); + + TensorDesc inputTensor = input.Impl()->GetOutputDesc(); + TensorDesc outputTensor(inputTensor.dataType, inputTensor.sizes, builder->GetTensorPolicy()); // Same as input + + DML_ELEMENT_WISE_CONSTANT_POW_OPERATOR_DESC desc = {}; + desc.InputTensor = inputTensor.AsPtr(); + desc.OutputTensor = outputTensor.AsPtr(); + desc.ScaleBias = scaleBias ? &scaleBias.value() : nullptr; + desc.Exponent = exponent; + + detail::NodeOutput* const inputs[] = { input.Impl() }; + detail::NodeID node = builder->CreateOperatorNode(DML_OPERATOR_ELEMENT_WISE_CONSTANT_POW, &desc, inputs); + detail::NodeOutput* output = builder->CreateNodeOutput(node, 0, std::move(outputTensor)); + + return output; + } + + inline Expression Recip(Expression input, const Optional& scaleBias = NullOpt) + { + return detail::ElementWiseUnary(input, scaleBias); + } + + inline Expression Sin(Expression input, const Optional& scaleBias = NullOpt) + { + return detail::ElementWiseUnary(input, scaleBias); + } + + inline Expression Sqrt(Expression input, const Optional& scaleBias = NullOpt) + { + return detail::ElementWiseUnary(input, scaleBias); + } + +#if DML_TARGET_VERSION >= 0x3100 + + inline Expression DifferenceSquare(Expression a, Expression b) + { + return detail::ElementWiseBinary(a, b); + } + +#endif // DML_TARGET_VERSION >= 0x3100 + + inline Expression Subtract(Expression a, Expression b) + { + return detail::ElementWiseBinary(a, b); + } + + inline Expression Tan(Expression input, const Optional& scaleBias = NullOpt) + { + return detail::ElementWiseUnary(input, scaleBias); + } + + inline Expression Threshold(Expression input, float min, const Optional& scaleBias = NullOpt) + { + detail::GraphBuilder* builder = input.Impl()->GetGraphBuilder(); + + TensorDesc inputTensor = input.Impl()->GetOutputDesc(); + TensorDesc outputTensor(inputTensor.dataType, inputTensor.sizes, builder->GetTensorPolicy()); // Same as input + + DML_ELEMENT_WISE_THRESHOLD_OPERATOR_DESC desc = {}; + desc.InputTensor = inputTensor.AsPtr(); + desc.OutputTensor = outputTensor.AsPtr(); + desc.ScaleBias = scaleBias ? &scaleBias.value() : nullptr; + desc.Min = min; + + detail::NodeOutput* const inputs[] = { input.Impl() }; + detail::NodeID node = builder->CreateOperatorNode(DML_OPERATOR_ELEMENT_WISE_THRESHOLD, &desc, inputs); + detail::NodeOutput* output = builder->CreateNodeOutput(node, 0, std::move(outputTensor)); + + return output; + } + + inline Expression QuantizeLinear(Expression input, Expression scale, Expression zeroPoint, DML_TENSOR_DATA_TYPE outputDataType = DML_TENSOR_DATA_TYPE_UINT8) + { + assert(detail::HasSameOwner({ input, scale, zeroPoint })); + + detail::GraphBuilder* builder = input.Impl()->GetGraphBuilder(); + + TensorDesc inputTensor = input.Impl()->GetOutputDesc(); + TensorDesc scaleTensor = scale.Impl()->GetOutputDesc(); + TensorDesc zeroPointTensor = zeroPoint.Impl()->GetOutputDesc(); + TensorDesc outputTensor(outputDataType, inputTensor.sizes, builder->GetTensorPolicy()); + + DML_ELEMENT_WISE_QUANTIZE_LINEAR_OPERATOR_DESC desc = {}; + desc.InputTensor = inputTensor.AsPtr(); + desc.ScaleTensor = scaleTensor.AsPtr(); + desc.ZeroPointTensor = zeroPointTensor.AsPtr(); + desc.OutputTensor = outputTensor.AsPtr(); + + detail::NodeOutput* const inputs[] = { input.Impl(), scale.Impl(), zeroPoint.Impl() }; + detail::NodeID node = builder->CreateOperatorNode(DML_OPERATOR_ELEMENT_WISE_QUANTIZE_LINEAR, &desc, inputs); + detail::NodeOutput* output = builder->CreateNodeOutput(node, 0, std::move(outputTensor)); + + return output; + } + + inline Expression DequantizeLinear(Expression input, Expression scale, Expression zeroPoint) + { + assert(detail::HasSameOwner({ input, scale, zeroPoint })); + + detail::GraphBuilder* builder = input.Impl()->GetGraphBuilder(); + + TensorDesc inputTensor = input.Impl()->GetOutputDesc(); + TensorDesc scaleTensor = scale.Impl()->GetOutputDesc(); + TensorDesc zeroPointTensor = zeroPoint.Impl()->GetOutputDesc(); + TensorDesc outputTensor(scaleTensor.dataType, inputTensor.sizes, builder->GetTensorPolicy()); + + DML_ELEMENT_WISE_DEQUANTIZE_LINEAR_OPERATOR_DESC desc = {}; + desc.InputTensor = inputTensor.AsPtr(); + desc.ScaleTensor = scaleTensor.AsPtr(); + desc.ZeroPointTensor = zeroPointTensor.AsPtr(); + desc.OutputTensor = outputTensor.AsPtr(); + + detail::NodeOutput* const inputs[] = { input.Impl(), scale.Impl(), zeroPoint.Impl() }; + detail::NodeID node = builder->CreateOperatorNode(DML_OPERATOR_ELEMENT_WISE_DEQUANTIZE_LINEAR, &desc, inputs); + detail::NodeOutput* output = builder->CreateNodeOutput(node, 0, std::move(outputTensor)); + + return output; + } + + inline Expression Sign(Expression a) + { + return detail::ElementWiseUnary(a); + } + + inline Expression IsNaN(Expression input, DML_TENSOR_DATA_TYPE outputDataType = DML_TENSOR_DATA_TYPE_UINT8) + { + return detail::ElementWiseUnary(input, outputDataType); + } + + inline Expression Erf(Expression input, const Optional& scaleBias = NullOpt) + { + return detail::ElementWiseUnary(input, scaleBias); + } + + inline Expression Sinh(Expression input, const Optional& scaleBias = NullOpt) + { + return detail::ElementWiseUnary(input, scaleBias); + } + + inline Expression Cosh(Expression input, const Optional& scaleBias = NullOpt) + { + return detail::ElementWiseUnary(input, scaleBias); + } + + inline Expression Tanh(Expression input, const Optional& scaleBias = NullOpt) + { + return detail::ElementWiseUnary(input, scaleBias); + } + + inline Expression ASinh(Expression input, const Optional& scaleBias = NullOpt) + { + return detail::ElementWiseUnary(input, scaleBias); + } + + inline Expression ACosh(Expression input, const Optional& scaleBias = NullOpt) + { + return detail::ElementWiseUnary(input, scaleBias); + } + + inline Expression ATanh(Expression input, const Optional& scaleBias = NullOpt) + { + return detail::ElementWiseUnary(input, scaleBias); + } + + inline Expression If(Expression condition, Expression a, Expression b) + { + assert(detail::HasSameOwner({ condition, a, b })); + assert(detail::HasSameDataType({ a, b })); + + detail::GraphBuilder* builder = condition.Impl()->GetGraphBuilder(); + + TensorDesc conditionTensor = condition.Impl()->GetOutputDesc(); + assert(conditionTensor.dataType == DML_TENSOR_DATA_TYPE_UINT8); + + TensorDesc aTensor = a.Impl()->GetOutputDesc(); + TensorDesc bTensor = b.Impl()->GetOutputDesc(); + TensorDesc outputTensor(aTensor.dataType, aTensor.sizes, builder->GetTensorPolicy()); + + DML_ELEMENT_WISE_IF_OPERATOR_DESC desc = {}; + desc.ConditionTensor = conditionTensor.AsPtr(); + desc.ATensor = aTensor.AsPtr(); + desc.BTensor = bTensor.AsPtr(); + desc.OutputTensor = outputTensor.AsPtr(); + + detail::NodeOutput* const inputs[] = { condition.Impl(), a.Impl(), b.Impl() }; + detail::NodeID node = builder->CreateOperatorNode(DML_OPERATOR_ELEMENT_WISE_IF, &desc, inputs); + detail::NodeOutput* output = builder->CreateNodeOutput(node, 0, std::move(outputTensor)); + + return output; + } + + inline Expression BitShiftLeft(Expression a, Expression b) + { + return detail::ElementWiseBinary(a, b); + } + + inline Expression BitShiftRight(Expression a, Expression b) + { + return detail::ElementWiseBinary(a, b); + } + + inline Expression BitAnd(Expression a, Expression b) + { + return detail::ElementWiseBinary(a, b); + } + + inline Expression BitOr(Expression a, Expression b) + { + return detail::ElementWiseBinary(a, b); + } + + inline Expression BitXor(Expression a, Expression b) + { + return detail::ElementWiseBinary(a, b); + } + + inline Expression BitNot(Expression a) + { + return detail::ElementWiseUnary(a); + } + + inline Expression BitCount(Expression a, DML_TENSOR_DATA_TYPE outputDataType = DML_TENSOR_DATA_TYPE_UINT8) + { + return detail::ElementWiseUnary(a, outputDataType); + } + + inline Expression Round(Expression input, DML_ROUNDING_MODE roundingMode = DML_ROUNDING_MODE_HALVES_TO_NEAREST_EVEN) + { + detail::GraphBuilder* builder = input.Impl()->GetGraphBuilder(); + + TensorDesc inputTensor = input.Impl()->GetOutputDesc(); + TensorDesc outputTensor(inputTensor.dataType, inputTensor.sizes, builder->GetTensorPolicy()); // Same as input + + DML_ELEMENT_WISE_ROUND_OPERATOR_DESC desc = {}; + desc.InputTensor = inputTensor.AsPtr(); + desc.OutputTensor = outputTensor.AsPtr(); + desc.RoundingMode = roundingMode; + + detail::NodeOutput* const inputs[] = { input.Impl() }; + detail::NodeID node = builder->CreateOperatorNode(DML_OPERATOR_ELEMENT_WISE_ROUND, &desc, inputs); + detail::NodeOutput* output = builder->CreateNodeOutput(node, 0, std::move(outputTensor)); + + return output; + } + + inline Expression IsInfinity( + Expression input, + DML_IS_INFINITY_MODE infinityMode = DML_IS_INFINITY_MODE_EITHER, + DML_TENSOR_DATA_TYPE outputDataType = DML_TENSOR_DATA_TYPE_UINT8) + { + detail::GraphBuilder* builder = input.Impl()->GetGraphBuilder(); + + TensorDesc inputTensor = input.Impl()->GetOutputDesc(); + TensorDesc outputTensor(outputDataType, inputTensor.sizes, builder->GetTensorPolicy()); + + DML_ELEMENT_WISE_IS_INFINITY_OPERATOR_DESC desc = {}; + desc.InputTensor = inputTensor.AsPtr(); + desc.OutputTensor = outputTensor.AsPtr(); + desc.InfinityMode = infinityMode; + + detail::NodeOutput* const inputs[] = { input.Impl() }; + detail::NodeID node = builder->CreateOperatorNode(DML_OPERATOR_ELEMENT_WISE_IS_INFINITY, &desc, inputs); + detail::NodeOutput* output = builder->CreateNodeOutput(node, 0, std::move(outputTensor)); + + return output; + } + + inline Expression ModulusTruncate(Expression a, Expression b) + { + return detail::ElementWiseBinary(a, b); + } + + inline Expression ModulusFloor(Expression a, Expression b) + { + return detail::ElementWiseBinary(a, b); + } + +#pragma region detail +#define DMLX_ACTIVATION_IMPL(_name) \ + do { \ + detail::GraphBuilder* builder = input.Impl()->GetGraphBuilder(); \ + \ + TensorDesc inputTensor = input.Impl()->GetOutputDesc(); \ + TensorDesc outputTensor(inputTensor.dataType, inputTensor.sizes, builder->GetTensorPolicy()); \ + \ + DML_##_name##_OPERATOR_DESC desc = {}; \ + desc.InputTensor = inputTensor.AsPtr(); \ + desc.OutputTensor = outputTensor.AsPtr(); \ + \ + detail::NodeOutput* const inputs[] = { input.Impl() }; \ + detail::NodeID node = builder->CreateOperatorNode(DML_OPERATOR_##_name, &desc, inputs); \ + detail::NodeOutput* output = builder->CreateNodeOutput(node, 0, std::move(outputTensor)); \ + \ + return output; \ + } while(0) + +#define DMLX_ACTIVATION_IMPL_1(_name, _param1Name, _param1) \ + do { \ + detail::GraphBuilder* builder = input.Impl()->GetGraphBuilder(); \ + \ + TensorDesc inputTensor = input.Impl()->GetOutputDesc(); \ + TensorDesc outputTensor(inputTensor.dataType, inputTensor.sizes, builder->GetTensorPolicy()); \ + \ + DML_##_name##_OPERATOR_DESC desc = {}; \ + desc.InputTensor = inputTensor.AsPtr(); \ + desc.OutputTensor = outputTensor.AsPtr(); \ + desc._param1Name = _param1; \ + \ + detail::NodeOutput* const inputs[] = { input.Impl() }; \ + detail::NodeID node = builder->CreateOperatorNode(DML_OPERATOR_##_name, &desc, inputs); \ + detail::NodeOutput* output = builder->CreateNodeOutput(node, 0, std::move(outputTensor)); \ + \ + return output; \ + } while(0) + +#define DMLX_ACTIVATION_IMPL_2(_name, _param1Name, _param1, _param2Name, _param2) \ + do { \ + detail::GraphBuilder* builder = input.Impl()->GetGraphBuilder(); \ + \ + TensorDesc inputTensor = input.Impl()->GetOutputDesc(); \ + TensorDesc outputTensor(inputTensor.dataType, inputTensor.sizes, builder->GetTensorPolicy()); \ + \ + DML_##_name##_OPERATOR_DESC desc = {}; \ + desc.InputTensor = inputTensor.AsPtr(); \ + desc.OutputTensor = outputTensor.AsPtr(); \ + desc._param1Name = _param1; \ + desc._param2Name = _param2; \ + \ + detail::NodeOutput* const inputs[] = { input.Impl() }; \ + detail::NodeID node = builder->CreateOperatorNode(DML_OPERATOR_##_name, &desc, inputs); \ + detail::NodeOutput* output = builder->CreateNodeOutput(node, 0, std::move(outputTensor)); \ + \ + return output; \ + } while(0) +#pragma endregion + + inline Expression ActivationElu(Expression input, float alpha = 1.0f) + { + DMLX_ACTIVATION_IMPL_1(ACTIVATION_ELU, Alpha, alpha); + } + + inline Expression ActivationHardmax(Expression input) + { + DMLX_ACTIVATION_IMPL(ACTIVATION_HARDMAX); + } + + inline Expression ActivationHardSigmoid(Expression input, float alpha = 0.2f, float beta = 0.5f) + { + DMLX_ACTIVATION_IMPL_2(ACTIVATION_HARD_SIGMOID, Alpha, alpha, Beta, beta); + } + + inline Expression ActivationIdentity(Expression input) + { + DMLX_ACTIVATION_IMPL(ACTIVATION_IDENTITY); + } + + inline Expression ActivationLeakyRelu(Expression input, float alpha = 0.01f) + { + DMLX_ACTIVATION_IMPL_1(ACTIVATION_LEAKY_RELU, Alpha, alpha); + } + + inline Expression ActivationLinear(Expression input, float alpha, float beta) + { + DMLX_ACTIVATION_IMPL_2(ACTIVATION_LINEAR, Alpha, alpha, Beta, beta); + } + + inline Expression ActivationLogSoftmax(Expression input) + { + DMLX_ACTIVATION_IMPL(ACTIVATION_LOG_SOFTMAX); + } + + inline Expression ActivationParameterizedRelu(Expression input, Expression slope) + { + assert(detail::HasSameOwner({ input, slope })); + + detail::GraphBuilder* builder = input.Impl()->GetGraphBuilder(); + + TensorDesc inputTensor = input.Impl()->GetOutputDesc(); + TensorDesc slopeTensor = slope.Impl()->GetOutputDesc(); + TensorDesc outputTensor(inputTensor.dataType, inputTensor.sizes, builder->GetTensorPolicy()); + + DML_ACTIVATION_PARAMETERIZED_RELU_OPERATOR_DESC desc = {}; + desc.InputTensor = inputTensor.AsPtr(); + desc.SlopeTensor = slopeTensor.AsPtr(); + desc.OutputTensor = outputTensor.AsPtr(); + + detail::NodeOutput* const inputs[] = { input.Impl(), slope.Impl() }; + detail::NodeID node = builder->CreateOperatorNode(DML_OPERATOR_ACTIVATION_PARAMETERIZED_RELU, &desc, inputs); + detail::NodeOutput* output = builder->CreateNodeOutput(node, 0, std::move(outputTensor)); + + return output; + } + + inline Expression ActivationParametricSoftplus(Expression input, float alpha, float beta) + { + DMLX_ACTIVATION_IMPL_2(ACTIVATION_PARAMETRIC_SOFTPLUS, Alpha, alpha, Beta, beta); + } + + inline Expression ActivationRelu(Expression input) + { + DMLX_ACTIVATION_IMPL(ACTIVATION_RELU); + } + + inline Expression ActivationScaledElu(Expression input, float alpha = 1.67326319217681884765625f, float gamma = 1.05070102214813232421875f) + { + DMLX_ACTIVATION_IMPL_2(ACTIVATION_SCALED_ELU, Alpha, alpha, Gamma, gamma); + } + + inline Expression ActivationScaledTanh(Expression input, float alpha = 1.0f, float beta = 0.5f) + { + DMLX_ACTIVATION_IMPL_2(ACTIVATION_SCALED_TANH, Alpha, alpha, Beta, beta); + } + + inline Expression ActivationSigmoid(Expression input) + { + DMLX_ACTIVATION_IMPL(ACTIVATION_SIGMOID); + } + + inline Expression ActivationSoftmax(Expression input) + { + DMLX_ACTIVATION_IMPL(ACTIVATION_SOFTMAX); + } + + inline Expression ActivationSoftplus(Expression input, float steepness = 1.0f) + { + DMLX_ACTIVATION_IMPL_1(ACTIVATION_SOFTPLUS, Steepness, steepness); + } + + inline Expression ActivationSoftsign(Expression input) + { + DMLX_ACTIVATION_IMPL(ACTIVATION_SOFTSIGN); + } + + inline Expression ActivationTanh(Expression input) + { + DMLX_ACTIVATION_IMPL(ACTIVATION_TANH); + } + + inline Expression ActivationThresholdedRelu(Expression input, float alpha = 1.0f) + { + DMLX_ACTIVATION_IMPL_1(ACTIVATION_THRESHOLDED_RELU, Alpha, alpha); + } + + inline Expression ActivationShrink(Expression input, float bias = 0.0f, float threshold = 0.5f) + { + DMLX_ACTIVATION_IMPL_2(ACTIVATION_SHRINK, Bias, bias, Threshold, threshold); + } + + inline Expression ActivationCelu(Expression input, float alpha = 1.0f) + { + DMLX_ACTIVATION_IMPL_1(ACTIVATION_CELU, Alpha, alpha); + } + +#undef DMLX_ACTIVATION_IMPL +#undef DMLX_ACTIVATION_IMPL_1 +#undef DMLX_ACTIVATION_IMPL_2 + + // --------------------------------------------------------------------------------------------------------------- + + // If not specified, parameters are defaulted to the following values: + // Mode = DML_CONVOLUTION_MODE_CROSS_CORRELATION + // Direction = DML_CONVOLUTION_DIRECTION_FORWARD + // Strides = { 1, 1 } for 2D convolution, { 1, 1, 1 } for 3D convolution + // Dilations = { 1, 1 } for 2D convolution, { 1, 1, 1 } for 3D convolution + // StartPadding = { 0, 0 } for 2D convolution, { 0, 0, 0 } for 3D convolution + // EndPadding = { 0, 0 } for 2D convolution, { 0, 0, 0 } for 3D convolution + // OutputPadding = { 0, 0 } for 2D convolution, { 0, 0, 0 } for 3D convolution + // GroupCount = 1 + // FusedActivation = nullptr + // OutputSizes = computed from other parameters + inline Expression Convolution( + Expression input, + Expression filter, + Optional bias = NullOpt, + DML_CONVOLUTION_MODE mode = DML_CONVOLUTION_MODE_CROSS_CORRELATION, + DML_CONVOLUTION_DIRECTION direction = DML_CONVOLUTION_DIRECTION_FORWARD, + Span strides = {}, + Span dilations = {}, + Span startPadding = {}, + Span endPadding = {}, + Span outputPadding = {}, + uint32_t groupCount = 1, + FusedActivation fusedActivation = FusedActivation::None(), + TensorDimensions outputSizes = {}) + { + assert(detail::HasSameOwner({ input, filter })); + assert(!bias || detail::HasSameOwner({ input, *bias })); + + detail::GraphBuilder* builder = input.Impl()->GetGraphBuilder(); + + TensorDesc inputTensor = input.Impl()->GetOutputDesc(); + TensorDesc filterTensor = filter.Impl()->GetOutputDesc(); + TensorDesc biasTensor; + if (bias) + { + biasTensor = bias->Impl()->GetOutputDesc(); + } + + uint32_t dimensionCount = static_cast(inputTensor.sizes.size()); + + assert(dimensionCount == 4 || dimensionCount == 5); + uint32_t spatialDimensionCount = dimensionCount - 2; + + // If the spatial dimension count is 2, we'll just use the first two elements by setting + // DimensionCount = 2 in the desc + const uint32_t defaultStridesAndDilations[3] = { 1, 1, 1 }; + const uint32_t defaultPadding[3] = { 0, 0, 0 }; + + assert(strides.empty() || strides.size() == spatialDimensionCount); + assert(dilations.empty() || dilations.size() == spatialDimensionCount); + assert(startPadding.empty() || startPadding.size() == spatialDimensionCount); + assert(endPadding.empty() || endPadding.size() == spatialDimensionCount); + assert(outputPadding.empty() || outputPadding.size() == spatialDimensionCount); + assert(outputSizes.empty() || outputSizes.size() == inputTensor.sizes.size()); + + strides = strides.empty() ? Span{ defaultStridesAndDilations } : strides; + dilations = dilations.empty() ? Span{ defaultStridesAndDilations } : dilations; + startPadding = startPadding.empty() ? Span{ defaultPadding } : startPadding; + endPadding = endPadding.empty() ? Span{ defaultPadding } : endPadding; + outputPadding = outputPadding.empty() ? Span{ defaultPadding } : outputPadding; + + // Compute the output shapes + + if (outputSizes.empty()) + { + if (direction == DML_CONVOLUTION_DIRECTION_FORWARD) + { + outputSizes.push_back(inputTensor.sizes[0]); // output[N] = input[N] + outputSizes.push_back(filterTensor.sizes[0]); // output[C] = filter[N] + + for (uint32_t dim = 0; dim < spatialDimensionCount; ++dim) + { + uint32_t inputSize = inputTensor.sizes[dim + 2]; + uint32_t paddedSize = inputSize + startPadding[dim] + endPadding[dim]; + + uint32_t windowSize = filterTensor.sizes[dim + 2]; + uint32_t kernelSize = 1 + (windowSize - 1) * dilations[dim]; + + assert(kernelSize <= paddedSize); + assert(strides[dim] != 0); + + outputSizes.push_back(1 + (paddedSize - kernelSize) / strides[dim]); + } + } + else if (direction == DML_CONVOLUTION_DIRECTION_BACKWARD) + { + // TODO: implement me + assert(false); + } + else + { + assert(false); + DMLX_THROW(E_UNEXPECTED); + } + } + + TensorDesc outputTensor(inputTensor.dataType, std::move(outputSizes), builder->GetTensorPolicy()); + detail::FusedActivationStorage storage; + + DML_CONVOLUTION_OPERATOR_DESC desc = {}; + desc.InputTensor = inputTensor.AsPtr(); + desc.FilterTensor = filterTensor.AsPtr(); + desc.BiasTensor = bias ? biasTensor.AsPtr() : nullptr; + desc.OutputTensor = outputTensor.AsPtr(); + desc.Mode = mode; + desc.Direction = direction; + desc.DimensionCount = spatialDimensionCount; + desc.Strides = strides.data(); + desc.Dilations = dilations.data(); + desc.StartPadding = startPadding.data(); + desc.EndPadding = endPadding.data(); + desc.OutputPadding = outputPadding.data(); + desc.GroupCount = groupCount; + desc.FusedActivation = detail::GetFusedActivationPtr(fusedActivation, &storage); + + detail::NodeOutput* const inputs[] = { input.Impl(), filter.Impl(), bias ? bias->Impl() : nullptr }; + detail::NodeID node = builder->CreateOperatorNode(DML_OPERATOR_CONVOLUTION, &desc, inputs); + detail::NodeOutput* output = builder->CreateNodeOutput(node, 0, std::move(outputTensor)); + + return output; + } + + // Helper for setting parameters for the Convolution operator. Sample usage: + // + // auto conv = dml::ConvolutionBuilder(...) + // .StartPadding(...) + // .EndPadding(...) + // .Strides(...) + // .Build(); + // + // Parameters left unspecified will be defaulted with the same values as dml::Convolution(). + class ConvolutionBuilder + { + public: + ConvolutionBuilder(Expression input, Expression filter, Optional bias = NullOpt) + : m_input(input), m_filter(filter), m_bias(bias) + {} + + ConvolutionBuilder& Mode(DML_CONVOLUTION_MODE mode) { m_mode = mode; return *this; } + ConvolutionBuilder& Direction(DML_CONVOLUTION_DIRECTION direction) { m_direction = direction; return *this; } + ConvolutionBuilder& Strides(Span strides) { m_strides.assign(strides.begin(), strides.end()); return *this; } + ConvolutionBuilder& Dilations(Span dilations) { m_dilations.assign(dilations.begin(), dilations.end()); return *this; } + ConvolutionBuilder& StartPadding(Span startPadding) { m_startPadding.assign(startPadding.begin(), startPadding.end()); return *this; } + ConvolutionBuilder& EndPadding(Span endPadding) { m_endPadding.assign(endPadding.begin(), endPadding.end()); return *this; } + ConvolutionBuilder& OutputPadding(Span outputPadding) { m_outputPadding.assign(outputPadding.begin(), outputPadding.end()); return *this; } + ConvolutionBuilder& GroupCount(uint32_t groupCount) { m_groupCount = groupCount; return *this; } + ConvolutionBuilder& FusedActivation(FusedActivation fusedActivation) { m_fusedActivation = fusedActivation; return *this; } + ConvolutionBuilder& OutputSizes(TensorDimensions outputSizes) { m_outputSizes = std::move(outputSizes); return *this; } + + Expression Build() const + { + return Convolution( + m_input, + m_filter, + m_bias, + m_mode, + m_direction, + m_strides, + m_dilations, + m_startPadding, + m_endPadding, + m_outputPadding, + m_groupCount, + m_fusedActivation, + m_outputSizes); + } + + private: + Expression m_input; + Expression m_filter; + Optional m_bias; + DML_CONVOLUTION_MODE m_mode = DML_CONVOLUTION_MODE_CROSS_CORRELATION; + DML_CONVOLUTION_DIRECTION m_direction = DML_CONVOLUTION_DIRECTION_FORWARD; + SmallVector m_strides = {}; + SmallVector m_dilations = {}; + SmallVector m_startPadding = {}; + SmallVector m_endPadding = {}; + SmallVector m_outputPadding = {}; + uint32_t m_groupCount = 1; + dml::FusedActivation m_fusedActivation; + TensorDimensions m_outputSizes = {}; + }; + + // --------------------------------------------------------------------------------------------------------------- + + inline Expression Gemm( + Expression a, + Expression b, + Optional c = NullOpt, + DML_MATRIX_TRANSFORM transA = DML_MATRIX_TRANSFORM_NONE, + DML_MATRIX_TRANSFORM transB = DML_MATRIX_TRANSFORM_NONE, + float alpha = 1.0f, + float beta = 1.0f, + FusedActivation fusedActivation = FusedActivation::None()) + { + assert(detail::HasSameOwner({ a, b })); + assert(!c || detail::HasSameOwner({ a, *c })); + + detail::GraphBuilder* builder = a.Impl()->GetGraphBuilder(); + + TensorDesc aTensor = a.Impl()->GetOutputDesc(); + TensorDesc bTensor = b.Impl()->GetOutputDesc(); + TensorDesc cTensor; + if (c) + { + cTensor = c->Impl()->GetOutputDesc(); + } + + TensorDimensions outputSizes; + outputSizes.push_back(aTensor.sizes[0]); // output[N] = input[N] + outputSizes.push_back(aTensor.sizes[1]); // output[C] = input[C] + outputSizes.push_back(transA == DML_MATRIX_TRANSFORM_NONE ? aTensor.sizes[2] : aTensor.sizes[3]); + outputSizes.push_back(transB == DML_MATRIX_TRANSFORM_NONE ? bTensor.sizes[3] : bTensor.sizes[2]); + + TensorDesc outputTensor(aTensor.dataType, std::move(outputSizes), builder->GetTensorPolicy()); + detail::FusedActivationStorage storage; + + DML_GEMM_OPERATOR_DESC desc = {}; + desc.ATensor = aTensor.AsPtr(); + desc.BTensor = bTensor.AsPtr(); + desc.CTensor = c ? cTensor.AsPtr() : nullptr; + desc.OutputTensor = outputTensor.AsPtr(); + desc.TransA = transA; + desc.TransB = transB; + desc.Alpha = alpha; + desc.Beta = beta; + desc.FusedActivation = detail::GetFusedActivationPtr(fusedActivation, &storage); + + detail::NodeOutput* const inputs[] = { a.Impl(), b.Impl(), c ? c->Impl() : nullptr }; + detail::NodeID node = builder->CreateOperatorNode(DML_OPERATOR_GEMM, &desc, inputs); + detail::NodeOutput* output = builder->CreateNodeOutput(node, 0, std::move(outputTensor)); + + return output; + } + + // Helper for setting parameters for the Gemm operator. Parameters left unspecified will be defaulted with the + // same values as dml::Gemm(). + class GemmBuilder + { + public: + GemmBuilder(Expression a, Expression b, Optional c = NullOpt) + : m_a(a), m_b(b), m_c(c) + {} + + GemmBuilder& TransA(DML_MATRIX_TRANSFORM transA) { m_transA = transA; return *this; } + GemmBuilder& TransB(DML_MATRIX_TRANSFORM transB) { m_transB = transB; return *this; } + GemmBuilder& Alpha(float alpha) { m_alpha = alpha; return *this; } + GemmBuilder& Beta(float beta) { m_beta = beta; return *this; } + GemmBuilder& FusedActivation(FusedActivation fusedActivation) { m_fusedActivation = fusedActivation; return *this; } + + Expression Build() const + { + return Gemm(m_a, m_b, m_c, m_transA, m_transB, m_alpha, m_beta, m_fusedActivation); + } + + private: + Expression m_a; + Expression m_b; + Optional m_c; + DML_MATRIX_TRANSFORM m_transA = DML_MATRIX_TRANSFORM_NONE; + DML_MATRIX_TRANSFORM m_transB = DML_MATRIX_TRANSFORM_NONE; + float m_alpha = 1.0f; + float m_beta = 1.0f; + dml::FusedActivation m_fusedActivation; + }; + + // --------------------------------------------------------------------------------------------------------------- + + // If `axes` is not specified, by default this reduces the entire tensor to single element. + inline Expression Reduce( + Expression input, + DML_REDUCE_FUNCTION function, + Span axes = {}, + DML_TENSOR_DATA_TYPE outputDataType = DML_TENSOR_DATA_TYPE_UNKNOWN) + { + detail::GraphBuilder* builder = input.Impl()->GetGraphBuilder(); + TensorDesc inputTensor = input.Impl()->GetOutputDesc(); + uint32_t dimensionCount = static_cast(inputTensor.sizes.size()); + + SmallVector defaultAxes; + if (axes.empty()) + { + for (uint32_t i = 0; i < dimensionCount; ++i) + { + defaultAxes.push_back(i); + } + axes = defaultAxes; + } + + // Compute the output tensor dimensions + TensorDimensions outputSizes; + for (uint32_t i = 0; i < dimensionCount; ++i) + { + // If the dimension is to be reduced, this dimension in the output tensor has a size of 1, otherwise + // it matches the input tensor. + const bool dimensionIsReduced = std::find(axes.begin(), axes.end(), i) != axes.end(); + if (dimensionIsReduced) + { + outputSizes.push_back(1); + } + else + { + outputSizes.push_back(inputTensor.sizes[i]); + } + } + + // All reductions other than ARGMIN and ARGMAX produce an output with the same type + // as the input. + if (outputDataType == DML_TENSOR_DATA_TYPE_UNKNOWN) + { + if (function == DML_REDUCE_FUNCTION_ARGMIN || function == DML_REDUCE_FUNCTION_ARGMAX) + { + // Default to UINT32 if the output type wasn't specified + outputDataType = DML_TENSOR_DATA_TYPE_UINT32; + } + else + { + outputDataType = inputTensor.dataType; + } + } + + TensorDesc outputTensor(outputDataType, std::move(outputSizes), builder->GetTensorPolicy()); + + DML_REDUCE_OPERATOR_DESC desc = {}; + desc.Function = function; + desc.InputTensor = inputTensor.AsPtr(); + desc.OutputTensor = outputTensor.AsPtr(); + desc.AxisCount = static_cast(axes.size()); + desc.Axes = axes.data(); + + detail::NodeOutput* const inputs[] = { input.Impl() }; + detail::NodeID node = builder->CreateOperatorNode(DML_OPERATOR_REDUCE, &desc, inputs); + detail::NodeOutput* output = builder->CreateNodeOutput(node, 0, std::move(outputTensor)); + + return output; + } + + inline Expression AveragePooling( + Expression input, + Span strides, + Span windowSizes, + Span startPadding, + Span endPadding, +#if DML_TARGET_VERSION >= 0x6200 + Span dilations, +#endif // DML_TARGET_VERSION >= 0x6200 + bool includePadding, + TensorDimensions outputSizes = {}) + { + detail::GraphBuilder* builder = input.Impl()->GetGraphBuilder(); + + TensorDesc inputTensor = input.Impl()->GetOutputDesc(); + +#if DML_TARGET_VERSION >= 0x6200 + const uint32_t defaultStridesAndDilations[3] = { 1, 1, 1 }; + + DML_AVERAGE_POOLING1_OPERATOR_DESC averagePoolDesc = {}; + // dilations must be omitted or have the same rank as the spatial dimension count (inputTensor rank - 2) + assert(dilations.empty() || dilations.size() == inputTensor.sizes.size() - 2); + averagePoolDesc.Dilations = dilations.empty() ? defaultStridesAndDilations : dilations.data(); +#else + DML_AVERAGE_POOLING_OPERATOR_DESC averagePoolDesc = {}; +#endif // DML_TARGET_VERSION >= 0x6200 + + assert(strides.size() == windowSizes.size()); + assert(strides.size() == startPadding.size()); + assert(strides.size() == endPadding.size()); + + // Calculate output size, if not explicitly provided + if (outputSizes.empty()) + { + outputSizes.push_back(inputTensor.sizes[0]); // N + outputSizes.push_back(inputTensor.sizes[1]); // C + for (size_t i = 0; i < windowSizes.size(); ++i) + { + uint32_t paddedInputSize = inputTensor.sizes[2 + i] + startPadding[i] + endPadding[i]; + uint32_t outputSize = (paddedInputSize - windowSizes[i]) / strides[i] + 1; + outputSizes.push_back(outputSize); + } + } + + TensorDesc outputTensor(inputTensor.dataType, std::move(outputSizes), builder->GetTensorPolicy()); + + averagePoolDesc.InputTensor = inputTensor.AsPtr(); + averagePoolDesc.OutputTensor = outputTensor.AsPtr(); + averagePoolDesc.DimensionCount = static_cast(windowSizes.size()); + averagePoolDesc.Strides = strides.data(); + averagePoolDesc.WindowSize = windowSizes.data(); + averagePoolDesc.StartPadding = startPadding.data(); + averagePoolDesc.EndPadding = endPadding.data(); + averagePoolDesc.IncludePadding = includePadding; + + detail::NodeOutput* const inputs[] = { input.Impl() }; +#if DML_TARGET_VERSION >= 0x6200 + detail::NodeID node = builder->CreateOperatorNode(DML_OPERATOR_AVERAGE_POOLING1, &averagePoolDesc, inputs); +#else + detail::NodeID node = builder->CreateOperatorNode(DML_OPERATOR_AVERAGE_POOLING, &averagePoolDesc, inputs); +#endif // DML_TARGET_VERSION >= 0x6200 + + detail::NodeOutput* output = builder->CreateNodeOutput(node, 0, std::move(outputTensor)); + + return output; + } + + // + // TODO: LpPooling + // + + // --------------------------------------------------------------------------------------------------------------- + + struct MaxPoolingOutputs + { + Expression values; + Expression indices; // Only valid if outputIndices = true is supplied to MaxPooling() + }; + + // If not specified, parameters are defaulted to the following values: + // Strides = 1 for each spatial dimension + // StartPadding = 0 for each spatial dimension + // EndPadding = 0 for each spatial dimension + // Dilations = 1 for each spatial dimension + // OutputIndices = false + inline MaxPoolingOutputs MaxPooling( + Expression input, + Span windowSize, + Span strides = {}, + Span startPadding = {}, + Span endPadding = {}, + Span dilations = {}, + bool outputIndices = false, + TensorDimensions outputSizes = {}) + { + detail::GraphBuilder* builder = input.Impl()->GetGraphBuilder(); + + TensorDesc inputTensor = input.Impl()->GetOutputDesc(); + + // If the spatial dimension count is 2, we'll just use the first two elements by setting + // DimensionCount = 2 in the desc + const uint32_t defaultStridesAndDilations[3] = { 1, 1, 1 }; + const uint32_t defaultPadding[3] = { 0, 0, 0 }; + + assert(windowSize.size() == 2 || windowSize.size() == 3); + assert(strides.empty() || strides.size() == windowSize.size()); + assert(dilations.empty() || dilations.size() == windowSize.size()); + assert(startPadding.empty() || startPadding.size() == windowSize.size()); + assert(endPadding.empty() || endPadding.size() == windowSize.size()); + + strides = strides.empty() ? Span{ defaultStridesAndDilations } : strides; + dilations = dilations.empty() ? Span{ defaultStridesAndDilations } : dilations; + startPadding = startPadding.empty() ? Span{ defaultPadding } : startPadding; + endPadding = endPadding.empty() ? Span{ defaultPadding } : endPadding; + + // Calculate output size, if not explicitly provided + if (outputSizes.empty()) + { + outputSizes.push_back(inputTensor.sizes[0]); // N + outputSizes.push_back(inputTensor.sizes[1]); // C + for (size_t i = 0; i < windowSize.size(); i++) + { + uint32_t paddedInputSize = inputTensor.sizes[2 + i] + startPadding[i] + endPadding[i]; + uint32_t dilatedWindowSize = 1 + (windowSize[i] - 1) * dilations[i]; + uint32_t outputSize = (dilatedWindowSize >= paddedInputSize) ? 1 : (paddedInputSize - dilatedWindowSize) / strides[i] + 1; + outputSizes.push_back(outputSize); + } + } + + TensorDesc outputTensor(inputTensor.dataType, outputSizes, builder->GetTensorPolicy()); + TensorDesc outputIndicesTensor(DML_TENSOR_DATA_TYPE_UINT32, std::move(outputSizes), builder->GetTensorPolicy()); + + DML_MAX_POOLING2_OPERATOR_DESC desc = {}; + desc.InputTensor = inputTensor.AsPtr(); + desc.OutputTensor = outputTensor.AsPtr(); + desc.OutputIndicesTensor = outputIndices ? outputIndicesTensor.AsPtr() : nullptr; + desc.DimensionCount = static_cast(windowSize.size()); + desc.Strides = strides.data(); + desc.WindowSize = windowSize.data(); + desc.StartPadding = startPadding.data(); + desc.EndPadding = endPadding.data(); + desc.Dilations = dilations.data(); + + detail::NodeOutput* const inputs[] = { input.Impl() }; + detail::NodeID node = builder->CreateOperatorNode(DML_OPERATOR_MAX_POOLING2, &desc, inputs); + + detail::NodeOutput* outputExpr = builder->CreateNodeOutput(node, 0, std::move(outputTensor)); + if (outputIndices) + { + detail::NodeOutput* outputIndicesExpr = builder->CreateNodeOutput(node, 1, std::move(outputIndicesTensor)); + return { outputExpr, outputIndicesExpr }; + } + return { outputExpr, Expression() }; + } + + // Helper for setting parameters for the MaxPooling operator. Sample usage: + // + // auto [out, outIndices] = dml::MaxPoolingBuilder(...) + // .StartPadding(...) + // .EndPadding(...) + // .OutputIndices(...) + // .Build(); + // + // Parameters left unspecified will be defaulted with the same values as dml::MaxPooling(). + class MaxPoolingBuilder + { + public: + MaxPoolingBuilder(Expression input, Span windowSize) + : m_input(input), m_windowSize(windowSize.begin(), windowSize.end()) + {} + + MaxPoolingBuilder& Strides(Span strides) { m_strides.assign(strides.begin(), strides.end()); return *this; } + MaxPoolingBuilder& StartPadding(Span startPadding) { m_startPadding.assign(startPadding.begin(), startPadding.end()); return *this; } + MaxPoolingBuilder& EndPadding(Span endPadding) { m_endPadding.assign(endPadding.begin(), endPadding.end()); return *this; } + MaxPoolingBuilder& Dilations(Span dilations) { m_dilations.assign(dilations.begin(), dilations.end()); return *this; } + MaxPoolingBuilder& OutputIndices(bool outputIndices) { m_outputIndices = outputIndices; return *this; } + MaxPoolingBuilder& OutputSizes(TensorDimensions outputSizes) { m_outputSizes = std::move(outputSizes); return *this; } + + MaxPoolingOutputs Build() const + { + return MaxPooling( + m_input, + m_windowSize, + m_strides, + m_startPadding, + m_endPadding, + m_dilations, + m_outputIndices, + m_outputSizes); + } + + private: + Expression m_input; + SmallVector m_windowSize; + SmallVector m_strides = {}; + SmallVector m_startPadding = {}; + SmallVector m_endPadding = {}; + SmallVector m_dilations = {}; + bool m_outputIndices = false; + TensorDimensions m_outputSizes = {}; + }; + + // --------------------------------------------------------------------------------------------------------------- + + // + // TODO: MaxUnpooling + // + + // + // TODO: ROIPooling + // + + inline Expression Slice( + Expression input, + Span inputWindowOffsets, + Span inputWindowSizes, + Span inputWindowStrides) + { + detail::GraphBuilder* builder = input.Impl()->GetGraphBuilder(); + + TensorDesc inputTensor = input.Impl()->GetOutputDesc(); + TensorDimensions outputSizes(inputTensor.sizes); + + assert(inputWindowOffsets.size() == outputSizes.size()); + assert(inputWindowOffsets.size() == inputWindowStrides.size()); + assert(inputWindowOffsets.size() == inputWindowSizes.size()); + + for (size_t i = 0; i < outputSizes.size(); i++) + { + uint32_t minimumInputSize = (inputWindowSizes[i] - 1) / abs(inputWindowStrides[i]) + 1; + outputSizes[i] = minimumInputSize; + } + + TensorDesc outputTensor(inputTensor.dataType, std::move(outputSizes), builder->GetTensorPolicy()); + + DML_SLICE1_OPERATOR_DESC sliceDesc = {}; + sliceDesc.InputTensor = inputTensor.AsPtr(); + sliceDesc.OutputTensor = outputTensor.AsPtr(); + sliceDesc.DimensionCount = static_cast(inputWindowOffsets.size()); + sliceDesc.InputWindowOffsets = inputWindowOffsets.data(); + sliceDesc.InputWindowSizes = inputWindowSizes.data(); + sliceDesc.InputWindowStrides = inputWindowStrides.data(); + + detail::NodeOutput* const inputs[] = { input.Impl() }; + detail::NodeID node = builder->CreateOperatorNode(DML_OPERATOR_SLICE1, &sliceDesc, inputs); + detail::NodeOutput* output = builder->CreateNodeOutput(node, 0, std::move(outputTensor)); + + return output; + } + + inline Expression Cast(Expression input, DML_TENSOR_DATA_TYPE targetDataType) + { + detail::GraphBuilder* builder = input.Impl()->GetGraphBuilder(); + + TensorDesc inputTensor = input.Impl()->GetOutputDesc(); + TensorDesc outputTensor(targetDataType, inputTensor.sizes, builder->GetTensorPolicy()); + + DML_CAST_OPERATOR_DESC desc = {}; + desc.InputTensor = inputTensor.AsPtr(); + desc.OutputTensor = outputTensor.AsPtr(); + + detail::NodeOutput* const inputs[] = { input.Impl() }; + detail::NodeID node = builder->CreateOperatorNode(DML_OPERATOR_CAST, &desc, inputs); + detail::NodeOutput* output = builder->CreateNodeOutput(node, 0, std::move(outputTensor)); + + return output; + } + + inline std::vector Split( + Expression input, + uint32_t axis, + Span outputAxisSizes) + { + TensorDesc inputTensor = input.Impl()->GetOutputDesc(); + detail::GraphBuilder* builder = input.Impl()->GetGraphBuilder(); + uint32_t axisSizeSum = 0; + + std::vector outputTensors; + outputTensors.reserve(outputAxisSizes.size()); + + std::vector outputDescs; + outputDescs.reserve(outputAxisSizes.size()); + + for (uint32_t outputAxisSize : outputAxisSizes) + { + TensorDimensions outputSizes = inputTensor.sizes; + outputSizes[axis] = outputAxisSize; + + TensorDesc tensorDesc(inputTensor.dataType, outputSizes, builder->GetTensorPolicy()); + outputTensors.push_back(std::move(tensorDesc)); + outputDescs.push_back(*outputTensors.back().AsPtr()); + + axisSizeSum += outputAxisSize; + } + + assert(axisSizeSum == inputTensor.sizes[axis]); + + DML_SPLIT_OPERATOR_DESC desc = {}; + desc.Axis = axis; + desc.InputTensor = inputTensor.AsPtr(); + desc.OutputTensors = outputDescs.data(); + desc.OutputCount = static_cast(outputAxisSizes.size()); + + detail::NodeOutput* const inputs[] = { input.Impl() }; + detail::NodeID node = builder->CreateOperatorNode(DML_OPERATOR_SPLIT, &desc, inputs); + + std::vector outputs; + outputs.reserve(outputAxisSizes.size()); + + for (uint32_t i = 0; i < outputAxisSizes.size(); ++i) + { + outputs.push_back(builder->CreateNodeOutput(node, i, std::move(outputTensors[i]))); + } + + return outputs; + } + + inline Expression Join( + Span inputs, + uint32_t axis) + { + assert(!inputs.empty()); + + detail::GraphBuilder* builder = inputs[0].Impl()->GetGraphBuilder(); + DML_TENSOR_DATA_TYPE dataType = inputs[0].Impl()->GetOutputDesc().dataType; + + TensorDimensions outputSizes = inputs[0].Impl()->GetOutputDesc().sizes; + outputSizes[axis] = 0; + + std::vector inputTensors; + inputTensors.reserve(inputs.size()); + + std::vector inputDescs; + inputDescs.reserve(inputs.size()); + + std::vector inputNodes; + inputNodes.reserve(inputs.size()); + + for (Expression input : inputs) + { + inputTensors.push_back(input.Impl()->GetOutputDesc()); + TensorDesc& inputTensor = inputTensors.back(); + outputSizes[axis] += inputTensor.sizes[axis]; + inputDescs.push_back(*inputTensor.AsPtr()); + inputNodes.push_back(input.Impl()); + } + + TensorDesc outputTensor(dataType, std::move(outputSizes), builder->GetTensorPolicy()); + + DML_JOIN_OPERATOR_DESC desc = {}; + desc.Axis = axis; + desc.InputCount = static_cast(inputDescs.size()); + desc.InputTensors = inputDescs.data(); + desc.OutputTensor = outputTensor.AsPtr(); + + detail::NodeID node = builder->CreateOperatorNode(DML_OPERATOR_JOIN, &desc, inputNodes); + detail::NodeOutput* output = builder->CreateNodeOutput(node, 0, std::move(outputTensor)); + + return output; + } + + inline Expression Padding( + Expression input, + DML_PADDING_MODE paddingMode, + float paddingValue, + Span startPadding, + Span endPadding) + { + detail::GraphBuilder* builder = input.Impl()->GetGraphBuilder(); + + TensorDesc inputTensor = input.Impl()->GetOutputDesc(); + TensorDimensions outputSizes = inputTensor.sizes; + + assert(outputSizes.size() == startPadding.size()); + assert(outputSizes.size() == endPadding.size()); + + for (size_t i = 0; i < outputSizes.size(); i++) + { + outputSizes[i] += startPadding[i] + endPadding[i]; + } + + TensorDesc outputTensor(inputTensor.dataType, std::move(outputSizes), builder->GetTensorPolicy()); + + DML_PADDING_OPERATOR_DESC paddingDesc = {}; + paddingDesc.InputTensor = inputTensor.AsPtr(); + paddingDesc.OutputTensor = outputTensor.AsPtr(); + paddingDesc.PaddingMode = paddingMode; + paddingDesc.PaddingValue = paddingValue; + paddingDesc.DimensionCount = static_cast(startPadding.size()); + paddingDesc.StartPadding = startPadding.data(); + paddingDesc.EndPadding = endPadding.data(); + + detail::NodeOutput* const inputs[] = { input.Impl() }; + detail::NodeID node = builder->CreateOperatorNode(DML_OPERATOR_PADDING, &paddingDesc, inputs); + detail::NodeOutput* output = builder->CreateNodeOutput(node, 0, std::move(outputTensor)); + + return output; + } + + inline Expression ValueScale2D( + Expression input, + float scale, + Span bias) + { + detail::GraphBuilder* builder = input.Impl()->GetGraphBuilder(); + TensorDesc inputTensor = input.Impl()->GetOutputDesc(); + TensorDesc outputTensor(inputTensor.dataType, inputTensor.sizes, builder->GetTensorPolicy()); + + DML_VALUE_SCALE_2D_OPERATOR_DESC desc = {}; + desc.InputTensor = inputTensor.AsPtr(); + desc.OutputTensor = outputTensor.AsPtr(); + desc.Scale = scale; + desc.ChannelCount = static_cast(bias.size()); + desc.Bias = bias.data(); + + detail::NodeOutput* const inputs[] = { input.Impl() }; + detail::NodeID node = builder->CreateOperatorNode(DML_OPERATOR_VALUE_SCALE_2D, &desc, inputs); + detail::NodeOutput* output = builder->CreateNodeOutput(node, 0, std::move(outputTensor)); + + return output; + } + + inline Expression Upsample2D(Expression input, DML_SIZE_2D scaleSize, DML_INTERPOLATION_MODE interpolationMode) + { + detail::GraphBuilder* builder = input.Impl()->GetGraphBuilder(); + + TensorDesc inputTensor = input.Impl()->GetOutputDesc(); + assert(inputTensor.sizes.size() == 4 || inputTensor.sizes.size() == 5); + + uint32_t i = 0; + TensorDimensions outputSizes; + outputSizes.push_back(inputTensor.sizes[i++]); // output[N] = input[N] + outputSizes.push_back(inputTensor.sizes[i++]); // output[C] = input[C] + if (inputTensor.sizes.size() == 5) + { + outputSizes.push_back(inputTensor.sizes[i++]); // output[D] = input[D] + } + outputSizes.push_back(inputTensor.sizes[i++] * scaleSize.Height); // output[H] = input[H] * scaleH + outputSizes.push_back(inputTensor.sizes[i++] * scaleSize.Width); // output[W] = input[W] * scaleW + TensorDesc outputTensor(inputTensor.dataType, std::move(outputSizes), builder->GetTensorPolicy()); + + DML_UPSAMPLE_2D_OPERATOR_DESC desc = {}; + desc.InputTensor = inputTensor.AsPtr(); + desc.OutputTensor = outputTensor.AsPtr(); + desc.ScaleSize = scaleSize; + desc.InterpolationMode = interpolationMode; + + detail::NodeOutput* const inputs[] = { input.Impl() }; + detail::NodeID node = builder->CreateOperatorNode(DML_OPERATOR_UPSAMPLE_2D, &desc, inputs); + detail::NodeOutput* output = builder->CreateNodeOutput(node, 0, std::move(outputTensor)); + + return output; + } + + inline Expression Gather( + Expression input, + Expression indices, + uint32_t axis, + uint32_t indexDimensions) + { + detail::GraphBuilder* builder = input.Impl()->GetGraphBuilder(); + TensorDesc inputTensor = input.Impl()->GetOutputDesc(); + TensorDesc indicesTensor = indices.Impl()->GetOutputDesc(); + + uint32_t dimensionCount = static_cast(inputTensor.sizes.size()); + assert(indicesTensor.sizes.size() == dimensionCount); + assert(axis < dimensionCount); + assert(indexDimensions <= dimensionCount); + + TensorDimensions outputSizes(dimensionCount, 1); + + // All dimensions after the axis should be the same as the input + int outputDim = static_cast(dimensionCount) - 1; + for (; static_cast(outputDim) > axis; --outputDim) + { + outputSizes[outputDim] = inputTensor.sizes[outputDim]; + } + + // All dimensions within the range [axis - indexDimensions, axis] should be the same as the indices + int indexDim = static_cast(dimensionCount) - 1; + for (; outputDim > static_cast(axis) - static_cast(indexDimensions); --outputDim, --indexDim) + { + outputSizes[outputDim] = indicesTensor.sizes[indexDim]; + } + + // All dimensions before (axis - indexDimensions) should be the same as the input + int inputDim = axis - 1; + for (; outputDim >= 0 && inputDim >= 0; --outputDim, --inputDim) + { + outputSizes[outputDim] = inputTensor.sizes[inputDim]; + } + + TensorDesc outputTensor(inputTensor.dataType, std::move(outputSizes), builder->GetTensorPolicy()); + + DML_GATHER_OPERATOR_DESC desc = {}; + desc.InputTensor = inputTensor.AsPtr(); + desc.IndicesTensor = indicesTensor.AsPtr(); + desc.OutputTensor = outputTensor.AsPtr(); + desc.Axis = axis; + desc.IndexDimensions = indexDimensions; + + detail::NodeOutput* const inputs[] = { input.Impl(), indices.Impl() }; + detail::NodeID node = builder->CreateOperatorNode(DML_OPERATOR_GATHER, &desc, inputs); + detail::NodeOutput* output = builder->CreateNodeOutput(node, 0, std::move(outputTensor)); + + return output; + } + + inline Expression GatherElements( + Expression input, + Expression indices, + uint32_t axis) + { + detail::GraphBuilder* builder = input.Impl()->GetGraphBuilder(); + TensorDesc inputTensor = input.Impl()->GetOutputDesc(); + TensorDesc indicesTensor = indices.Impl()->GetOutputDesc(); + + TensorDesc outputTensor(inputTensor.dataType, indicesTensor.sizes, builder->GetTensorPolicy()); + + DML_GATHER_ELEMENTS_OPERATOR_DESC desc = {}; + desc.InputTensor = inputTensor.AsPtr(); + desc.IndicesTensor = indicesTensor.AsPtr(); + desc.OutputTensor = outputTensor.AsPtr(); + desc.Axis = axis; + + detail::NodeOutput* const inputs[] = { input.Impl(), indices.Impl() }; + detail::NodeID node = builder->CreateOperatorNode(DML_OPERATOR_GATHER_ELEMENTS, &desc, inputs); + detail::NodeOutput* output = builder->CreateNodeOutput(node, 0, std::move(outputTensor)); + + return output; + } + + inline Expression GatherND( + Expression input, + Expression indices, + uint32_t inputDimensionCount, + uint32_t indicesDimensionCount, + uint32_t batchDimensionCount) + { + detail::GraphBuilder* builder = input.Impl()->GetGraphBuilder(); + TensorDesc inputTensor = input.Impl()->GetOutputDesc(); + TensorDesc indicesTensor = indices.Impl()->GetOutputDesc(); + + assert(inputDimensionCount >= 1u && inputDimensionCount <= inputTensor.sizes.size()); + assert(indicesDimensionCount >= 1u && indicesDimensionCount <= indicesTensor.sizes.size()); + assert(batchDimensionCount < inputDimensionCount); + assert(batchDimensionCount < indicesDimensionCount); + + uint32_t numberOfCoordinatesPerIndex = indicesTensor.sizes.back(); + assert(numberOfCoordinatesPerIndex >= 1u && numberOfCoordinatesPerIndex <= inputDimensionCount - batchDimensionCount); + + uint32_t numberOfOutputDimensionsFromInput = inputDimensionCount - batchDimensionCount - numberOfCoordinatesPerIndex; + uint32_t outputPaddingAmount = static_cast(inputTensor.sizes.size()) - (indicesDimensionCount + numberOfOutputDimensionsFromInput - 1); + + TensorDimensions outputSizes(outputPaddingAmount, 1); + outputSizes.insert(outputSizes.end(), indicesTensor.sizes.end() - indicesDimensionCount, indicesTensor.sizes.end() - 1); + outputSizes.insert(outputSizes.end(), inputTensor.sizes.end() - numberOfOutputDimensionsFromInput, inputTensor.sizes.end()); + + TensorDesc outputTensor(inputTensor.dataType, std::move(outputSizes), builder->GetTensorPolicy()); + + DML_GATHER_ND1_OPERATOR_DESC desc = {}; + desc.InputTensor = inputTensor.AsPtr(); + desc.IndicesTensor = indicesTensor.AsPtr(); + desc.OutputTensor = outputTensor.AsPtr(); + desc.InputDimensionCount = inputDimensionCount; + desc.IndicesDimensionCount = indicesDimensionCount; + desc.BatchDimensionCount = batchDimensionCount; + + detail::NodeOutput* const inputs[] = { input.Impl(), indices.Impl() }; + detail::NodeID node = builder->CreateOperatorNode(DML_OPERATOR_GATHER_ND1, &desc, inputs); + detail::NodeOutput* output = builder->CreateNodeOutput(node, 0, std::move(outputTensor)); + + return output; + } + + inline Expression ScatterElements( + Expression input, + Expression indices, + Expression updates, + uint32_t axis) + { + detail::GraphBuilder* builder = input.Impl()->GetGraphBuilder(); + TensorDesc inputTensor = input.Impl()->GetOutputDesc(); + TensorDesc indicesTensor = indices.Impl()->GetOutputDesc(); + TensorDesc updatesTensor = updates.Impl()->GetOutputDesc(); + TensorDesc outputTensor(inputTensor.dataType, inputTensor.sizes, builder->GetTensorPolicy()); + + DML_SCATTER_ELEMENTS_OPERATOR_DESC desc = {}; + desc.InputTensor = inputTensor.AsPtr(); + desc.IndicesTensor = indicesTensor.AsPtr(); + desc.UpdatesTensor = updatesTensor.AsPtr(); + desc.OutputTensor = outputTensor.AsPtr(); + desc.Axis = axis; + + detail::NodeOutput* const inputs[] = { input.Impl(), indices.Impl(), updates.Impl() }; + detail::NodeID node = builder->CreateOperatorNode(DML_OPERATOR_SCATTER_ELEMENTS, &desc, inputs); + detail::NodeOutput* output = builder->CreateNodeOutput(node, 0, std::move(outputTensor)); + + return output; + } + + inline Expression ScatterND( + Expression input, + Expression indices, + Expression updates, + uint32_t inputDimensionCount, + uint32_t indicesDimensionCount) + { + detail::GraphBuilder* builder = input.Impl()->GetGraphBuilder(); + TensorDesc inputTensor = input.Impl()->GetOutputDesc(); + TensorDesc indicesTensor = indices.Impl()->GetOutputDesc(); + TensorDesc updatesTensor = updates.Impl()->GetOutputDesc(); + TensorDesc outputTensor(inputTensor.dataType, inputTensor.sizes, builder->GetTensorPolicy()); + + DML_SCATTER_ND_OPERATOR_DESC desc = {}; + desc.InputTensor = inputTensor.AsPtr(); + desc.IndicesTensor = indicesTensor.AsPtr(); + desc.UpdatesTensor = updatesTensor.AsPtr(); + desc.OutputTensor = outputTensor.AsPtr(); + desc.InputDimensionCount = inputDimensionCount; + desc.IndicesDimensionCount = indicesDimensionCount; + + detail::NodeOutput* const inputs[] = { input.Impl(), indices.Impl(), updates.Impl() }; + detail::NodeID node = builder->CreateOperatorNode(DML_OPERATOR_SCATTER_ND, &desc, inputs); + detail::NodeOutput* output = builder->CreateNodeOutput(node, 0, std::move(outputTensor)); + + return output; + } + + inline Expression SpaceToDepth( + Expression input, + uint32_t blockSize, + DML_DEPTH_SPACE_ORDER order = DML_DEPTH_SPACE_ORDER_DEPTH_COLUMN_ROW) + { + detail::GraphBuilder* builder = input.Impl()->GetGraphBuilder(); + TensorDesc inputTensor = input.Impl()->GetOutputDesc(); + + assert(inputTensor.sizes.size() == 4); + + dml::TensorDesc::Dimensions outputSizes = { + inputTensor.sizes[0], + inputTensor.sizes[1] * blockSize * blockSize, + inputTensor.sizes[2] / blockSize, + inputTensor.sizes[3] / blockSize + }; + + TensorDesc outputTensor(inputTensor.dataType, outputSizes, builder->GetTensorPolicy()); + + DML_SPACE_TO_DEPTH1_OPERATOR_DESC desc = {}; + desc.InputTensor = inputTensor.AsPtr(); + desc.OutputTensor = outputTensor.AsPtr(); + desc.BlockSize = blockSize; + desc.Order = order; + + detail::NodeOutput* const inputs[] = { input.Impl() }; + detail::NodeID node = builder->CreateOperatorNode(DML_OPERATOR_SPACE_TO_DEPTH1, &desc, inputs); + detail::NodeOutput* output = builder->CreateNodeOutput(node, 0, std::move(outputTensor)); + + return output; + } + + inline Expression DepthToSpace( + Expression input, + uint32_t blockSize, + DML_DEPTH_SPACE_ORDER order = DML_DEPTH_SPACE_ORDER_DEPTH_COLUMN_ROW) + { + detail::GraphBuilder* builder = input.Impl()->GetGraphBuilder(); + TensorDesc inputTensor = input.Impl()->GetOutputDesc(); + + assert(inputTensor.sizes.size() == 4); + + dml::TensorDesc::Dimensions outputSizes = { + inputTensor.sizes[0], + inputTensor.sizes[1] / (blockSize * blockSize), + inputTensor.sizes[2] * blockSize, + inputTensor.sizes[3] * blockSize + }; + + TensorDesc outputTensor(inputTensor.dataType, outputSizes, builder->GetTensorPolicy()); + + DML_DEPTH_TO_SPACE1_OPERATOR_DESC desc = {}; + desc.InputTensor = inputTensor.AsPtr(); + desc.OutputTensor = outputTensor.AsPtr(); + desc.BlockSize = blockSize; + desc.Order = order; + + detail::NodeOutput* const inputs[] = { input.Impl() }; + detail::NodeID node = builder->CreateOperatorNode(DML_OPERATOR_DEPTH_TO_SPACE1, &desc, inputs); + detail::NodeOutput* output = builder->CreateNodeOutput(node, 0, std::move(outputTensor)); + + return output; + } + + inline Expression Tile(Expression input, Span repeats) + { + detail::GraphBuilder* builder = input.Impl()->GetGraphBuilder(); + TensorDimensions outputSizes = input.GetOutputDesc().sizes; + + assert(repeats.size() == outputSizes.size()); + + for (size_t i = 0; i < repeats.size(); ++i) + { + outputSizes[i] *= repeats[i]; + } + + TensorDesc inputTensor = input.Impl()->GetOutputDesc(); + TensorDesc outputTensor(inputTensor.dataType, std::move(outputSizes), builder->GetTensorPolicy()); + + DML_TILE_OPERATOR_DESC desc = {}; + desc.InputTensor = inputTensor.AsPtr(); + desc.OutputTensor = outputTensor.AsPtr(); + desc.RepeatsCount = static_cast(repeats.size()); + desc.Repeats = repeats.data(); + + detail::NodeOutput* const inputs[] = { input.Impl() }; + detail::NodeID node = builder->CreateOperatorNode(DML_OPERATOR_TILE, &desc, inputs); + detail::NodeOutput* output = builder->CreateNodeOutput(node, 0, std::move(outputTensor)); + + return output; + } + + struct TopKOutputs + { + Expression value; + Expression index; + }; + + inline TopKOutputs TopK(Expression input, uint32_t axis, uint32_t k, DML_AXIS_DIRECTION axisDirection) + { + detail::GraphBuilder* builder = input.Impl()->GetGraphBuilder(); + TensorDesc inputTensor = input.Impl()->GetOutputDesc(); + + TensorDimensions outputSizes = inputTensor.sizes; + outputSizes.back() = k; + + TensorDesc outputValueTensor(inputTensor.dataType, outputSizes, builder->GetTensorPolicy()); + TensorDesc outputIndexTensor(DML_TENSOR_DATA_TYPE_UINT32, outputSizes, builder->GetTensorPolicy()); + + DML_TOP_K1_OPERATOR_DESC desc = {}; + desc.InputTensor = inputTensor.AsPtr(); + desc.OutputValueTensor = outputValueTensor.AsPtr(); + desc.OutputIndexTensor = outputIndexTensor.AsPtr(); + desc.Axis = axis; + desc.K = k; + desc.AxisDirection = axisDirection; + + detail::NodeOutput* const inputs[] = { input.Impl() }; + detail::NodeID node = builder->CreateOperatorNode(DML_OPERATOR_TOP_K1, &desc, inputs); + detail::NodeOutput* outputValue = builder->CreateNodeOutput(node, 0, std::move(outputValueTensor)); + detail::NodeOutput* outputIndex = builder->CreateNodeOutput(node, 1, std::move(outputIndexTensor)); + + return { outputValue, outputIndex }; + } + + inline Expression BatchNormalization( + Expression input, + Expression mean, + Expression variance, + Expression scale, + Expression bias, + bool spatial, + float epsilon, + FusedActivation fusedActivation = FusedActivation::None()) + { + detail::GraphBuilder* builder = input.Impl()->GetGraphBuilder(); + TensorDesc inputTensor = input.Impl()->GetOutputDesc(); + TensorDesc meanTensor = mean.Impl()->GetOutputDesc(); + TensorDesc varianceTensor = variance.Impl()->GetOutputDesc(); + TensorDesc scaleTensor = scale.Impl()->GetOutputDesc(); + TensorDesc biasTensor = bias.Impl()->GetOutputDesc(); + TensorDesc outputTensor(inputTensor.dataType, inputTensor.sizes, builder->GetTensorPolicy()); + + detail::FusedActivationStorage storage; + + DML_BATCH_NORMALIZATION_OPERATOR_DESC desc = {}; + desc.InputTensor = inputTensor.AsPtr(); + desc.MeanTensor = meanTensor.AsPtr(); + desc.VarianceTensor = varianceTensor.AsPtr(); + desc.ScaleTensor = scaleTensor.AsPtr(); + desc.BiasTensor = biasTensor.AsPtr(); + desc.OutputTensor = outputTensor.AsPtr(); + desc.Spatial = spatial; + desc.Epsilon = epsilon; + desc.FusedActivation = detail::GetFusedActivationPtr(fusedActivation, &storage); + + detail::NodeOutput* const inputs[] = { input.Impl(), mean.Impl(), variance.Impl(), scale.Impl(), bias.Impl() }; + detail::NodeID node = builder->CreateOperatorNode(DML_OPERATOR_BATCH_NORMALIZATION, &desc, inputs); + detail::NodeOutput* output = builder->CreateNodeOutput(node, 0, std::move(outputTensor)); + + return output; + } + +#if DML_TARGET_VERSION >= 0x3100 + + struct BatchNormalizationGradOutputs + { + Expression gradient; + Expression scaleGradient; + Expression biasGradient; + }; + + inline BatchNormalizationGradOutputs BatchNormalizationGrad( + Expression input, + Expression inputGradient, + Expression mean, + Expression variance, + Expression scale, + float epsilon) + { + dml::detail::GraphBuilder* builder = mean.Impl()->GetGraphBuilder(); + TensorDesc inputTensor = input.Impl()->GetOutputDesc(); + TensorDesc inputGradientTensor = inputGradient.Impl()->GetOutputDesc(); + TensorDesc meanTensor = mean.Impl()->GetOutputDesc(); + TensorDesc varianceTensor = variance.Impl()->GetOutputDesc(); + TensorDesc scaleTensor = scale.Impl()->GetOutputDesc(); + TensorDesc outputGradientTensor(inputTensor.dataType, inputTensor.sizes, builder->GetTensorPolicy()); + TensorDesc outputScaleGradientTensor(meanTensor.dataType, meanTensor.sizes, builder->GetTensorPolicy()); + TensorDesc outputBiasGradientTensor(meanTensor.dataType, meanTensor.sizes, builder->GetTensorPolicy()); + + DML_BATCH_NORMALIZATION_GRAD_OPERATOR_DESC desc = {}; + desc.InputTensor = inputTensor.AsPtr(); + desc.InputGradientTensor = inputGradientTensor.AsPtr(); + desc.MeanTensor = meanTensor.AsPtr(); + desc.VarianceTensor = varianceTensor.AsPtr(); + desc.ScaleTensor = scaleTensor.AsPtr(); + desc.Epsilon = epsilon; + + desc.OutputGradientTensor = outputGradientTensor.AsPtr(); + desc.OutputScaleGradientTensor = outputScaleGradientTensor.AsPtr(); + desc.OutputBiasGradientTensor = outputBiasGradientTensor.AsPtr(); + + dml::detail::NodeOutput* const inputs[] = { input.Impl(), inputGradient.Impl(), mean.Impl(), variance.Impl(), scale.Impl() }; + dml::detail::NodeID node = builder->CreateOperatorNode(DML_OPERATOR_BATCH_NORMALIZATION_GRAD, &desc, inputs); + + BatchNormalizationGradOutputs outputValues; + outputValues.gradient = builder->CreateNodeOutput(node, 0, *desc.OutputGradientTensor); + outputValues.scaleGradient = builder->CreateNodeOutput(node, 1, *desc.OutputScaleGradientTensor); + outputValues.biasGradient = builder->CreateNodeOutput(node, 2, *desc.OutputBiasGradientTensor); + + return outputValues; + } + +#endif // DML_TARGET_VERSION >= 0x3100 + +#if DML_TARGET_VERSION >= 0x4100 + struct BatchNormalizationTrainingOutputs + { + Expression output; + Expression mean; + Expression variance; + }; + + inline BatchNormalizationTrainingOutputs BatchNormalizationTraining( + Expression input, + Expression scale, + Expression bias, + Optional fusedAdd, + float epsilon, + FusedActivation fusedActivation = FusedActivation::None()) + { + detail::GraphBuilder* builder = input.Impl()->GetGraphBuilder(); + TensorDesc inputTensor = input.Impl()->GetOutputDesc(); + TensorDesc scaleTensor = scale.Impl()->GetOutputDesc(); + TensorDesc biasTensor = bias.Impl()->GetOutputDesc(); + + TensorDesc fusedAddTensor; + if (fusedAdd) + { + fusedAddTensor = fusedAdd->Impl()->GetOutputDesc(); + } + + TensorDesc outputTensor(inputTensor.dataType, inputTensor.sizes, builder->GetTensorPolicy()); + TensorDesc outputMeanTensor(inputTensor.dataType, scaleTensor.sizes, builder->GetTensorPolicy()); + TensorDesc outputVarianceTensor(inputTensor.dataType, scaleTensor.sizes, builder->GetTensorPolicy()); + + detail::FusedActivationStorage storage; + + DML_BATCH_NORMALIZATION_TRAINING_OPERATOR_DESC desc = {}; + desc.InputTensor = inputTensor.AsPtr(); + desc.ScaleTensor = scaleTensor.AsPtr(); + desc.BiasTensor = biasTensor.AsPtr(); + desc.FusedAddTensor = fusedAdd.has_value() ? fusedAddTensor.AsPtr() : nullptr; + desc.OutputTensor = outputTensor.AsPtr(); + desc.OutputMeanTensor = outputMeanTensor.AsPtr(); + desc.OutputVarianceTensor = outputVarianceTensor.AsPtr(); + desc.Epsilon = epsilon; + desc.FusedActivation = detail::GetFusedActivationPtr(fusedActivation, &storage); + + detail::NodeOutput* const inputs[] = { input.Impl(), scale.Impl(), bias.Impl(), fusedAdd ? fusedAdd->Impl() : nullptr }; + detail::NodeID node = builder->CreateOperatorNode(DML_OPERATOR_BATCH_NORMALIZATION_TRAINING, &desc, inputs); + detail::NodeOutput* output = builder->CreateNodeOutput(node, 0, std::move(outputTensor)); + detail::NodeOutput* outputMean = builder->CreateNodeOutput(node, 1, std::move(outputMeanTensor)); + detail::NodeOutput* outputVariance = builder->CreateNodeOutput(node, 2, std::move(outputVarianceTensor)); + + return {output, outputMean, outputVariance}; + } + + inline BatchNormalizationGradOutputs BatchNormalizationTrainingGrad( + Expression input, + Expression inputGradient, + Expression mean, + Expression variance, + Expression scale, + float epsilon) + { + dml::detail::GraphBuilder* builder = mean.Impl()->GetGraphBuilder(); + TensorDesc inputTensor = input.Impl()->GetOutputDesc(); + TensorDesc inputGradientTensor = inputGradient.Impl()->GetOutputDesc(); + TensorDesc meanTensor = mean.Impl()->GetOutputDesc(); + TensorDesc varianceTensor = variance.Impl()->GetOutputDesc(); + TensorDesc scaleTensor = scale.Impl()->GetOutputDesc(); + TensorDesc outputGradientTensor(inputTensor.dataType, inputTensor.sizes, builder->GetTensorPolicy()); + TensorDesc outputScaleGradientTensor(meanTensor.dataType, meanTensor.sizes, builder->GetTensorPolicy()); + TensorDesc outputBiasGradientTensor(meanTensor.dataType, meanTensor.sizes, builder->GetTensorPolicy()); + + DML_BATCH_NORMALIZATION_TRAINING_GRAD_OPERATOR_DESC desc = {}; + desc.InputTensor = inputTensor.AsPtr(); + desc.InputGradientTensor = inputGradientTensor.AsPtr(); + desc.MeanTensor = meanTensor.AsPtr(); + desc.VarianceTensor = varianceTensor.AsPtr(); + desc.ScaleTensor = scaleTensor.AsPtr(); + desc.Epsilon = epsilon; + + desc.OutputGradientTensor = outputGradientTensor.AsPtr(); + desc.OutputScaleGradientTensor = outputScaleGradientTensor.AsPtr(); + desc.OutputBiasGradientTensor = outputBiasGradientTensor.AsPtr(); + + dml::detail::NodeOutput* const inputs[] = { input.Impl(), inputGradient.Impl(), mean.Impl(), variance.Impl(), scale.Impl() }; + dml::detail::NodeID node = builder->CreateOperatorNode(DML_OPERATOR_BATCH_NORMALIZATION_TRAINING_GRAD, &desc, inputs); + + BatchNormalizationGradOutputs outputValues; + outputValues.gradient = builder->CreateNodeOutput(node, 0, *desc.OutputGradientTensor); + outputValues.scaleGradient = builder->CreateNodeOutput(node, 1, *desc.OutputScaleGradientTensor); + outputValues.biasGradient = builder->CreateNodeOutput(node, 2, *desc.OutputBiasGradientTensor); + + return outputValues; + } +#endif // DML_TARGET_VERSION >= 0x4100 + + inline Expression MeanVarianceNormalization( + Expression input, + Optional scale, + Optional bias, + Span axes, + bool normalizeVariance, + float epsilon, + FusedActivation fusedActivation = FusedActivation::None()) + { + detail::GraphBuilder* builder = input.Impl()->GetGraphBuilder(); + TensorDesc inputTensor = input.Impl()->GetOutputDesc(); + TensorDesc outputTensor(inputTensor.dataType, inputTensor.sizes, builder->GetTensorPolicy()); + TensorDesc scaleTensor; + TensorDesc biasTensor; + + if (scale) + { + scaleTensor = scale->Impl()->GetOutputDesc(); + } + if (bias) + { + biasTensor = bias->Impl()->GetOutputDesc(); + } + + detail::FusedActivationStorage storage; + + DML_MEAN_VARIANCE_NORMALIZATION1_OPERATOR_DESC desc = {}; + desc.InputTensor = inputTensor.AsPtr(); + desc.ScaleTensor = scale ? scaleTensor.AsPtr() : nullptr; + desc.BiasTensor = bias ? biasTensor.AsPtr() : nullptr; + desc.OutputTensor = outputTensor.AsPtr(); + desc.AxisCount = static_cast(axes.size()); + desc.Axes = axes.data(); + desc.NormalizeVariance = normalizeVariance; + desc.Epsilon = epsilon; + desc.FusedActivation = detail::GetFusedActivationPtr(fusedActivation, &storage); + + detail::NodeOutput* const inputs[] = + { + input.Impl(), + scale ? scale->Impl() : nullptr, + bias ? bias->Impl() : nullptr + }; + detail::NodeID node = builder->CreateOperatorNode(DML_OPERATOR_MEAN_VARIANCE_NORMALIZATION1, &desc, inputs); + detail::NodeOutput* output = builder->CreateNodeOutput(node, 0, std::move(outputTensor)); + + return output; + } + + inline Expression LocalResponseNormalization( + Expression input, + bool crossChannel, + uint32_t localSize, + float alpha, + float beta, + float bias) + { + detail::GraphBuilder* builder = input.Impl()->GetGraphBuilder(); + TensorDesc inputTensor = input.Impl()->GetOutputDesc(); + TensorDesc outputTensor(inputTensor.dataType, inputTensor.sizes, builder->GetTensorPolicy()); + + DML_LOCAL_RESPONSE_NORMALIZATION_OPERATOR_DESC desc = {}; + desc.InputTensor = inputTensor.AsPtr(); + desc.OutputTensor = outputTensor.AsPtr(); + desc.CrossChannel = crossChannel; + desc.LocalSize = localSize; + desc.Alpha = alpha; + desc.Beta = beta; + desc.Bias = bias; + + detail::NodeOutput* const inputs[] = { input.Impl() }; + detail::NodeID node = builder->CreateOperatorNode(DML_OPERATOR_LOCAL_RESPONSE_NORMALIZATION, &desc, inputs); + detail::NodeOutput* output = builder->CreateNodeOutput(node, 0, std::move(outputTensor)); + + return output; + } + + // + // TODO: LpNormalization + // + + // + // TODO: RNN + // + + // + // TODO: LSTM + // + + enum class GRUOutputOptions + { + Both, + Sequence, + Single, + }; + + struct GRUOutputs + { + Expression sequence; + Expression single; + }; + + inline GRUOutputs GRU( + Expression input, + Expression weight, + Expression recurrence, + Optional bias, + Optional hiddenInit, + Optional sequenceLengths, + Span activationDescs, + DML_RECURRENT_NETWORK_DIRECTION direction, + bool linearBeforeReset, + GRUOutputOptions outputOptions) + { + detail::GraphBuilder* builder = input.Impl()->GetGraphBuilder(); + TensorDesc inputTensor = input.Impl()->GetOutputDesc(); + TensorDesc weightTensor = weight.Impl()->GetOutputDesc(); + TensorDesc recurrenceTensor = recurrence.Impl()->GetOutputDesc(); + TensorDesc biasTensor; + TensorDesc hiddenInitTensor; + TensorDesc sequenceLengthsTensor; + TensorDesc outputSequenceTensor; + TensorDesc outputSingleTensor; + if (bias) + { + biasTensor = bias->Impl()->GetOutputDesc(); + } + if (hiddenInit) + { + hiddenInitTensor = hiddenInit->Impl()->GetOutputDesc(); + } + if (sequenceLengths) + { + sequenceLengthsTensor = sequenceLengths->Impl()->GetOutputDesc(); + } + + TensorDesc::Dimensions outputSequenceSizes(4); + TensorDesc::Dimensions outputSingleSizes(4); + uint32_t directionCount = (direction == DML_RECURRENT_NETWORK_DIRECTION_BIDIRECTIONAL) ? 2 : 1; + if (outputOptions == GRUOutputOptions::Sequence || outputOptions == GRUOutputOptions::Both) + { + outputSequenceSizes[0] = inputTensor.sizes[1]; // SequenceLength + outputSequenceSizes[1] = directionCount; + outputSequenceSizes[2] = inputTensor.sizes[2]; // BatchSize + outputSequenceSizes[3] = recurrenceTensor.sizes[3]; // HiddenSize + outputSequenceTensor = TensorDesc(inputTensor.dataType, outputSequenceSizes, builder->GetTensorPolicy()); + } + if (outputOptions == GRUOutputOptions::Single || outputOptions == GRUOutputOptions::Both) + { + outputSingleSizes[0] = 1; + outputSingleSizes[1] = directionCount; + outputSingleSizes[2] = inputTensor.sizes[2]; // BatchSize + outputSingleSizes[3] = recurrenceTensor.sizes[3]; // HiddenSize + outputSingleTensor = TensorDesc(inputTensor.dataType, outputSingleSizes, builder->GetTensorPolicy()); + } + + uint32_t activationCount = static_cast(activationDescs.size()); + if (activationCount > 4) + { + DMLX_THROW(E_INVALIDARG); + } + + detail::FusedActivationStorage storage[4]; + DML_OPERATOR_DESC activationDescArray[4]; + for (uint32_t i = 0; i < activationCount; ++i) + { + activationDescArray[i] = *detail::GetFusedActivationPtr(activationDescs[i], &storage[i]); + } + + DML_GRU_OPERATOR_DESC desc = {}; + desc.InputTensor = inputTensor.AsPtr(); + desc.WeightTensor = weightTensor.AsPtr(); + desc.RecurrenceTensor = recurrenceTensor.AsPtr(); + desc.BiasTensor = bias ? biasTensor.AsPtr() : nullptr; + desc.HiddenInitTensor = hiddenInit ? hiddenInitTensor.AsPtr() : nullptr; + desc.SequenceLengthsTensor = sequenceLengths ? sequenceLengthsTensor.AsPtr() : nullptr; + desc.OutputSequenceTensor = outputSequenceTensor.sizes.empty() ? nullptr : outputSequenceTensor.AsPtr(); + desc.OutputSingleTensor = outputSingleTensor.sizes.empty() ? nullptr : outputSingleTensor.AsPtr(); + desc.ActivationDescCount = activationCount; + desc.ActivationDescs = activationDescArray; + desc.Direction = direction; + desc.LinearBeforeReset = linearBeforeReset; + + detail::NodeOutput* const inputs[] = + { + input.Impl(), + weight.Impl(), + recurrence.Impl(), + bias ? bias->Impl() : nullptr, + hiddenInit ? hiddenInit->Impl() : nullptr, + sequenceLengths ? sequenceLengths->Impl() : nullptr + }; + + detail::NodeID node = builder->CreateOperatorNode(DML_OPERATOR_GRU, &desc, inputs); + + detail::NodeOutput* outputSequenceExpr = nullptr; + detail::NodeOutput* outputSingleExpr = nullptr; + if (outputOptions == GRUOutputOptions::Sequence || outputOptions == GRUOutputOptions::Both) + { + outputSequenceExpr = builder->CreateNodeOutput(node, 0, std::move(outputSequenceTensor)); + } + if (outputOptions == GRUOutputOptions::Single || outputOptions == GRUOutputOptions::Both) + { + outputSingleExpr = builder->CreateNodeOutput(node, 1, std::move(outputSingleTensor)); + } + return { outputSequenceExpr, outputSingleExpr }; + } + + // + // TODO: DiagonalMatrix + // + + inline Expression OneHot( + Expression indices, + Expression values, + uint32_t outputLength, + uint32_t axis) + { + detail::GraphBuilder* builder = indices.Impl()->GetGraphBuilder(); + TensorDesc indicesTensor = indices.Impl()->GetOutputDesc(); + TensorDesc valuesTensor = values.Impl()->GetOutputDesc(); + + assert(axis < static_cast(indicesTensor.sizes.size())); + + // The output and indices sizes must all match except for the active axis, which is supplied as outputLength. + TensorDimensions outputSizes = indicesTensor.sizes; + outputSizes[axis] = outputLength; + + TensorDesc outputTensor(valuesTensor.dataType, std::move(outputSizes), builder->GetTensorPolicy()); + + DML_ONE_HOT_OPERATOR_DESC desc = {}; + desc.IndicesTensor = indicesTensor.AsPtr(); + desc.ValuesTensor = valuesTensor.AsPtr(); + desc.OutputTensor = outputTensor.AsPtr(); + desc.Axis = axis; + + detail::NodeOutput* const inputs[] = { indices.Impl(), values.Impl() }; + detail::NodeID node = builder->CreateOperatorNode(DML_OPERATOR_ONE_HOT, &desc, inputs); + detail::NodeOutput* output = builder->CreateNodeOutput(node, 0, std::move(outputTensor)); + + return output; + } + + // If not specified, parameters are defaulted to the following values: + // Scales = computed by dividing the output sizes by the input sizes + // InputPixelOffsets = 0.5f for each dimension + // OutputPixelOffsets = -0.5f for each dimension + // Antialiased = false + inline Expression Resample( + Expression input, + TensorDimensions outputSizes, + DML_INTERPOLATION_MODE mode, +#if DML_TARGET_VERSION >= 0x5100 + DML_AXIS_DIRECTION roundingDirection = DML_AXIS_DIRECTION_INCREASING, +#endif // DML_TARGET_VERSION >= 0x5100 + Span scales = {}, + Span inputPixelOffsets = {}, + Span outputPixelOffsets = {} +#if DML_TARGET_VERSION >= 0x6300 + , bool antialiased = false +#endif + ) + { + detail::GraphBuilder* builder = input.Impl()->GetGraphBuilder(); + + TensorDesc inputTensor = input.Impl()->GetOutputDesc(); + uint32_t dimensionCount = static_cast(inputTensor.sizes.size()); + assert(outputSizes.size() == dimensionCount); + assert(scales.empty() || scales.size() == dimensionCount); + assert(inputPixelOffsets.empty() || inputPixelOffsets.size() == dimensionCount); + assert(outputPixelOffsets.empty() || outputPixelOffsets.size() == dimensionCount); + + SmallVector defaultScales; + if (scales.empty()) + { + for (uint32_t i = 0; i < dimensionCount; ++i) + { + defaultScales.push_back(static_cast(outputSizes[i]) / static_cast(inputTensor.sizes[i])); + } + scales = defaultScales; + } + + SmallVector defaultInputPixelOffsets; + if (inputPixelOffsets.empty()) + { + defaultInputPixelOffsets.assign(dimensionCount, 0.5f); + inputPixelOffsets = defaultInputPixelOffsets; + } + + SmallVector defaultOutputPixelOffsets; + if (outputPixelOffsets.empty()) + { + defaultOutputPixelOffsets.assign(dimensionCount, -0.5f); + outputPixelOffsets = defaultOutputPixelOffsets; + } + + TensorDesc outputTensor(inputTensor.dataType, std::move(outputSizes), builder->GetTensorPolicy()); + +#if DML_TARGET_VERSION >= 0x6300 + DML_RESAMPLE3_OPERATOR_DESC desc = {}; + desc.RoundingDirection = roundingDirection; + desc.Antialiased = antialiased; +#elif DML_TARGET_VERSION >= 0x5100 + DML_RESAMPLE2_OPERATOR_DESC desc = {}; + desc.RoundingDirection = roundingDirection; +#else + DML_RESAMPLE1_OPERATOR_DESC desc = {}; +#endif // DML_TARGET_VERSION >= 0x5100 + + desc.InputTensor = inputTensor.AsPtr(); + desc.OutputTensor = outputTensor.AsPtr(); + desc.InterpolationMode = mode; + desc.DimensionCount = dimensionCount; + desc.Scales = scales.data(); + desc.InputPixelOffsets = inputPixelOffsets.data(); + desc.OutputPixelOffsets = outputPixelOffsets.data(); + + detail::NodeOutput* const inputs[] = { input.Impl() }; + +#if DML_TARGET_VERSION >= 0x6300 + detail::NodeID node = builder->CreateOperatorNode(DML_OPERATOR_RESAMPLE3, &desc, inputs); +#elif DML_TARGET_VERSION >= 0x5100 + detail::NodeID node = builder->CreateOperatorNode(DML_OPERATOR_RESAMPLE2, &desc, inputs); +#else + detail::NodeID node = builder->CreateOperatorNode(DML_OPERATOR_RESAMPLE1, &desc, inputs); +#endif // DML_TARGET_VERSION >= 0x5100 + detail::NodeOutput* output = builder->CreateNodeOutput(node, 0, std::move(outputTensor)); + + return output; + } + + inline Expression FillValueConstant( + Graph& graph, + TensorDimensions outputSizes, + DML_TENSOR_DATA_TYPE valueDataType, + DML_SCALAR_UNION value) + { + detail::GraphBuilder* builder = graph.Impl(); + TensorDesc outputTensor(valueDataType, std::move(outputSizes), builder->GetTensorPolicy()); + + DML_FILL_VALUE_CONSTANT_OPERATOR_DESC desc = {}; + desc.OutputTensor = outputTensor.AsPtr(); + desc.ValueDataType = valueDataType; + desc.Value = value; + + detail::NodeID node = builder->CreateOperatorNode(DML_OPERATOR_FILL_VALUE_CONSTANT, &desc, {}); + detail::NodeOutput* output = builder->CreateNodeOutput(node, 0, std::move(outputTensor)); + + return output; + } + + inline Expression FillValueSequence( + Graph& graph, + TensorDimensions outputSizes, + DML_TENSOR_DATA_TYPE valueDataType, + DML_SCALAR_UNION valueStart, + DML_SCALAR_UNION valueDelta) + { + detail::GraphBuilder* builder = graph.Impl(); + TensorDesc outputTensor(valueDataType, std::move(outputSizes), builder->GetTensorPolicy()); + + DML_FILL_VALUE_SEQUENCE_OPERATOR_DESC desc = {}; + desc.OutputTensor = outputTensor.AsPtr(); + desc.ValueDataType = valueDataType; + desc.ValueStart = valueStart; + desc.ValueDelta = valueDelta; + + detail::NodeID node = builder->CreateOperatorNode(DML_OPERATOR_FILL_VALUE_SEQUENCE, &desc, {}); + detail::NodeOutput* output = builder->CreateNodeOutput(node, 0, std::move(outputTensor)); + + return output; + } + + inline Expression CumulativeSummation( + Expression input, + uint32_t axis, + DML_AXIS_DIRECTION axisDirection, + bool hasExclusiveSum) + { + detail::GraphBuilder* builder = input.Impl()->GetGraphBuilder(); + TensorDesc inputTensor = input.Impl()->GetOutputDesc(); + TensorDesc outputTensor(inputTensor.dataType, inputTensor.sizes, builder->GetTensorPolicy()); + + DML_CUMULATIVE_SUMMATION_OPERATOR_DESC desc = {}; + desc.InputTensor = inputTensor.AsPtr(); + desc.OutputTensor = outputTensor.AsPtr(); + desc.Axis = axis; + desc.AxisDirection = axisDirection; + desc.HasExclusiveSum = hasExclusiveSum; + + detail::NodeOutput* const inputs[] = { input.Impl() }; + detail::NodeID node = builder->CreateOperatorNode(DML_OPERATOR_CUMULATIVE_SUMMATION, &desc, inputs); + detail::NodeOutput* output = builder->CreateNodeOutput(node, 0, std::move(outputTensor)); + + return output; + } + +#if DML_TARGET_VERSION >= 0x3100 + + inline Expression CumulativeProduct( + Expression input, + uint32_t axis, + DML_AXIS_DIRECTION axisDirection, + bool hasExclusiveProduct) + { + detail::GraphBuilder* builder = input.Impl()->GetGraphBuilder(); + TensorDesc inputTensor = input.Impl()->GetOutputDesc(); + TensorDesc outputTensor(inputTensor.dataType, inputTensor.sizes, builder->GetTensorPolicy()); + + DML_CUMULATIVE_PRODUCT_OPERATOR_DESC desc = {}; + desc.InputTensor = inputTensor.AsPtr(); + desc.OutputTensor = outputTensor.AsPtr(); + desc.Axis = axis; + desc.AxisDirection = axisDirection; + desc.HasExclusiveProduct = hasExclusiveProduct; + + detail::NodeOutput* const inputs[] = { input.Impl() }; + detail::NodeID node = builder->CreateOperatorNode(DML_OPERATOR_CUMULATIVE_PRODUCT, &desc, inputs); + detail::NodeOutput* output = builder->CreateNodeOutput(node, 0, std::move(outputTensor)); + + return output; + } + +#endif // DML_TARGET_VERSION >= 0x3100 + + inline Expression ReverseSubsequences( + Expression input, + Expression sequenceLengths, + uint32_t axis) + { + detail::GraphBuilder* builder = input.Impl()->GetGraphBuilder(); + TensorDesc inputTensor = input.Impl()->GetOutputDesc(); + TensorDesc sequenceLengthsTensor = sequenceLengths.Impl()->GetOutputDesc(); + TensorDesc outputTensor(inputTensor.dataType, inputTensor.sizes, builder->GetTensorPolicy()); + + DML_REVERSE_SUBSEQUENCES_OPERATOR_DESC reverseDesc = {}; + reverseDesc.InputTensor = inputTensor.AsPtr(); + reverseDesc.SequenceLengthsTensor = sequenceLengthsTensor.AsPtr(); + reverseDesc.OutputTensor = outputTensor.AsPtr(); + reverseDesc.Axis = axis; + + detail::NodeOutput* const inputs[] = { input.Impl(), sequenceLengths.Impl() }; + detail::NodeID node = builder->CreateOperatorNode(DML_OPERATOR_REVERSE_SUBSEQUENCES, &reverseDesc, inputs); + detail::NodeOutput* output = builder->CreateNodeOutput(node, 0, std::move(outputTensor)); + + return output; + } + + // + // TODO: MatrixMultiplyInteger + // + + // + // TODO: QuantizedLinearMatrixMultiply + // + + inline Expression ConvolutionInteger( + Expression input, + Optional inputZeroPoint, + Expression filter, + Optional filterZeroPoint, + Span strides = {}, + Span dilations = {}, + Span startPadding = {}, + Span endPadding = {}, + uint32_t groupCount = 1, + TensorDimensions outputSizes = {}) + { + assert(detail::HasSameOwner({ input, filter })); + assert(!inputZeroPoint || detail::HasSameOwner({ input, *inputZeroPoint })); + assert(!filterZeroPoint || detail::HasSameOwner({ filter, *filterZeroPoint })); + + detail::GraphBuilder* builder = input.Impl()->GetGraphBuilder(); + + TensorDesc inputTensor = input.Impl()->GetOutputDesc(); + TensorDesc filterTensor = filter.Impl()->GetOutputDesc(); + + TensorDesc inputZeroPointTensor; + if (inputZeroPoint) { inputZeroPointTensor = inputZeroPoint->Impl()->GetOutputDesc(); } + + TensorDesc filterZeroPointTensor; + if (filterZeroPoint) { filterZeroPointTensor = filterZeroPoint->Impl()->GetOutputDesc(); } + + uint32_t dimensionCount = static_cast(inputTensor.sizes.size()); + + // todo: support 1d convolution? + assert(dimensionCount == 4 || dimensionCount == 5); + uint32_t spatialDimensionCount = dimensionCount - 2; + + // If the spatial dimension count is 2, we'll just use the first two elements by setting + // DimensionCount = 2 in the desc + const uint32_t defaultStridesAndDilations[3] = { 1, 1, 1 }; + const uint32_t defaultPadding[3] = { 0, 0, 0 }; + + assert(strides.empty() || strides.size() == spatialDimensionCount); + assert(dilations.empty() || dilations.size() == spatialDimensionCount); + assert(startPadding.empty() || startPadding.size() == spatialDimensionCount); + assert(endPadding.empty() || endPadding.size() == spatialDimensionCount); + assert(outputSizes.empty() || outputSizes.size() == inputTensor.sizes.size()); + + strides = strides.empty() ? Span{ defaultStridesAndDilations } : strides; + dilations = dilations.empty() ? Span{ defaultStridesAndDilations } : dilations; + startPadding = startPadding.empty() ? Span{ defaultPadding } : startPadding; + endPadding = endPadding.empty() ? Span{ defaultPadding } : endPadding; + + if (outputSizes.empty()) + { + outputSizes.push_back(inputTensor.sizes[0]); // output[N] = input[N] + outputSizes.push_back(filterTensor.sizes[0]); // output[C] = filter[N] + + for (uint32_t dim = 0; dim < spatialDimensionCount; ++dim) + { + uint32_t inputSize = inputTensor.sizes[dim + 2]; + uint32_t paddedSize = inputSize + startPadding[dim] + endPadding[dim]; + + uint32_t windowSize = filterTensor.sizes[dim + 2]; + uint32_t kernelSize = 1 + (windowSize - 1) * dilations[dim]; + + assert(kernelSize <= paddedSize); + assert(strides[dim] != 0); + + outputSizes.push_back(1 + (paddedSize - kernelSize) / strides[dim]); + } + } + + TensorDesc outputTensor(DML_TENSOR_DATA_TYPE_INT32, std::move(outputSizes), builder->GetTensorPolicy()); + + DML_CONVOLUTION_INTEGER_OPERATOR_DESC desc = {}; + desc.InputTensor = inputTensor.AsPtr(); + desc.FilterTensor = filterTensor.AsPtr(); + desc.InputZeroPointTensor = inputZeroPoint ? inputZeroPointTensor.AsPtr() : nullptr; + desc.FilterZeroPointTensor = filterZeroPoint ? filterZeroPointTensor.AsPtr() : nullptr; + desc.OutputTensor = outputTensor.AsPtr(); + desc.DimensionCount = spatialDimensionCount; + desc.Strides = strides.data(); + desc.Dilations = dilations.data(); + desc.StartPadding = startPadding.data(); + desc.EndPadding = endPadding.data(); + desc.GroupCount = groupCount; + + detail::NodeOutput* const inputs[] = { + input.Impl(), + inputZeroPoint ? inputZeroPoint->Impl() : nullptr, + filter.Impl(), + filterZeroPoint ? filterZeroPoint->Impl() : nullptr + }; + detail::NodeID node = builder->CreateOperatorNode(DML_OPERATOR_CONVOLUTION_INTEGER, &desc, inputs); + detail::NodeOutput* output = builder->CreateNodeOutput(node, 0, std::move(outputTensor)); + + return output; + } + + + + inline Expression QuantizedLinearConvolution( + Expression input, + Expression inputScale, + Optional inputZeroPoint, + Expression filter, + Expression filterScale, + Optional filterZeroPoint, + Optional bias, + Expression outputScale, + Optional outputZeroPoint, + DML_TENSOR_DATA_TYPE outputDataType, // INT8 or UINT8, must match outputZeroPoint dtype if present + Span strides = {}, + Span dilations = {}, + Span startPadding = {}, + Span endPadding = {}, + uint32_t groupCount = 1, + TensorDimensions outputSizes = {}) + { + assert(detail::HasSameOwner({input, inputScale, filter, filterScale, outputScale})); + assert(!inputZeroPoint || detail::HasSameOwner({ input, *inputZeroPoint})); + assert(!filterZeroPoint || detail::HasSameOwner({ input, *filterZeroPoint})); + assert(!bias || detail::HasSameOwner({ input, *bias})); + assert(!outputZeroPoint || detail::HasSameOwner({ input, *outputZeroPoint})); + + if (outputZeroPoint) { + assert(outputZeroPoint->GetOutputDesc().dataType == outputDataType); + } + + detail::GraphBuilder* builder = input.Impl()->GetGraphBuilder(); + + const auto getOptional = [](Optional& e) { + if (e) return e->Impl()->GetOutputDesc(); + return TensorDesc{}; + }; + + TensorDesc inputTensor = input.Impl()->GetOutputDesc(); + TensorDesc inputScaleTensor = inputScale.Impl()->GetOutputDesc(); + TensorDesc inputZeroPointTensor = getOptional(inputZeroPoint); + TensorDesc filterTensor = filter.Impl()->GetOutputDesc(); + TensorDesc filterScaleTensor = filterScale.Impl()->GetOutputDesc(); + TensorDesc filterZeroPointTensor = getOptional(filterZeroPoint); + TensorDesc biasTensor = getOptional(bias); + TensorDesc outputScaleTensor = outputScale.Impl()->GetOutputDesc(); + TensorDesc outputZeroPointTensor = getOptional(outputZeroPoint); + + uint32_t dimensionCount = static_cast(inputTensor.sizes.size()); + + // todo: suppord 1d convolution? + assert(dimensionCount == 4 || dimensionCount == 5); + const uint32_t spatialDimensionCount = dimensionCount - 2; + + // If the spatial dimension count is 2, we'll just use the first two elements by setting + // DimensionCount = 2 in the desc + const uint32_t defaultStridesAndDilations[3] = { 1, 1, 1 }; + const uint32_t defaultPadding[3] = { 0, 0, 0 }; + + assert(strides.empty() || strides.size() == spatialDimensionCount); + assert(dilations.empty() || dilations.size() == spatialDimensionCount); + assert(startPadding.empty() || startPadding.size() == spatialDimensionCount); + assert(endPadding.empty() || endPadding.size() == spatialDimensionCount); + assert(outputSizes.empty() || outputSizes.size() == inputTensor.sizes.size()); + + strides = strides.empty() ? Span{ defaultStridesAndDilations } : strides; + dilations = dilations.empty() ? Span{ defaultStridesAndDilations } : dilations; + startPadding = startPadding.empty() ? Span{ defaultPadding } : startPadding; + endPadding = endPadding.empty() ? Span{ defaultPadding } : endPadding; + + if (outputSizes.empty()) + { + outputSizes.push_back(inputTensor.sizes[0]); // output[N] = input[N] + outputSizes.push_back(filterTensor.sizes[0]); // output[C] = filter[N] + + for (uint32_t dim = 0; dim < spatialDimensionCount; ++dim) + { + uint32_t inputSize = inputTensor.sizes[dim + 2]; + uint32_t paddedSize = inputSize + startPadding[dim] + endPadding[dim]; + + uint32_t windowSize = filterTensor.sizes[dim + 2]; + uint32_t kernelSize = 1 + (windowSize - 1) * dilations[dim]; + + assert(kernelSize <= paddedSize); + assert(strides[dim] != 0); + + outputSizes.push_back(1 + (paddedSize - kernelSize) / strides[dim]); + } + } + + + TensorDesc outputTensor(outputDataType, std::move(outputSizes), builder->GetTensorPolicy()); + + DML_QUANTIZED_LINEAR_CONVOLUTION_OPERATOR_DESC desc = {}; + desc.InputTensor = inputTensor.AsPtr(); + desc.InputScaleTensor = inputScaleTensor.AsPtr(); + desc.InputZeroPointTensor = inputZeroPoint ? inputZeroPointTensor.AsPtr() : nullptr; + desc.FilterTensor = filterTensor.AsPtr(); + desc.FilterScaleTensor = filterScaleTensor.AsPtr(); + desc.FilterZeroPointTensor = filterZeroPoint ? filterZeroPointTensor.AsPtr() : nullptr; + desc.BiasTensor = bias ? biasTensor.AsPtr() : nullptr; + desc.OutputScaleTensor = outputScaleTensor.AsPtr(); + desc.OutputZeroPointTensor = outputZeroPoint ? outputZeroPointTensor.AsPtr() : nullptr; + desc.OutputTensor = outputTensor.AsPtr(); + + desc.DimensionCount = spatialDimensionCount; + desc.Strides = strides.data(); + desc.Dilations = dilations.data(); + desc.StartPadding = startPadding.data(); + desc.EndPadding = endPadding.data(); + desc.GroupCount = groupCount; + + detail::NodeOutput* const inputs[] = { + input.Impl(), + inputScale.Impl(), + inputZeroPoint ? inputZeroPoint->Impl() : nullptr, + filter.Impl(), + filterScale.Impl(), + filterZeroPoint ? filterZeroPoint->Impl() : nullptr, + bias ? bias->Impl() : nullptr, + outputScale.Impl(), + outputZeroPoint ? outputZeroPoint->Impl() : nullptr, + }; + detail::NodeID node = builder->CreateOperatorNode(DML_OPERATOR_QUANTIZED_LINEAR_CONVOLUTION, &desc, inputs); + detail::NodeOutput* output = builder->CreateNodeOutput(node, 0, std::move(outputTensor)); + + return output; + } + + + // + // TODO: ReluGrad + // + + // + // TODO: AveragePoolingGrad + // + + // + // TODO: MaxPoolingGrad + // + + struct RandomGeneratorOutputs + { + Expression values; + Expression state; // Only valid if outputState = true is supplied to RandomGenerator + }; + + inline RandomGeneratorOutputs RandomGenerator( + Expression inputState, + TensorDimensions outputSizes, + bool outputState = true, + DML_RANDOM_GENERATOR_TYPE type = DML_RANDOM_GENERATOR_TYPE_PHILOX_4X32_10) + { + detail::GraphBuilder* builder = inputState.Impl()->GetGraphBuilder(); + + TensorDesc inputStateTensor = inputState.Impl()->GetOutputDesc(); + TensorDesc outputTensor(DML_TENSOR_DATA_TYPE_UINT32, std::move(outputSizes), builder->GetTensorPolicy()); + + DML_RANDOM_GENERATOR_OPERATOR_DESC desc = {}; + desc.Type = type; + desc.InputStateTensor = inputStateTensor.AsPtr(); + desc.OutputTensor = outputTensor.AsPtr(); + if (outputState) + { + // Input and output state have the same TensorDesc. + desc.OutputStateTensor = inputStateTensor.AsPtr(); + } + + RandomGeneratorOutputs out; + + detail::NodeOutput* const inputs[] = { inputState.Impl() }; + detail::NodeID node = builder->CreateOperatorNode(DML_OPERATOR_RANDOM_GENERATOR, &desc, inputs); + out.values = builder->CreateNodeOutput(node, 0, std::move(outputTensor)); + + if (outputState) + { + TensorDesc outputStateTensor = inputStateTensor; + out.state = builder->CreateNodeOutput(node, 1, std::move(outputStateTensor)); + } + + return out; + } + + struct NonZeroCoordinatesOutputs + { + Expression count; + Expression coordinates; + }; + inline NonZeroCoordinatesOutputs NonZeroCoordinates(Expression input) + { + detail::GraphBuilder* builder = input.Impl()->GetGraphBuilder(); + TensorDesc inputTensor = input.Impl()->GetOutputDesc(); + const auto& inputTensorSizes = inputTensor.sizes; + uint32_t dimensionCount = static_cast(inputTensorSizes.size()); + + TensorDimensions outputCountSizes = {1}; + uint32_t totalElements = 1; + for (uint32_t i = 0; i < dimensionCount; ++i) + { + totalElements *= inputTensorSizes[i]; + } + TensorDesc outputCountTensor(DML_TENSOR_DATA_TYPE_UINT32, outputCountSizes, builder->GetTensorPolicy()); + TensorDesc outputCoordinatesTensor(DML_TENSOR_DATA_TYPE_UINT32, {totalElements, dimensionCount}, builder->GetTensorPolicy()); + + DML_NONZERO_COORDINATES_OPERATOR_DESC desc = {}; + desc.InputTensor = inputTensor.AsPtr(); + desc.OutputCountTensor = outputCountTensor.AsPtr(); + desc.OutputCoordinatesTensor = outputCoordinatesTensor.AsPtr(); + + NonZeroCoordinatesOutputs output; + + detail::NodeOutput* const inputs[] = { input.Impl() }; + detail::NodeID node = builder->CreateOperatorNode(DML_OPERATOR_NONZERO_COORDINATES, &desc, inputs); + output.count = builder->CreateNodeOutput(node, 0, std::move(outputCountTensor)); + output.coordinates = builder->CreateNodeOutput(node, 1, std::move(outputCoordinatesTensor)); + return output; + } + + // If not specified, parameters are defaulted to the following values: + // Scales = computed by dividing the input sizes by the output sizes + // InputPixelOffsets = 0.5f for each dimension + // OutputPixelOffsets = -0.5f for each dimension + inline Expression ResampleGrad( + Expression input, + TensorDimensions outputSizes, + DML_INTERPOLATION_MODE mode, + Span scales = {}, + Span inputPixelOffsets = {}, + Span outputPixelOffsets = {}) + { + detail::GraphBuilder* builder = input.Impl()->GetGraphBuilder(); + + TensorDesc inputTensor = input.Impl()->GetOutputDesc(); + uint32_t dimensionCount = static_cast(inputTensor.sizes.size()); + assert(outputSizes.size() == dimensionCount); + + SmallVector defaultScales; + if (scales.empty()) + { + for (uint32_t i = 0; i < dimensionCount; ++i) + { + defaultScales.push_back(static_cast(inputTensor.sizes[i]) / static_cast(outputSizes[i])); + } + scales = defaultScales; + } + + SmallVector defaultInputPixelOffsets; + if (inputPixelOffsets.empty()) + { + defaultInputPixelOffsets.assign(dimensionCount, 0.5f); + inputPixelOffsets = defaultInputPixelOffsets; + } + + SmallVector defaultOutputPixelOffsets; + if (outputPixelOffsets.empty()) + { + defaultOutputPixelOffsets.assign(dimensionCount, -0.5f); + outputPixelOffsets = defaultOutputPixelOffsets; + } + + TensorDesc outputTensor(inputTensor.dataType, std::move(outputSizes), builder->GetTensorPolicy()); + + DML_RESAMPLE_GRAD_OPERATOR_DESC desc = {}; + desc.InputGradientTensor = inputTensor.AsPtr(); + desc.OutputGradientTensor = outputTensor.AsPtr(); + desc.InterpolationMode = mode; + desc.DimensionCount = static_cast(scales.size()); + desc.Scales = scales.data(); + desc.InputPixelOffsets = inputPixelOffsets.data(); + desc.OutputPixelOffsets = outputPixelOffsets.data(); + + detail::NodeOutput* const inputs[] = { input.Impl() }; + detail::NodeID node = builder->CreateOperatorNode(DML_OPERATOR_RESAMPLE_GRAD, &desc, inputs); + detail::NodeOutput* output = builder->CreateNodeOutput(node, 0, std::move(outputTensor)); + + return output; + } + + inline Expression SliceGrad( + Expression inputGradient, + TensorDimensions outputGradientSizes, + Span inputWindowOffsets, + Span inputWindowSizes, + Span inputWindowStrides) + { + detail::GraphBuilder* builder = inputGradient.Impl()->GetGraphBuilder(); + + TensorDesc inputGradientTensor = inputGradient.Impl()->GetOutputDesc(); + + assert(inputWindowOffsets.size() == inputGradientTensor.sizes.size()); + assert(inputWindowOffsets.size() == outputGradientSizes.size()); + assert(inputWindowOffsets.size() == inputWindowStrides.size()); + assert(inputWindowOffsets.size() == inputWindowSizes.size()); + + TensorDesc outputGradientTensor(inputGradientTensor.dataType, std::move(outputGradientSizes), builder->GetTensorPolicy()); + + DML_SLICE_GRAD_OPERATOR_DESC sliceGradDesc = {}; + sliceGradDesc.InputGradientTensor = inputGradientTensor.AsPtr(); + sliceGradDesc.OutputGradientTensor = outputGradientTensor.AsPtr(); + sliceGradDesc.DimensionCount = static_cast(inputWindowOffsets.size()); + sliceGradDesc.InputWindowOffsets = inputWindowOffsets.data(); + sliceGradDesc.InputWindowSizes = inputWindowSizes.data(); + sliceGradDesc.InputWindowStrides = inputWindowStrides.data(); + + detail::NodeOutput* const inputs[] = { inputGradient.Impl() }; + detail::NodeID node = builder->CreateOperatorNode(DML_OPERATOR_SLICE_GRAD, &sliceGradDesc, inputs); + detail::NodeOutput* output = builder->CreateNodeOutput(node, 0, std::move(outputGradientTensor)); + + return output; + } + + // + // TODO: AdamOptimizer + // + + // + // TODO: Argmin + // + + // + // TODO: Argmax + // + +#if DML_TARGET_VERSION >= 0x4000 + + inline Expression RoiAlign( + Expression input, + Expression roi, + Expression batchIndices, + DML_REDUCE_FUNCTION reductionFunction, + DML_INTERPOLATION_MODE interpolationMode, + float spatialScaleX, + float spatialScaleY, + float inputPixelOffset, + float outputPixelOffset, + float outOfBoundsInputValue, + uint32_t minimumSamplesPerOutput, + uint32_t maximumSamplesPerOutput, + bool alignRegionsToCorners, + uint32_t outputHeight, + uint32_t outputWidth) + { + detail::GraphBuilder* builder = input.Impl()->GetGraphBuilder(); + + TensorDesc inputTensor = input.Impl()->GetOutputDesc(); + TensorDesc roiTensor = roi.Impl()->GetOutputDesc(); + TensorDesc batchIndicesTensor = batchIndices.Impl()->GetOutputDesc(); + + uint32_t channelCount = inputTensor.sizes[1]; + uint32_t roiCount = roiTensor.sizes.size() < 2 ? 1u : roiTensor.sizes[roiTensor.sizes.size() - 2]; + + TensorDesc::Dimensions outputSizes({ + roiCount, + channelCount, + outputHeight, + outputWidth, + }); + + TensorDesc outputTensor(inputTensor.dataType, outputSizes, builder->GetTensorPolicy()); + + DML_ROI_ALIGN1_OPERATOR_DESC desc = {}; + desc.InputTensor = inputTensor.AsPtr(); + desc.ROITensor = roiTensor.AsPtr(); + desc.BatchIndicesTensor = batchIndicesTensor.AsPtr(); + desc.OutputTensor = outputTensor.AsPtr(); + desc.ReductionFunction = reductionFunction; + desc.InterpolationMode = interpolationMode; + desc.SpatialScaleX = spatialScaleX; + desc.SpatialScaleY = spatialScaleY; + desc.InputPixelOffset = inputPixelOffset; + desc.OutputPixelOffset = outputPixelOffset; + desc.OutOfBoundsInputValue = outOfBoundsInputValue; + desc.MinimumSamplesPerOutput = minimumSamplesPerOutput; + desc.MaximumSamplesPerOutput = maximumSamplesPerOutput; + desc.AlignRegionsToCorners = alignRegionsToCorners; + + detail::NodeOutput* const inputs[] = { input.Impl(), roi.Impl(), batchIndices.Impl() }; + detail::NodeID node = builder->CreateOperatorNode(DML_OPERATOR_ROI_ALIGN1, &desc, inputs); + detail::NodeOutput* output = builder->CreateNodeOutput(node, 0, std::move(outputTensor)); + + return output; + } + +#endif // DML_TARGET_VERSION >= 0x4000 + +#if DML_TARGET_VERSION >= 0x4100 + struct RoiAlignGradOutputs + { + Expression outputGradient; + Expression outputROIGradient; + }; + + inline RoiAlignGradOutputs RoiAlignGrad( + Optional input, + Expression inputGradient, + Expression roi, + Expression batchIndices, + DML_REDUCE_FUNCTION reductionFunction, + DML_INTERPOLATION_MODE interpolationMode, + float spatialScaleX, + float spatialScaleY, + float inputPixelOffset, + float outputPixelOffset, + uint32_t minimumSamplesPerOutput, + uint32_t maximumSamplesPerOutput, + bool alignRegionsToCorners, + uint32_t batchSize, + uint32_t imageHeight, + uint32_t imageWidth, + bool computeOutputGradient, + bool computeOutputROIGradient) + { + detail::GraphBuilder* builder = inputGradient.Impl()->GetGraphBuilder(); + + TensorDesc inputTensor = input.has_value() ? input->Impl()->GetOutputDesc() : TensorDesc(); + TensorDesc inputGradientTensor = inputGradient.Impl()->GetOutputDesc(); + TensorDesc roiTensor = roi.Impl()->GetOutputDesc(); + TensorDesc batchIndicesTensor = batchIndices.Impl()->GetOutputDesc(); + + assert(computeOutputGradient || computeOutputROIGradient); + assert(inputGradientTensor.sizes.size() > 1); + + TensorDesc outputGradientTensor; + if (computeOutputGradient) + { + TensorDesc::Dimensions outputGradientSizes({ + batchSize, + inputGradientTensor.sizes[1], + imageHeight, + imageWidth, + }); + + outputGradientTensor = TensorDesc(inputGradientTensor.dataType, outputGradientSizes, builder->GetTensorPolicy()); + } + + TensorDesc outputROIGradientTensor = computeOutputROIGradient ? TensorDesc(roiTensor.dataType, roiTensor.sizes, builder->GetTensorPolicy()) : TensorDesc(); + assert(!computeOutputROIGradient || outputROIGradientTensor.sizes == roiTensor.sizes); + + DML_ROI_ALIGN_GRAD_OPERATOR_DESC desc = {}; + desc.InputTensor = input ? inputTensor.AsPtr() : nullptr; + desc.InputGradientTensor = inputGradientTensor.AsPtr(); + desc.ROITensor = roiTensor.AsPtr(); + desc.BatchIndicesTensor = batchIndicesTensor.AsPtr(); + desc.OutputGradientTensor = computeOutputGradient ? outputGradientTensor.AsPtr() : nullptr; + desc.OutputROIGradientTensor = computeOutputROIGradient ? outputROIGradientTensor.AsPtr() : nullptr; + desc.ReductionFunction = reductionFunction; + desc.InterpolationMode = interpolationMode; + desc.SpatialScaleX = spatialScaleX; + desc.SpatialScaleY = spatialScaleY; + desc.InputPixelOffset = inputPixelOffset; + desc.OutputPixelOffset = outputPixelOffset; + desc.MinimumSamplesPerOutput = minimumSamplesPerOutput; + desc.MaximumSamplesPerOutput = maximumSamplesPerOutput; + desc.AlignRegionsToCorners = alignRegionsToCorners; + + detail::NodeOutput* const inputs[] = { input ? input->Impl() : nullptr, inputGradient.Impl(), roi.Impl(), batchIndices.Impl() }; + detail::NodeID node = builder->CreateOperatorNode(static_cast(DML_OPERATOR_ROI_ALIGN_GRAD), &desc, inputs); + + RoiAlignGradOutputs outputs {}; + + if (computeOutputGradient) + { + outputs.outputGradient = builder->CreateNodeOutput(node, 0, std::move(outputGradientTensor)); + } + + if (computeOutputROIGradient) + { + outputs.outputROIGradient = builder->CreateNodeOutput(node, 1, std::move(outputROIGradientTensor)); + } + + return outputs; + } +#endif + +#if DML_TARGET_VERSION >= 0x5000 + + inline Expression Negate(Expression input) + { + return detail::ElementWiseUnary(input); + } + +#endif // DML_TARGET_VERSION >= 0x5000 + + // Reinterprets the memory of a tensor with a different type and dimensions (analogously to using + // reinterpret_cast to access raw bits). Note that this is different to the DML Cast operator, which performs + // a type cast on the contents of a tensor (analogously to static_cast). The total tensor size of the output + // (which depends on the supplied type/sizes/strides) must match the input. + inline Expression Reinterpret( + Expression input, + DML_TENSOR_DATA_TYPE newType, + TensorDimensions newSizes, + Optional newStrides) + { + detail::GraphBuilder* builder = input.Impl()->GetGraphBuilder(); + TensorDesc inputTensor = input.Impl()->GetOutputDesc(); + TensorDesc newTensor( + newType, + inputTensor.flags, + std::move(newSizes), + std::move(newStrides), + inputTensor.totalTensorSizeInBytes, + inputTensor.guaranteedBaseOffsetAlignment); + + detail::NodeID node = builder->CreateReinterpretNode(input.Impl()); + detail::NodeOutput* output = builder->CreateNodeOutput(node, 0, std::move(newTensor)); + + return output; + } + + // Same as Reinterpret above, but only adjusts tensor dimensions without affecting type. + inline Expression Reinterpret( + Expression input, + TensorDimensions newSizes, + Optional newStrides) + { + TensorDesc inputTensor = input.Impl()->GetOutputDesc(); + + return Reinterpret(input, inputTensor.dataType, std::move(newSizes), std::move(newStrides)); + } + + // Same as Reinterpret above, but only adjusts tensor type without affecting sizes or strides. + inline Expression Reinterpret(Expression input, DML_TENSOR_DATA_TYPE newType) + { + TensorDesc inputTensor = input.Impl()->GetOutputDesc(); + + return Reinterpret(input, newType, inputTensor.sizes, inputTensor.strides); + } + + // Operator overloads for convenience, which merely map to one of the functions above + inline Expression operator+(Expression a, Expression b) { return dml::Add(a, b); } + inline Expression operator-(Expression a, Expression b) { return dml::Subtract(a, b); } + inline Expression operator*(Expression a, Expression b) { return dml::Multiply(a, b); } + inline Expression operator/(Expression a, Expression b) { return dml::Divide(a, b); } + inline Expression operator%(Expression a, Expression b) { return dml::ModulusTruncate(a, b); } + inline Expression operator&(Expression a, Expression b) { return dml::BitAnd(a, b); } + inline Expression operator|(Expression a, Expression b) { return dml::BitOr(a, b); } + inline Expression operator^(Expression a, Expression b) { return dml::BitXor(a, b); } + inline Expression operator<<(Expression a, Expression b) { return dml::BitShiftLeft(a, b); } + inline Expression operator>>(Expression a, Expression b) { return dml::BitShiftRight(a, b); } + inline Expression& operator+=(Expression& a, Expression b) { a = a + b; return a; } + inline Expression& operator-=(Expression& a, Expression b) { a = a - b; return a; } + inline Expression& operator*=(Expression& a, Expression b) { a = a * b; return a; } + inline Expression& operator/=(Expression& a, Expression b) { a = a / b; return a; } + inline Expression& operator%=(Expression& a, Expression b) { a = a % b; return a; } + inline Expression& operator&=(Expression& a, Expression b) { a = a & b; return a; } + inline Expression& operator|=(Expression& a, Expression b) { a = a | b; return a; } + inline Expression& operator^=(Expression& a, Expression b) { a = a ^ b; return a; } + inline Expression& operator<<=(Expression& a, Expression b) { a = a << b; return a; } + inline Expression& operator>>=(Expression& a, Expression b) { a = a >> b; return a; } + + // Operations involving scalars can be reduced to elementwise identity + inline Expression operator+(Expression a, float b) { return dml::Identity(a, DML_SCALE_BIAS{ 1.0f, b }); } + inline Expression operator-(Expression a, float b) { return dml::Identity(a, DML_SCALE_BIAS{ 1.0f, -b }); } + inline Expression operator*(Expression a, float b) { return dml::Identity(a, DML_SCALE_BIAS{ b, 0.0f }); } + inline Expression operator/(Expression a, float b) { return dml::Identity(a, DML_SCALE_BIAS{ 1.0f / b, 0.0f }); } + inline Expression operator+(float a, Expression b) { return dml::Identity(b, DML_SCALE_BIAS{ 1.0f, a }); } + inline Expression operator-(float a, Expression b) { return dml::Identity(b, DML_SCALE_BIAS{ -1.0f, a }); } + inline Expression operator*(float a, Expression b) { return dml::Identity(b, DML_SCALE_BIAS{ a, 0.0f }); } + inline Expression operator/(float a, Expression b) { return dml::Recip(b, DML_SCALE_BIAS{ a, 0.0f }); } + inline Expression& operator+=(Expression& a, float b) { a = a + b; return a; } + inline Expression& operator-=(Expression& a, float b) { a = a - b; return a; } + inline Expression& operator*=(Expression& a, float b) { a = a * b; return a; } + inline Expression& operator/=(Expression& a, float b) { a = a / b; return a; } + + // Unary + inline Expression operator~(Expression input) { return dml::BitNot(input); } + inline Expression operator+(Expression input) { return dml::Identity(input); } + +#if DML_TARGET_VERSION >= 0x5000 + + inline Expression operator-(Expression input) { return dml::Negate(input); } + +#else + + inline Expression operator-(Expression input) { return dml::Identity(input, DML_SCALE_BIAS{ -1.0f, 0.0f }); } + +#endif // DML_TARGET_VERSION >= 0x5000 + + // Logical + inline Expression operator!(Expression a) { return dml::LogicalNot(a); } + inline Expression operator&&(Expression a, Expression b) { return dml::LogicalAnd(a, b); } + inline Expression operator||(Expression a, Expression b) { return dml::LogicalOr(a, b); } + inline Expression operator>(Expression a, Expression b) { return dml::GreaterThan(a, b); } + inline Expression operator<(Expression a, Expression b) { return dml::LessThan(a, b); } + inline Expression operator==(Expression a, Expression b) { return dml::Equals(a, b); } + inline Expression operator!=(Expression a, Expression b) { return !(a == b); } + inline Expression operator>=(Expression a, Expression b) { return dml::GreaterThanOrEqual(a, b); } + inline Expression operator<=(Expression a, Expression b) { return dml::LessThanOrEqual(a, b); } + + // GraphBuilder implementation details + namespace detail + { + inline NodeID GraphBuilder::CreateOperatorNode( + DML_OPERATOR_TYPE type, + const void* desc, + Span inputs) + { + DML_OPERATOR_DESC opDesc = { type, desc }; + + Microsoft::WRL::ComPtr op; + DMLX_THROW_IF_FAILED(m_device->CreateOperator(&opDesc, IID_PPV_ARGS(&op))); + + OperatorNode node = {}; + node.op = std::move(op); + node.inputs.assign(inputs.begin(), inputs.end()); + if (!m_name.empty()) + { + node.name = m_name; + } + + uint32_t index = static_cast(m_operatorNodes.size()); + m_operatorNodes.push_back(std::move(node)); + + return { NodeType::Operator, index }; + } + + inline NodeID GraphBuilder::CreateInputNode(uint32_t inputIndex) + { + uint32_t index = static_cast(m_inputNodes.size()); + m_inputNodes.push_back(InputNode{ inputIndex }); + return { NodeType::Input, index }; + } + + inline NodeID GraphBuilder::CreateReinterpretNode(NodeOutput* input) + { + uint32_t index = static_cast(m_reinterpretNodes.size()); + m_reinterpretNodes.push_back(ReinterpretNode{ input }); + return { NodeType::Reinterpret, index }; + } + + inline NodeOutput* GraphBuilder::CreateNodeOutput(NodeID node, uint32_t outputIndex, TensorDesc tensorDesc) + { + // Construct the object in the deque, which doesn't invalidate references to elements as it grows + m_nodeOutputs.emplace_back(this, node, outputIndex, std::move(tensorDesc)); + + return &m_nodeOutputs.back(); + } + + inline GraphDesc GraphBuilder::GetGraphDesc(Span outputs) const + { + GraphDesc desc = {}; + desc.inputCount = static_cast(m_inputNodes.size()); + desc.outputCount = static_cast(outputs.size()); + + for (const OperatorNode& node : m_operatorNodes) + { + uint32_t nodeIndex = static_cast(desc.nodes.size()); + + desc.nodes.push_back(DML_OPERATOR_GRAPH_NODE_DESC{ node.op.Get(), (!node.name.empty() ? node.name.c_str() : nullptr) }); + + // Walk through each of this node's inputs and add it as an edge + const uint32_t inputCount = static_cast(node.inputs.size()); + for (uint32_t inputIndex = 0; inputIndex < inputCount; ++inputIndex) + { + NodeOutput* input = node.inputs[inputIndex]; + if (input == nullptr) + { + continue; + } + NodeID inputNode = input->GetNode(); + + // Reinterpret nodes aren't "real" nodes, they're just used to modify TensorDescs across + // edges. So we follow this node backwards until it hits a real node. + while (inputNode.type == NodeType::Reinterpret) + { + input = m_reinterpretNodes[inputNode.index].input; + inputNode = input->GetNode(); + } + + if (inputNode.type == NodeType::Input) + { + DML_INPUT_GRAPH_EDGE_DESC inputEdge = {}; + inputEdge.GraphInputIndex = m_inputNodes[inputNode.index].inputIndex; + inputEdge.ToNodeIndex = nodeIndex; + inputEdge.ToNodeInputIndex = inputIndex; + + desc.inputEdges.push_back(inputEdge); + } + else if (inputNode.type == NodeType::Operator) + { + DML_INTERMEDIATE_GRAPH_EDGE_DESC intermediateEdge = {}; + intermediateEdge.FromNodeIndex = inputNode.index; + intermediateEdge.FromNodeOutputIndex = input->GetOutputIndex(); + intermediateEdge.ToNodeIndex = nodeIndex; + intermediateEdge.ToNodeInputIndex = inputIndex; + + desc.intermediateEdges.push_back(intermediateEdge); + } + else + { + assert(false); // Invalid node type + DMLX_THROW(E_UNEXPECTED); + } + } + } + + // Add output edges + for (uint32_t outputIndex = 0; outputIndex < desc.outputCount; ++outputIndex) + { + NodeOutput* output = outputs[outputIndex].Impl(); + if (output == nullptr) + { + continue; + } + NodeID outputNode = output->GetNode(); + + // Reinterpret nodes are meaningless on outputs (they're no-ops), so just follow them back until we + // get to a real operator node. + while (outputNode.type == NodeType::Reinterpret) + { + output = m_reinterpretNodes[outputNode.index].input; + outputNode = output->GetNode(); + } + + if (outputNode.type == NodeType::Input) + { + // It's not valid to connect an output of the graph directly to an input without an intervening + // node. If this behavior is desired, it should instead be accomplished with a copy e.g. using + // the elementwise identity operator. + DMLX_THROW(E_INVALIDARG); + } + + assert(outputNode.type == NodeType::Operator); + + DML_OUTPUT_GRAPH_EDGE_DESC outputEdge = {}; + outputEdge.FromNodeIndex = output->GetNode().index; + outputEdge.FromNodeOutputIndex = output->GetOutputIndex(); + outputEdge.GraphOutputIndex = outputIndex; + + desc.outputEdges.push_back(outputEdge); + } + + // Sanity + assert(desc.nodes.size() == m_operatorNodes.size()); + assert(desc.outputEdges.size() == desc.outputCount); + assert(desc.outputCount == outputs.size()); + + return desc; + } + } // namespace detail + +} // namespace dml diff --git a/Samples/DirectMLConv/DirectMLConv/d3dx12.h b/Samples/DirectMLConv/DirectMLConv/d3dx12.h new file mode 100644 index 00000000..ff8465a6 --- /dev/null +++ b/Samples/DirectMLConv/DirectMLConv/d3dx12.h @@ -0,0 +1,3439 @@ +//********************************************************* +// +// Copyright (c) Microsoft. All rights reserved. +// This code is licensed under the MIT License (MIT). +// THIS CODE IS PROVIDED *AS IS* WITHOUT WARRANTY OF +// ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING ANY +// IMPLIED WARRANTIES OF FITNESS FOR A PARTICULAR +// PURPOSE, MERCHANTABILITY, OR NON-INFRINGEMENT. +// +//********************************************************* + +#ifndef __D3DX12_H__ +#define __D3DX12_H__ + +#include "d3d12.h" + +#if defined( __cplusplus ) + +struct CD3DX12_DEFAULT {}; +extern const DECLSPEC_SELECTANY CD3DX12_DEFAULT D3D12_DEFAULT; + +//------------------------------------------------------------------------------------------------ +inline bool operator==( const D3D12_VIEWPORT& l, const D3D12_VIEWPORT& r ) +{ + return l.TopLeftX == r.TopLeftX && l.TopLeftY == r.TopLeftY && l.Width == r.Width && + l.Height == r.Height && l.MinDepth == r.MinDepth && l.MaxDepth == r.MaxDepth; +} + +//------------------------------------------------------------------------------------------------ +inline bool operator!=( const D3D12_VIEWPORT& l, const D3D12_VIEWPORT& r ) +{ return !( l == r ); } + +//------------------------------------------------------------------------------------------------ +struct CD3DX12_RECT : public D3D12_RECT +{ + CD3DX12_RECT() = default; + explicit CD3DX12_RECT( const D3D12_RECT& o ) : + D3D12_RECT( o ) + {} + explicit CD3DX12_RECT( + LONG Left, + LONG Top, + LONG Right, + LONG Bottom ) + { + left = Left; + top = Top; + right = Right; + bottom = Bottom; + } +}; + +//------------------------------------------------------------------------------------------------ +struct CD3DX12_VIEWPORT : public D3D12_VIEWPORT +{ + CD3DX12_VIEWPORT() = default; + explicit CD3DX12_VIEWPORT( const D3D12_VIEWPORT& o ) : + D3D12_VIEWPORT( o ) + {} + explicit CD3DX12_VIEWPORT( + FLOAT topLeftX, + FLOAT topLeftY, + FLOAT width, + FLOAT height, + FLOAT minDepth = D3D12_MIN_DEPTH, + FLOAT maxDepth = D3D12_MAX_DEPTH ) + { + TopLeftX = topLeftX; + TopLeftY = topLeftY; + Width = width; + Height = height; + MinDepth = minDepth; + MaxDepth = maxDepth; + } + explicit CD3DX12_VIEWPORT( + _In_ ID3D12Resource* pResource, + UINT mipSlice = 0, + FLOAT topLeftX = 0.0f, + FLOAT topLeftY = 0.0f, + FLOAT minDepth = D3D12_MIN_DEPTH, + FLOAT maxDepth = D3D12_MAX_DEPTH ) + { + auto Desc = pResource->GetDesc(); + const UINT64 SubresourceWidth = Desc.Width >> mipSlice; + const UINT64 SubresourceHeight = Desc.Height >> mipSlice; + switch (Desc.Dimension) + { + case D3D12_RESOURCE_DIMENSION_BUFFER: + TopLeftX = topLeftX; + TopLeftY = 0.0f; + Width = Desc.Width - topLeftX; + Height = 1.0f; + break; + case D3D12_RESOURCE_DIMENSION_TEXTURE1D: + TopLeftX = topLeftX; + TopLeftY = 0.0f; + Width = (SubresourceWidth ? SubresourceWidth : 1.0f) - topLeftX; + Height = 1.0f; + break; + case D3D12_RESOURCE_DIMENSION_TEXTURE2D: + case D3D12_RESOURCE_DIMENSION_TEXTURE3D: + TopLeftX = topLeftX; + TopLeftY = topLeftY; + Width = (SubresourceWidth ? SubresourceWidth : 1.0f) - topLeftX; + Height = (SubresourceHeight ? SubresourceHeight: 1.0f) - topLeftY; + break; + default: break; + } + + MinDepth = minDepth; + MaxDepth = maxDepth; + } +}; + +//------------------------------------------------------------------------------------------------ +struct CD3DX12_BOX : public D3D12_BOX +{ + CD3DX12_BOX() = default; + explicit CD3DX12_BOX( const D3D12_BOX& o ) : + D3D12_BOX( o ) + {} + explicit CD3DX12_BOX( + LONG Left, + LONG Right ) + { + left = Left; + top = 0; + front = 0; + right = Right; + bottom = 1; + back = 1; + } + explicit CD3DX12_BOX( + LONG Left, + LONG Top, + LONG Right, + LONG Bottom ) + { + left = Left; + top = Top; + front = 0; + right = Right; + bottom = Bottom; + back = 1; + } + explicit CD3DX12_BOX( + LONG Left, + LONG Top, + LONG Front, + LONG Right, + LONG Bottom, + LONG Back ) + { + left = Left; + top = Top; + front = Front; + right = Right; + bottom = Bottom; + back = Back; + } +}; +inline bool operator==( const D3D12_BOX& l, const D3D12_BOX& r ) +{ + return l.left == r.left && l.top == r.top && l.front == r.front && + l.right == r.right && l.bottom == r.bottom && l.back == r.back; +} +inline bool operator!=( const D3D12_BOX& l, const D3D12_BOX& r ) +{ return !( l == r ); } + +//------------------------------------------------------------------------------------------------ +struct CD3DX12_DEPTH_STENCIL_DESC : public D3D12_DEPTH_STENCIL_DESC +{ + CD3DX12_DEPTH_STENCIL_DESC() = default; + explicit CD3DX12_DEPTH_STENCIL_DESC( const D3D12_DEPTH_STENCIL_DESC& o ) : + D3D12_DEPTH_STENCIL_DESC( o ) + {} + explicit CD3DX12_DEPTH_STENCIL_DESC( CD3DX12_DEFAULT ) + { + DepthEnable = TRUE; + DepthWriteMask = D3D12_DEPTH_WRITE_MASK_ALL; + DepthFunc = D3D12_COMPARISON_FUNC_LESS; + StencilEnable = FALSE; + StencilReadMask = D3D12_DEFAULT_STENCIL_READ_MASK; + StencilWriteMask = D3D12_DEFAULT_STENCIL_WRITE_MASK; + const D3D12_DEPTH_STENCILOP_DESC defaultStencilOp = + { D3D12_STENCIL_OP_KEEP, D3D12_STENCIL_OP_KEEP, D3D12_STENCIL_OP_KEEP, D3D12_COMPARISON_FUNC_ALWAYS }; + FrontFace = defaultStencilOp; + BackFace = defaultStencilOp; + } + explicit CD3DX12_DEPTH_STENCIL_DESC( + BOOL depthEnable, + D3D12_DEPTH_WRITE_MASK depthWriteMask, + D3D12_COMPARISON_FUNC depthFunc, + BOOL stencilEnable, + UINT8 stencilReadMask, + UINT8 stencilWriteMask, + D3D12_STENCIL_OP frontStencilFailOp, + D3D12_STENCIL_OP frontStencilDepthFailOp, + D3D12_STENCIL_OP frontStencilPassOp, + D3D12_COMPARISON_FUNC frontStencilFunc, + D3D12_STENCIL_OP backStencilFailOp, + D3D12_STENCIL_OP backStencilDepthFailOp, + D3D12_STENCIL_OP backStencilPassOp, + D3D12_COMPARISON_FUNC backStencilFunc ) + { + DepthEnable = depthEnable; + DepthWriteMask = depthWriteMask; + DepthFunc = depthFunc; + StencilEnable = stencilEnable; + StencilReadMask = stencilReadMask; + StencilWriteMask = stencilWriteMask; + FrontFace.StencilFailOp = frontStencilFailOp; + FrontFace.StencilDepthFailOp = frontStencilDepthFailOp; + FrontFace.StencilPassOp = frontStencilPassOp; + FrontFace.StencilFunc = frontStencilFunc; + BackFace.StencilFailOp = backStencilFailOp; + BackFace.StencilDepthFailOp = backStencilDepthFailOp; + BackFace.StencilPassOp = backStencilPassOp; + BackFace.StencilFunc = backStencilFunc; + } +}; + +//------------------------------------------------------------------------------------------------ +struct CD3DX12_DEPTH_STENCIL_DESC1 : public D3D12_DEPTH_STENCIL_DESC1 +{ + CD3DX12_DEPTH_STENCIL_DESC1() = default; + explicit CD3DX12_DEPTH_STENCIL_DESC1( const D3D12_DEPTH_STENCIL_DESC1& o ) : + D3D12_DEPTH_STENCIL_DESC1( o ) + {} + explicit CD3DX12_DEPTH_STENCIL_DESC1( const D3D12_DEPTH_STENCIL_DESC& o ) + { + DepthEnable = o.DepthEnable; + DepthWriteMask = o.DepthWriteMask; + DepthFunc = o.DepthFunc; + StencilEnable = o.StencilEnable; + StencilReadMask = o.StencilReadMask; + StencilWriteMask = o.StencilWriteMask; + FrontFace.StencilFailOp = o.FrontFace.StencilFailOp; + FrontFace.StencilDepthFailOp = o.FrontFace.StencilDepthFailOp; + FrontFace.StencilPassOp = o.FrontFace.StencilPassOp; + FrontFace.StencilFunc = o.FrontFace.StencilFunc; + BackFace.StencilFailOp = o.BackFace.StencilFailOp; + BackFace.StencilDepthFailOp = o.BackFace.StencilDepthFailOp; + BackFace.StencilPassOp = o.BackFace.StencilPassOp; + BackFace.StencilFunc = o.BackFace.StencilFunc; + DepthBoundsTestEnable = FALSE; + } + explicit CD3DX12_DEPTH_STENCIL_DESC1( CD3DX12_DEFAULT ) + { + DepthEnable = TRUE; + DepthWriteMask = D3D12_DEPTH_WRITE_MASK_ALL; + DepthFunc = D3D12_COMPARISON_FUNC_LESS; + StencilEnable = FALSE; + StencilReadMask = D3D12_DEFAULT_STENCIL_READ_MASK; + StencilWriteMask = D3D12_DEFAULT_STENCIL_WRITE_MASK; + const D3D12_DEPTH_STENCILOP_DESC defaultStencilOp = + { D3D12_STENCIL_OP_KEEP, D3D12_STENCIL_OP_KEEP, D3D12_STENCIL_OP_KEEP, D3D12_COMPARISON_FUNC_ALWAYS }; + FrontFace = defaultStencilOp; + BackFace = defaultStencilOp; + DepthBoundsTestEnable = FALSE; + } + explicit CD3DX12_DEPTH_STENCIL_DESC1( + BOOL depthEnable, + D3D12_DEPTH_WRITE_MASK depthWriteMask, + D3D12_COMPARISON_FUNC depthFunc, + BOOL stencilEnable, + UINT8 stencilReadMask, + UINT8 stencilWriteMask, + D3D12_STENCIL_OP frontStencilFailOp, + D3D12_STENCIL_OP frontStencilDepthFailOp, + D3D12_STENCIL_OP frontStencilPassOp, + D3D12_COMPARISON_FUNC frontStencilFunc, + D3D12_STENCIL_OP backStencilFailOp, + D3D12_STENCIL_OP backStencilDepthFailOp, + D3D12_STENCIL_OP backStencilPassOp, + D3D12_COMPARISON_FUNC backStencilFunc, + BOOL depthBoundsTestEnable ) + { + DepthEnable = depthEnable; + DepthWriteMask = depthWriteMask; + DepthFunc = depthFunc; + StencilEnable = stencilEnable; + StencilReadMask = stencilReadMask; + StencilWriteMask = stencilWriteMask; + FrontFace.StencilFailOp = frontStencilFailOp; + FrontFace.StencilDepthFailOp = frontStencilDepthFailOp; + FrontFace.StencilPassOp = frontStencilPassOp; + FrontFace.StencilFunc = frontStencilFunc; + BackFace.StencilFailOp = backStencilFailOp; + BackFace.StencilDepthFailOp = backStencilDepthFailOp; + BackFace.StencilPassOp = backStencilPassOp; + BackFace.StencilFunc = backStencilFunc; + DepthBoundsTestEnable = depthBoundsTestEnable; + } + operator D3D12_DEPTH_STENCIL_DESC() const + { + D3D12_DEPTH_STENCIL_DESC D; + D.DepthEnable = DepthEnable; + D.DepthWriteMask = DepthWriteMask; + D.DepthFunc = DepthFunc; + D.StencilEnable = StencilEnable; + D.StencilReadMask = StencilReadMask; + D.StencilWriteMask = StencilWriteMask; + D.FrontFace.StencilFailOp = FrontFace.StencilFailOp; + D.FrontFace.StencilDepthFailOp = FrontFace.StencilDepthFailOp; + D.FrontFace.StencilPassOp = FrontFace.StencilPassOp; + D.FrontFace.StencilFunc = FrontFace.StencilFunc; + D.BackFace.StencilFailOp = BackFace.StencilFailOp; + D.BackFace.StencilDepthFailOp = BackFace.StencilDepthFailOp; + D.BackFace.StencilPassOp = BackFace.StencilPassOp; + D.BackFace.StencilFunc = BackFace.StencilFunc; + return D; + } +}; + +//------------------------------------------------------------------------------------------------ +struct CD3DX12_BLEND_DESC : public D3D12_BLEND_DESC +{ + CD3DX12_BLEND_DESC() = default; + explicit CD3DX12_BLEND_DESC( const D3D12_BLEND_DESC& o ) : + D3D12_BLEND_DESC( o ) + {} + explicit CD3DX12_BLEND_DESC( CD3DX12_DEFAULT ) + { + AlphaToCoverageEnable = FALSE; + IndependentBlendEnable = FALSE; + const D3D12_RENDER_TARGET_BLEND_DESC defaultRenderTargetBlendDesc = + { + FALSE,FALSE, + D3D12_BLEND_ONE, D3D12_BLEND_ZERO, D3D12_BLEND_OP_ADD, + D3D12_BLEND_ONE, D3D12_BLEND_ZERO, D3D12_BLEND_OP_ADD, + D3D12_LOGIC_OP_NOOP, + D3D12_COLOR_WRITE_ENABLE_ALL, + }; + for (UINT i = 0; i < D3D12_SIMULTANEOUS_RENDER_TARGET_COUNT; ++i) + RenderTarget[ i ] = defaultRenderTargetBlendDesc; + } +}; + +//------------------------------------------------------------------------------------------------ +struct CD3DX12_RASTERIZER_DESC : public D3D12_RASTERIZER_DESC +{ + CD3DX12_RASTERIZER_DESC() = default; + explicit CD3DX12_RASTERIZER_DESC( const D3D12_RASTERIZER_DESC& o ) : + D3D12_RASTERIZER_DESC( o ) + {} + explicit CD3DX12_RASTERIZER_DESC( CD3DX12_DEFAULT ) + { + FillMode = D3D12_FILL_MODE_SOLID; + CullMode = D3D12_CULL_MODE_BACK; + FrontCounterClockwise = FALSE; + DepthBias = D3D12_DEFAULT_DEPTH_BIAS; + DepthBiasClamp = D3D12_DEFAULT_DEPTH_BIAS_CLAMP; + SlopeScaledDepthBias = D3D12_DEFAULT_SLOPE_SCALED_DEPTH_BIAS; + DepthClipEnable = TRUE; + MultisampleEnable = FALSE; + AntialiasedLineEnable = FALSE; + ForcedSampleCount = 0; + ConservativeRaster = D3D12_CONSERVATIVE_RASTERIZATION_MODE_OFF; + } + explicit CD3DX12_RASTERIZER_DESC( + D3D12_FILL_MODE fillMode, + D3D12_CULL_MODE cullMode, + BOOL frontCounterClockwise, + INT depthBias, + FLOAT depthBiasClamp, + FLOAT slopeScaledDepthBias, + BOOL depthClipEnable, + BOOL multisampleEnable, + BOOL antialiasedLineEnable, + UINT forcedSampleCount, + D3D12_CONSERVATIVE_RASTERIZATION_MODE conservativeRaster) + { + FillMode = fillMode; + CullMode = cullMode; + FrontCounterClockwise = frontCounterClockwise; + DepthBias = depthBias; + DepthBiasClamp = depthBiasClamp; + SlopeScaledDepthBias = slopeScaledDepthBias; + DepthClipEnable = depthClipEnable; + MultisampleEnable = multisampleEnable; + AntialiasedLineEnable = antialiasedLineEnable; + ForcedSampleCount = forcedSampleCount; + ConservativeRaster = conservativeRaster; + } +}; + +//------------------------------------------------------------------------------------------------ +struct CD3DX12_RESOURCE_ALLOCATION_INFO : public D3D12_RESOURCE_ALLOCATION_INFO +{ + CD3DX12_RESOURCE_ALLOCATION_INFO() = default; + explicit CD3DX12_RESOURCE_ALLOCATION_INFO( const D3D12_RESOURCE_ALLOCATION_INFO& o ) : + D3D12_RESOURCE_ALLOCATION_INFO( o ) + {} + CD3DX12_RESOURCE_ALLOCATION_INFO( + UINT64 size, + UINT64 alignment ) + { + SizeInBytes = size; + Alignment = alignment; + } +}; + +//------------------------------------------------------------------------------------------------ +struct CD3DX12_HEAP_PROPERTIES : public D3D12_HEAP_PROPERTIES +{ + CD3DX12_HEAP_PROPERTIES() = default; + explicit CD3DX12_HEAP_PROPERTIES(const D3D12_HEAP_PROPERTIES &o) : + D3D12_HEAP_PROPERTIES(o) + {} + CD3DX12_HEAP_PROPERTIES( + D3D12_CPU_PAGE_PROPERTY cpuPageProperty, + D3D12_MEMORY_POOL memoryPoolPreference, + UINT creationNodeMask = 1, + UINT nodeMask = 1 ) + { + Type = D3D12_HEAP_TYPE_CUSTOM; + CPUPageProperty = cpuPageProperty; + MemoryPoolPreference = memoryPoolPreference; + CreationNodeMask = creationNodeMask; + VisibleNodeMask = nodeMask; + } + explicit CD3DX12_HEAP_PROPERTIES( + D3D12_HEAP_TYPE type, + UINT creationNodeMask = 1, + UINT nodeMask = 1 ) + { + Type = type; + CPUPageProperty = D3D12_CPU_PAGE_PROPERTY_UNKNOWN; + MemoryPoolPreference = D3D12_MEMORY_POOL_UNKNOWN; + CreationNodeMask = creationNodeMask; + VisibleNodeMask = nodeMask; + } + bool IsCPUAccessible() const + { + return Type == D3D12_HEAP_TYPE_UPLOAD || Type == D3D12_HEAP_TYPE_READBACK || (Type == D3D12_HEAP_TYPE_CUSTOM && + (CPUPageProperty == D3D12_CPU_PAGE_PROPERTY_WRITE_COMBINE || CPUPageProperty == D3D12_CPU_PAGE_PROPERTY_WRITE_BACK)); + } +}; +inline bool operator==( const D3D12_HEAP_PROPERTIES& l, const D3D12_HEAP_PROPERTIES& r ) +{ + return l.Type == r.Type && l.CPUPageProperty == r.CPUPageProperty && + l.MemoryPoolPreference == r.MemoryPoolPreference && + l.CreationNodeMask == r.CreationNodeMask && + l.VisibleNodeMask == r.VisibleNodeMask; +} +inline bool operator!=( const D3D12_HEAP_PROPERTIES& l, const D3D12_HEAP_PROPERTIES& r ) +{ return !( l == r ); } + +//------------------------------------------------------------------------------------------------ +struct CD3DX12_HEAP_DESC : public D3D12_HEAP_DESC +{ + CD3DX12_HEAP_DESC() = default; + explicit CD3DX12_HEAP_DESC(const D3D12_HEAP_DESC &o) : + D3D12_HEAP_DESC(o) + {} + CD3DX12_HEAP_DESC( + UINT64 size, + D3D12_HEAP_PROPERTIES properties, + UINT64 alignment = 0, + D3D12_HEAP_FLAGS flags = D3D12_HEAP_FLAG_NONE ) + { + SizeInBytes = size; + Properties = properties; + Alignment = alignment; + Flags = flags; + } + CD3DX12_HEAP_DESC( + UINT64 size, + D3D12_HEAP_TYPE type, + UINT64 alignment = 0, + D3D12_HEAP_FLAGS flags = D3D12_HEAP_FLAG_NONE ) + { + SizeInBytes = size; + Properties = CD3DX12_HEAP_PROPERTIES( type ); + Alignment = alignment; + Flags = flags; + } + CD3DX12_HEAP_DESC( + UINT64 size, + D3D12_CPU_PAGE_PROPERTY cpuPageProperty, + D3D12_MEMORY_POOL memoryPoolPreference, + UINT64 alignment = 0, + D3D12_HEAP_FLAGS flags = D3D12_HEAP_FLAG_NONE ) + { + SizeInBytes = size; + Properties = CD3DX12_HEAP_PROPERTIES( cpuPageProperty, memoryPoolPreference ); + Alignment = alignment; + Flags = flags; + } + CD3DX12_HEAP_DESC( + const D3D12_RESOURCE_ALLOCATION_INFO& resAllocInfo, + D3D12_HEAP_PROPERTIES properties, + D3D12_HEAP_FLAGS flags = D3D12_HEAP_FLAG_NONE ) + { + SizeInBytes = resAllocInfo.SizeInBytes; + Properties = properties; + Alignment = resAllocInfo.Alignment; + Flags = flags; + } + CD3DX12_HEAP_DESC( + const D3D12_RESOURCE_ALLOCATION_INFO& resAllocInfo, + D3D12_HEAP_TYPE type, + D3D12_HEAP_FLAGS flags = D3D12_HEAP_FLAG_NONE ) + { + SizeInBytes = resAllocInfo.SizeInBytes; + Properties = CD3DX12_HEAP_PROPERTIES( type ); + Alignment = resAllocInfo.Alignment; + Flags = flags; + } + CD3DX12_HEAP_DESC( + const D3D12_RESOURCE_ALLOCATION_INFO& resAllocInfo, + D3D12_CPU_PAGE_PROPERTY cpuPageProperty, + D3D12_MEMORY_POOL memoryPoolPreference, + D3D12_HEAP_FLAGS flags = D3D12_HEAP_FLAG_NONE ) + { + SizeInBytes = resAllocInfo.SizeInBytes; + Properties = CD3DX12_HEAP_PROPERTIES( cpuPageProperty, memoryPoolPreference ); + Alignment = resAllocInfo.Alignment; + Flags = flags; + } + bool IsCPUAccessible() const + { return static_cast< const CD3DX12_HEAP_PROPERTIES* >( &Properties )->IsCPUAccessible(); } +}; +inline bool operator==( const D3D12_HEAP_DESC& l, const D3D12_HEAP_DESC& r ) +{ + return l.SizeInBytes == r.SizeInBytes && + l.Properties == r.Properties && + l.Alignment == r.Alignment && + l.Flags == r.Flags; +} +inline bool operator!=( const D3D12_HEAP_DESC& l, const D3D12_HEAP_DESC& r ) +{ return !( l == r ); } + +//------------------------------------------------------------------------------------------------ +struct CD3DX12_CLEAR_VALUE : public D3D12_CLEAR_VALUE +{ + CD3DX12_CLEAR_VALUE() = default; + explicit CD3DX12_CLEAR_VALUE(const D3D12_CLEAR_VALUE &o) : + D3D12_CLEAR_VALUE(o) + {} + CD3DX12_CLEAR_VALUE( + DXGI_FORMAT format, + const FLOAT color[4] ) + { + Format = format; + memcpy( Color, color, sizeof( Color ) ); + } + CD3DX12_CLEAR_VALUE( + DXGI_FORMAT format, + FLOAT depth, + UINT8 stencil ) + { + Format = format; + /* Use memcpy to preserve NAN values */ + memcpy( &DepthStencil.Depth, &depth, sizeof( depth ) ); + DepthStencil.Stencil = stencil; + } +}; + +//------------------------------------------------------------------------------------------------ +struct CD3DX12_RANGE : public D3D12_RANGE +{ + CD3DX12_RANGE() = default; + explicit CD3DX12_RANGE(const D3D12_RANGE &o) : + D3D12_RANGE(o) + {} + CD3DX12_RANGE( + SIZE_T begin, + SIZE_T end ) + { + Begin = begin; + End = end; + } +}; + +//------------------------------------------------------------------------------------------------ +struct CD3DX12_RANGE_UINT64 : public D3D12_RANGE_UINT64 +{ + CD3DX12_RANGE_UINT64() = default; + explicit CD3DX12_RANGE_UINT64(const D3D12_RANGE_UINT64 &o) : + D3D12_RANGE_UINT64(o) + {} + CD3DX12_RANGE_UINT64( + UINT64 begin, + UINT64 end ) + { + Begin = begin; + End = end; + } +}; + +//------------------------------------------------------------------------------------------------ +struct CD3DX12_SUBRESOURCE_RANGE_UINT64 : public D3D12_SUBRESOURCE_RANGE_UINT64 +{ + CD3DX12_SUBRESOURCE_RANGE_UINT64() = default; + explicit CD3DX12_SUBRESOURCE_RANGE_UINT64(const D3D12_SUBRESOURCE_RANGE_UINT64 &o) : + D3D12_SUBRESOURCE_RANGE_UINT64(o) + {} + CD3DX12_SUBRESOURCE_RANGE_UINT64( + UINT subresource, + const D3D12_RANGE_UINT64& range ) + { + Subresource = subresource; + Range = range; + } + CD3DX12_SUBRESOURCE_RANGE_UINT64( + UINT subresource, + UINT64 begin, + UINT64 end ) + { + Subresource = subresource; + Range.Begin = begin; + Range.End = end; + } +}; + +//------------------------------------------------------------------------------------------------ +struct CD3DX12_SHADER_BYTECODE : public D3D12_SHADER_BYTECODE +{ + CD3DX12_SHADER_BYTECODE() = default; + explicit CD3DX12_SHADER_BYTECODE(const D3D12_SHADER_BYTECODE &o) : + D3D12_SHADER_BYTECODE(o) + {} + CD3DX12_SHADER_BYTECODE( + _In_ ID3DBlob* pShaderBlob ) + { + pShaderBytecode = pShaderBlob->GetBufferPointer(); + BytecodeLength = pShaderBlob->GetBufferSize(); + } + CD3DX12_SHADER_BYTECODE( + const void* _pShaderBytecode, + SIZE_T bytecodeLength ) + { + pShaderBytecode = _pShaderBytecode; + BytecodeLength = bytecodeLength; + } +}; + +//------------------------------------------------------------------------------------------------ +struct CD3DX12_TILED_RESOURCE_COORDINATE : public D3D12_TILED_RESOURCE_COORDINATE +{ + CD3DX12_TILED_RESOURCE_COORDINATE() = default; + explicit CD3DX12_TILED_RESOURCE_COORDINATE(const D3D12_TILED_RESOURCE_COORDINATE &o) : + D3D12_TILED_RESOURCE_COORDINATE(o) + {} + CD3DX12_TILED_RESOURCE_COORDINATE( + UINT x, + UINT y, + UINT z, + UINT subresource ) + { + X = x; + Y = y; + Z = z; + Subresource = subresource; + } +}; + +//------------------------------------------------------------------------------------------------ +struct CD3DX12_TILE_REGION_SIZE : public D3D12_TILE_REGION_SIZE +{ + CD3DX12_TILE_REGION_SIZE() = default; + explicit CD3DX12_TILE_REGION_SIZE(const D3D12_TILE_REGION_SIZE &o) : + D3D12_TILE_REGION_SIZE(o) + {} + CD3DX12_TILE_REGION_SIZE( + UINT numTiles, + BOOL useBox, + UINT width, + UINT16 height, + UINT16 depth ) + { + NumTiles = numTiles; + UseBox = useBox; + Width = width; + Height = height; + Depth = depth; + } +}; + +//------------------------------------------------------------------------------------------------ +struct CD3DX12_SUBRESOURCE_TILING : public D3D12_SUBRESOURCE_TILING +{ + CD3DX12_SUBRESOURCE_TILING() = default; + explicit CD3DX12_SUBRESOURCE_TILING(const D3D12_SUBRESOURCE_TILING &o) : + D3D12_SUBRESOURCE_TILING(o) + {} + CD3DX12_SUBRESOURCE_TILING( + UINT widthInTiles, + UINT16 heightInTiles, + UINT16 depthInTiles, + UINT startTileIndexInOverallResource ) + { + WidthInTiles = widthInTiles; + HeightInTiles = heightInTiles; + DepthInTiles = depthInTiles; + StartTileIndexInOverallResource = startTileIndexInOverallResource; + } +}; + +//------------------------------------------------------------------------------------------------ +struct CD3DX12_TILE_SHAPE : public D3D12_TILE_SHAPE +{ + CD3DX12_TILE_SHAPE() = default; + explicit CD3DX12_TILE_SHAPE(const D3D12_TILE_SHAPE &o) : + D3D12_TILE_SHAPE(o) + {} + CD3DX12_TILE_SHAPE( + UINT widthInTexels, + UINT heightInTexels, + UINT depthInTexels ) + { + WidthInTexels = widthInTexels; + HeightInTexels = heightInTexels; + DepthInTexels = depthInTexels; + } +}; + +//------------------------------------------------------------------------------------------------ +struct CD3DX12_RESOURCE_BARRIER : public D3D12_RESOURCE_BARRIER +{ + CD3DX12_RESOURCE_BARRIER() = default; + explicit CD3DX12_RESOURCE_BARRIER(const D3D12_RESOURCE_BARRIER &o) : + D3D12_RESOURCE_BARRIER(o) + {} + static inline CD3DX12_RESOURCE_BARRIER Transition( + _In_ ID3D12Resource* pResource, + D3D12_RESOURCE_STATES stateBefore, + D3D12_RESOURCE_STATES stateAfter, + UINT subresource = D3D12_RESOURCE_BARRIER_ALL_SUBRESOURCES, + D3D12_RESOURCE_BARRIER_FLAGS flags = D3D12_RESOURCE_BARRIER_FLAG_NONE) + { + CD3DX12_RESOURCE_BARRIER result = {}; + D3D12_RESOURCE_BARRIER &barrier = result; + result.Type = D3D12_RESOURCE_BARRIER_TYPE_TRANSITION; + result.Flags = flags; + barrier.Transition.pResource = pResource; + barrier.Transition.StateBefore = stateBefore; + barrier.Transition.StateAfter = stateAfter; + barrier.Transition.Subresource = subresource; + return result; + } + static inline CD3DX12_RESOURCE_BARRIER Aliasing( + _In_ ID3D12Resource* pResourceBefore, + _In_ ID3D12Resource* pResourceAfter) + { + CD3DX12_RESOURCE_BARRIER result = {}; + D3D12_RESOURCE_BARRIER &barrier = result; + result.Type = D3D12_RESOURCE_BARRIER_TYPE_ALIASING; + barrier.Aliasing.pResourceBefore = pResourceBefore; + barrier.Aliasing.pResourceAfter = pResourceAfter; + return result; + } + static inline CD3DX12_RESOURCE_BARRIER UAV( + _In_ ID3D12Resource* pResource) + { + CD3DX12_RESOURCE_BARRIER result = {}; + D3D12_RESOURCE_BARRIER &barrier = result; + result.Type = D3D12_RESOURCE_BARRIER_TYPE_UAV; + barrier.UAV.pResource = pResource; + return result; + } +}; + +//------------------------------------------------------------------------------------------------ +struct CD3DX12_PACKED_MIP_INFO : public D3D12_PACKED_MIP_INFO +{ + CD3DX12_PACKED_MIP_INFO() = default; + explicit CD3DX12_PACKED_MIP_INFO(const D3D12_PACKED_MIP_INFO &o) : + D3D12_PACKED_MIP_INFO(o) + {} + CD3DX12_PACKED_MIP_INFO( + UINT8 numStandardMips, + UINT8 numPackedMips, + UINT numTilesForPackedMips, + UINT startTileIndexInOverallResource ) + { + NumStandardMips = numStandardMips; + NumPackedMips = numPackedMips; + NumTilesForPackedMips = numTilesForPackedMips; + StartTileIndexInOverallResource = startTileIndexInOverallResource; + } +}; + +//------------------------------------------------------------------------------------------------ +struct CD3DX12_SUBRESOURCE_FOOTPRINT : public D3D12_SUBRESOURCE_FOOTPRINT +{ + CD3DX12_SUBRESOURCE_FOOTPRINT() = default; + explicit CD3DX12_SUBRESOURCE_FOOTPRINT(const D3D12_SUBRESOURCE_FOOTPRINT &o) : + D3D12_SUBRESOURCE_FOOTPRINT(o) + {} + CD3DX12_SUBRESOURCE_FOOTPRINT( + DXGI_FORMAT format, + UINT width, + UINT height, + UINT depth, + UINT rowPitch ) + { + Format = format; + Width = width; + Height = height; + Depth = depth; + RowPitch = rowPitch; + } + explicit CD3DX12_SUBRESOURCE_FOOTPRINT( + const D3D12_RESOURCE_DESC& resDesc, + UINT rowPitch ) + { + Format = resDesc.Format; + Width = UINT( resDesc.Width ); + Height = resDesc.Height; + Depth = (resDesc.Dimension == D3D12_RESOURCE_DIMENSION_TEXTURE3D ? resDesc.DepthOrArraySize : 1); + RowPitch = rowPitch; + } +}; + +//------------------------------------------------------------------------------------------------ +struct CD3DX12_TEXTURE_COPY_LOCATION : public D3D12_TEXTURE_COPY_LOCATION +{ + CD3DX12_TEXTURE_COPY_LOCATION() = default; + explicit CD3DX12_TEXTURE_COPY_LOCATION(const D3D12_TEXTURE_COPY_LOCATION &o) : + D3D12_TEXTURE_COPY_LOCATION(o) + {} + CD3DX12_TEXTURE_COPY_LOCATION(_In_ ID3D12Resource* pRes) + { + pResource = pRes; + Type = D3D12_TEXTURE_COPY_TYPE_SUBRESOURCE_INDEX; + PlacedFootprint = {}; + } + CD3DX12_TEXTURE_COPY_LOCATION(_In_ ID3D12Resource* pRes, D3D12_PLACED_SUBRESOURCE_FOOTPRINT const& Footprint) + { + pResource = pRes; + Type = D3D12_TEXTURE_COPY_TYPE_PLACED_FOOTPRINT; + PlacedFootprint = Footprint; + } + CD3DX12_TEXTURE_COPY_LOCATION(_In_ ID3D12Resource* pRes, UINT Sub) + { + pResource = pRes; + Type = D3D12_TEXTURE_COPY_TYPE_SUBRESOURCE_INDEX; + SubresourceIndex = Sub; + } +}; + +//------------------------------------------------------------------------------------------------ +struct CD3DX12_DESCRIPTOR_RANGE : public D3D12_DESCRIPTOR_RANGE +{ + CD3DX12_DESCRIPTOR_RANGE() = default; + explicit CD3DX12_DESCRIPTOR_RANGE(const D3D12_DESCRIPTOR_RANGE &o) : + D3D12_DESCRIPTOR_RANGE(o) + {} + CD3DX12_DESCRIPTOR_RANGE( + D3D12_DESCRIPTOR_RANGE_TYPE rangeType, + UINT numDescriptors, + UINT baseShaderRegister, + UINT registerSpace = 0, + UINT offsetInDescriptorsFromTableStart = + D3D12_DESCRIPTOR_RANGE_OFFSET_APPEND) + { + Init(rangeType, numDescriptors, baseShaderRegister, registerSpace, offsetInDescriptorsFromTableStart); + } + + inline void Init( + D3D12_DESCRIPTOR_RANGE_TYPE rangeType, + UINT numDescriptors, + UINT baseShaderRegister, + UINT registerSpace = 0, + UINT offsetInDescriptorsFromTableStart = + D3D12_DESCRIPTOR_RANGE_OFFSET_APPEND) + { + Init(*this, rangeType, numDescriptors, baseShaderRegister, registerSpace, offsetInDescriptorsFromTableStart); + } + + static inline void Init( + _Out_ D3D12_DESCRIPTOR_RANGE &range, + D3D12_DESCRIPTOR_RANGE_TYPE rangeType, + UINT numDescriptors, + UINT baseShaderRegister, + UINT registerSpace = 0, + UINT offsetInDescriptorsFromTableStart = + D3D12_DESCRIPTOR_RANGE_OFFSET_APPEND) + { + range.RangeType = rangeType; + range.NumDescriptors = numDescriptors; + range.BaseShaderRegister = baseShaderRegister; + range.RegisterSpace = registerSpace; + range.OffsetInDescriptorsFromTableStart = offsetInDescriptorsFromTableStart; + } +}; + +//------------------------------------------------------------------------------------------------ +struct CD3DX12_ROOT_DESCRIPTOR_TABLE : public D3D12_ROOT_DESCRIPTOR_TABLE +{ + CD3DX12_ROOT_DESCRIPTOR_TABLE() = default; + explicit CD3DX12_ROOT_DESCRIPTOR_TABLE(const D3D12_ROOT_DESCRIPTOR_TABLE &o) : + D3D12_ROOT_DESCRIPTOR_TABLE(o) + {} + CD3DX12_ROOT_DESCRIPTOR_TABLE( + UINT numDescriptorRanges, + _In_reads_opt_(numDescriptorRanges) const D3D12_DESCRIPTOR_RANGE* _pDescriptorRanges) + { + Init(numDescriptorRanges, _pDescriptorRanges); + } + + inline void Init( + UINT numDescriptorRanges, + _In_reads_opt_(numDescriptorRanges) const D3D12_DESCRIPTOR_RANGE* _pDescriptorRanges) + { + Init(*this, numDescriptorRanges, _pDescriptorRanges); + } + + static inline void Init( + _Out_ D3D12_ROOT_DESCRIPTOR_TABLE &rootDescriptorTable, + UINT numDescriptorRanges, + _In_reads_opt_(numDescriptorRanges) const D3D12_DESCRIPTOR_RANGE* _pDescriptorRanges) + { + rootDescriptorTable.NumDescriptorRanges = numDescriptorRanges; + rootDescriptorTable.pDescriptorRanges = _pDescriptorRanges; + } +}; + +//------------------------------------------------------------------------------------------------ +struct CD3DX12_ROOT_CONSTANTS : public D3D12_ROOT_CONSTANTS +{ + CD3DX12_ROOT_CONSTANTS() = default; + explicit CD3DX12_ROOT_CONSTANTS(const D3D12_ROOT_CONSTANTS &o) : + D3D12_ROOT_CONSTANTS(o) + {} + CD3DX12_ROOT_CONSTANTS( + UINT num32BitValues, + UINT shaderRegister, + UINT registerSpace = 0) + { + Init(num32BitValues, shaderRegister, registerSpace); + } + + inline void Init( + UINT num32BitValues, + UINT shaderRegister, + UINT registerSpace = 0) + { + Init(*this, num32BitValues, shaderRegister, registerSpace); + } + + static inline void Init( + _Out_ D3D12_ROOT_CONSTANTS &rootConstants, + UINT num32BitValues, + UINT shaderRegister, + UINT registerSpace = 0) + { + rootConstants.Num32BitValues = num32BitValues; + rootConstants.ShaderRegister = shaderRegister; + rootConstants.RegisterSpace = registerSpace; + } +}; + +//------------------------------------------------------------------------------------------------ +struct CD3DX12_ROOT_DESCRIPTOR : public D3D12_ROOT_DESCRIPTOR +{ + CD3DX12_ROOT_DESCRIPTOR() = default; + explicit CD3DX12_ROOT_DESCRIPTOR(const D3D12_ROOT_DESCRIPTOR &o) : + D3D12_ROOT_DESCRIPTOR(o) + {} + CD3DX12_ROOT_DESCRIPTOR( + UINT shaderRegister, + UINT registerSpace = 0) + { + Init(shaderRegister, registerSpace); + } + + inline void Init( + UINT shaderRegister, + UINT registerSpace = 0) + { + Init(*this, shaderRegister, registerSpace); + } + + static inline void Init(_Out_ D3D12_ROOT_DESCRIPTOR &table, UINT shaderRegister, UINT registerSpace = 0) + { + table.ShaderRegister = shaderRegister; + table.RegisterSpace = registerSpace; + } +}; + +//------------------------------------------------------------------------------------------------ +struct CD3DX12_ROOT_PARAMETER : public D3D12_ROOT_PARAMETER +{ + CD3DX12_ROOT_PARAMETER() = default; + explicit CD3DX12_ROOT_PARAMETER(const D3D12_ROOT_PARAMETER &o) : + D3D12_ROOT_PARAMETER(o) + {} + + static inline void InitAsDescriptorTable( + _Out_ D3D12_ROOT_PARAMETER &rootParam, + UINT numDescriptorRanges, + _In_reads_(numDescriptorRanges) const D3D12_DESCRIPTOR_RANGE* pDescriptorRanges, + D3D12_SHADER_VISIBILITY visibility = D3D12_SHADER_VISIBILITY_ALL) + { + rootParam.ParameterType = D3D12_ROOT_PARAMETER_TYPE_DESCRIPTOR_TABLE; + rootParam.ShaderVisibility = visibility; + CD3DX12_ROOT_DESCRIPTOR_TABLE::Init(rootParam.DescriptorTable, numDescriptorRanges, pDescriptorRanges); + } + + static inline void InitAsConstants( + _Out_ D3D12_ROOT_PARAMETER &rootParam, + UINT num32BitValues, + UINT shaderRegister, + UINT registerSpace = 0, + D3D12_SHADER_VISIBILITY visibility = D3D12_SHADER_VISIBILITY_ALL) + { + rootParam.ParameterType = D3D12_ROOT_PARAMETER_TYPE_32BIT_CONSTANTS; + rootParam.ShaderVisibility = visibility; + CD3DX12_ROOT_CONSTANTS::Init(rootParam.Constants, num32BitValues, shaderRegister, registerSpace); + } + + static inline void InitAsConstantBufferView( + _Out_ D3D12_ROOT_PARAMETER &rootParam, + UINT shaderRegister, + UINT registerSpace = 0, + D3D12_SHADER_VISIBILITY visibility = D3D12_SHADER_VISIBILITY_ALL) + { + rootParam.ParameterType = D3D12_ROOT_PARAMETER_TYPE_CBV; + rootParam.ShaderVisibility = visibility; + CD3DX12_ROOT_DESCRIPTOR::Init(rootParam.Descriptor, shaderRegister, registerSpace); + } + + static inline void InitAsShaderResourceView( + _Out_ D3D12_ROOT_PARAMETER &rootParam, + UINT shaderRegister, + UINT registerSpace = 0, + D3D12_SHADER_VISIBILITY visibility = D3D12_SHADER_VISIBILITY_ALL) + { + rootParam.ParameterType = D3D12_ROOT_PARAMETER_TYPE_SRV; + rootParam.ShaderVisibility = visibility; + CD3DX12_ROOT_DESCRIPTOR::Init(rootParam.Descriptor, shaderRegister, registerSpace); + } + + static inline void InitAsUnorderedAccessView( + _Out_ D3D12_ROOT_PARAMETER &rootParam, + UINT shaderRegister, + UINT registerSpace = 0, + D3D12_SHADER_VISIBILITY visibility = D3D12_SHADER_VISIBILITY_ALL) + { + rootParam.ParameterType = D3D12_ROOT_PARAMETER_TYPE_UAV; + rootParam.ShaderVisibility = visibility; + CD3DX12_ROOT_DESCRIPTOR::Init(rootParam.Descriptor, shaderRegister, registerSpace); + } + + inline void InitAsDescriptorTable( + UINT numDescriptorRanges, + _In_reads_(numDescriptorRanges) const D3D12_DESCRIPTOR_RANGE* pDescriptorRanges, + D3D12_SHADER_VISIBILITY visibility = D3D12_SHADER_VISIBILITY_ALL) + { + InitAsDescriptorTable(*this, numDescriptorRanges, pDescriptorRanges, visibility); + } + + inline void InitAsConstants( + UINT num32BitValues, + UINT shaderRegister, + UINT registerSpace = 0, + D3D12_SHADER_VISIBILITY visibility = D3D12_SHADER_VISIBILITY_ALL) + { + InitAsConstants(*this, num32BitValues, shaderRegister, registerSpace, visibility); + } + + inline void InitAsConstantBufferView( + UINT shaderRegister, + UINT registerSpace = 0, + D3D12_SHADER_VISIBILITY visibility = D3D12_SHADER_VISIBILITY_ALL) + { + InitAsConstantBufferView(*this, shaderRegister, registerSpace, visibility); + } + + inline void InitAsShaderResourceView( + UINT shaderRegister, + UINT registerSpace = 0, + D3D12_SHADER_VISIBILITY visibility = D3D12_SHADER_VISIBILITY_ALL) + { + InitAsShaderResourceView(*this, shaderRegister, registerSpace, visibility); + } + + inline void InitAsUnorderedAccessView( + UINT shaderRegister, + UINT registerSpace = 0, + D3D12_SHADER_VISIBILITY visibility = D3D12_SHADER_VISIBILITY_ALL) + { + InitAsUnorderedAccessView(*this, shaderRegister, registerSpace, visibility); + } +}; + +//------------------------------------------------------------------------------------------------ +struct CD3DX12_STATIC_SAMPLER_DESC : public D3D12_STATIC_SAMPLER_DESC +{ + CD3DX12_STATIC_SAMPLER_DESC() = default; + explicit CD3DX12_STATIC_SAMPLER_DESC(const D3D12_STATIC_SAMPLER_DESC &o) : + D3D12_STATIC_SAMPLER_DESC(o) + {} + CD3DX12_STATIC_SAMPLER_DESC( + UINT shaderRegister, + D3D12_FILTER filter = D3D12_FILTER_ANISOTROPIC, + D3D12_TEXTURE_ADDRESS_MODE addressU = D3D12_TEXTURE_ADDRESS_MODE_WRAP, + D3D12_TEXTURE_ADDRESS_MODE addressV = D3D12_TEXTURE_ADDRESS_MODE_WRAP, + D3D12_TEXTURE_ADDRESS_MODE addressW = D3D12_TEXTURE_ADDRESS_MODE_WRAP, + FLOAT mipLODBias = 0, + UINT maxAnisotropy = 16, + D3D12_COMPARISON_FUNC comparisonFunc = D3D12_COMPARISON_FUNC_LESS_EQUAL, + D3D12_STATIC_BORDER_COLOR borderColor = D3D12_STATIC_BORDER_COLOR_OPAQUE_WHITE, + FLOAT minLOD = 0.f, + FLOAT maxLOD = D3D12_FLOAT32_MAX, + D3D12_SHADER_VISIBILITY shaderVisibility = D3D12_SHADER_VISIBILITY_ALL, + UINT registerSpace = 0) + { + Init( + shaderRegister, + filter, + addressU, + addressV, + addressW, + mipLODBias, + maxAnisotropy, + comparisonFunc, + borderColor, + minLOD, + maxLOD, + shaderVisibility, + registerSpace); + } + + static inline void Init( + _Out_ D3D12_STATIC_SAMPLER_DESC &samplerDesc, + UINT shaderRegister, + D3D12_FILTER filter = D3D12_FILTER_ANISOTROPIC, + D3D12_TEXTURE_ADDRESS_MODE addressU = D3D12_TEXTURE_ADDRESS_MODE_WRAP, + D3D12_TEXTURE_ADDRESS_MODE addressV = D3D12_TEXTURE_ADDRESS_MODE_WRAP, + D3D12_TEXTURE_ADDRESS_MODE addressW = D3D12_TEXTURE_ADDRESS_MODE_WRAP, + FLOAT mipLODBias = 0, + UINT maxAnisotropy = 16, + D3D12_COMPARISON_FUNC comparisonFunc = D3D12_COMPARISON_FUNC_LESS_EQUAL, + D3D12_STATIC_BORDER_COLOR borderColor = D3D12_STATIC_BORDER_COLOR_OPAQUE_WHITE, + FLOAT minLOD = 0.f, + FLOAT maxLOD = D3D12_FLOAT32_MAX, + D3D12_SHADER_VISIBILITY shaderVisibility = D3D12_SHADER_VISIBILITY_ALL, + UINT registerSpace = 0) + { + samplerDesc.ShaderRegister = shaderRegister; + samplerDesc.Filter = filter; + samplerDesc.AddressU = addressU; + samplerDesc.AddressV = addressV; + samplerDesc.AddressW = addressW; + samplerDesc.MipLODBias = mipLODBias; + samplerDesc.MaxAnisotropy = maxAnisotropy; + samplerDesc.ComparisonFunc = comparisonFunc; + samplerDesc.BorderColor = borderColor; + samplerDesc.MinLOD = minLOD; + samplerDesc.MaxLOD = maxLOD; + samplerDesc.ShaderVisibility = shaderVisibility; + samplerDesc.RegisterSpace = registerSpace; + } + inline void Init( + UINT shaderRegister, + D3D12_FILTER filter = D3D12_FILTER_ANISOTROPIC, + D3D12_TEXTURE_ADDRESS_MODE addressU = D3D12_TEXTURE_ADDRESS_MODE_WRAP, + D3D12_TEXTURE_ADDRESS_MODE addressV = D3D12_TEXTURE_ADDRESS_MODE_WRAP, + D3D12_TEXTURE_ADDRESS_MODE addressW = D3D12_TEXTURE_ADDRESS_MODE_WRAP, + FLOAT mipLODBias = 0, + UINT maxAnisotropy = 16, + D3D12_COMPARISON_FUNC comparisonFunc = D3D12_COMPARISON_FUNC_LESS_EQUAL, + D3D12_STATIC_BORDER_COLOR borderColor = D3D12_STATIC_BORDER_COLOR_OPAQUE_WHITE, + FLOAT minLOD = 0.f, + FLOAT maxLOD = D3D12_FLOAT32_MAX, + D3D12_SHADER_VISIBILITY shaderVisibility = D3D12_SHADER_VISIBILITY_ALL, + UINT registerSpace = 0) + { + Init( + *this, + shaderRegister, + filter, + addressU, + addressV, + addressW, + mipLODBias, + maxAnisotropy, + comparisonFunc, + borderColor, + minLOD, + maxLOD, + shaderVisibility, + registerSpace); + } + +}; + +//------------------------------------------------------------------------------------------------ +struct CD3DX12_ROOT_SIGNATURE_DESC : public D3D12_ROOT_SIGNATURE_DESC +{ + CD3DX12_ROOT_SIGNATURE_DESC() = default; + explicit CD3DX12_ROOT_SIGNATURE_DESC(const D3D12_ROOT_SIGNATURE_DESC &o) : + D3D12_ROOT_SIGNATURE_DESC(o) + {} + CD3DX12_ROOT_SIGNATURE_DESC( + UINT numParameters, + _In_reads_opt_(numParameters) const D3D12_ROOT_PARAMETER* _pParameters, + UINT numStaticSamplers = 0, + _In_reads_opt_(numStaticSamplers) const D3D12_STATIC_SAMPLER_DESC* _pStaticSamplers = nullptr, + D3D12_ROOT_SIGNATURE_FLAGS flags = D3D12_ROOT_SIGNATURE_FLAG_NONE) + { + Init(numParameters, _pParameters, numStaticSamplers, _pStaticSamplers, flags); + } + CD3DX12_ROOT_SIGNATURE_DESC(CD3DX12_DEFAULT) + { + Init(0, nullptr, 0, nullptr, D3D12_ROOT_SIGNATURE_FLAG_NONE); + } + + inline void Init( + UINT numParameters, + _In_reads_opt_(numParameters) const D3D12_ROOT_PARAMETER* _pParameters, + UINT numStaticSamplers = 0, + _In_reads_opt_(numStaticSamplers) const D3D12_STATIC_SAMPLER_DESC* _pStaticSamplers = nullptr, + D3D12_ROOT_SIGNATURE_FLAGS flags = D3D12_ROOT_SIGNATURE_FLAG_NONE) + { + Init(*this, numParameters, _pParameters, numStaticSamplers, _pStaticSamplers, flags); + } + + static inline void Init( + _Out_ D3D12_ROOT_SIGNATURE_DESC &desc, + UINT numParameters, + _In_reads_opt_(numParameters) const D3D12_ROOT_PARAMETER* _pParameters, + UINT numStaticSamplers = 0, + _In_reads_opt_(numStaticSamplers) const D3D12_STATIC_SAMPLER_DESC* _pStaticSamplers = nullptr, + D3D12_ROOT_SIGNATURE_FLAGS flags = D3D12_ROOT_SIGNATURE_FLAG_NONE) + { + desc.NumParameters = numParameters; + desc.pParameters = _pParameters; + desc.NumStaticSamplers = numStaticSamplers; + desc.pStaticSamplers = _pStaticSamplers; + desc.Flags = flags; + } +}; + +//------------------------------------------------------------------------------------------------ +struct CD3DX12_DESCRIPTOR_RANGE1 : public D3D12_DESCRIPTOR_RANGE1 +{ + CD3DX12_DESCRIPTOR_RANGE1() = default; + explicit CD3DX12_DESCRIPTOR_RANGE1(const D3D12_DESCRIPTOR_RANGE1 &o) : + D3D12_DESCRIPTOR_RANGE1(o) + {} + CD3DX12_DESCRIPTOR_RANGE1( + D3D12_DESCRIPTOR_RANGE_TYPE rangeType, + UINT numDescriptors, + UINT baseShaderRegister, + UINT registerSpace = 0, + D3D12_DESCRIPTOR_RANGE_FLAGS flags = D3D12_DESCRIPTOR_RANGE_FLAG_NONE, + UINT offsetInDescriptorsFromTableStart = + D3D12_DESCRIPTOR_RANGE_OFFSET_APPEND) + { + Init(rangeType, numDescriptors, baseShaderRegister, registerSpace, flags, offsetInDescriptorsFromTableStart); + } + + inline void Init( + D3D12_DESCRIPTOR_RANGE_TYPE rangeType, + UINT numDescriptors, + UINT baseShaderRegister, + UINT registerSpace = 0, + D3D12_DESCRIPTOR_RANGE_FLAGS flags = D3D12_DESCRIPTOR_RANGE_FLAG_NONE, + UINT offsetInDescriptorsFromTableStart = + D3D12_DESCRIPTOR_RANGE_OFFSET_APPEND) + { + Init(*this, rangeType, numDescriptors, baseShaderRegister, registerSpace, flags, offsetInDescriptorsFromTableStart); + } + + static inline void Init( + _Out_ D3D12_DESCRIPTOR_RANGE1 &range, + D3D12_DESCRIPTOR_RANGE_TYPE rangeType, + UINT numDescriptors, + UINT baseShaderRegister, + UINT registerSpace = 0, + D3D12_DESCRIPTOR_RANGE_FLAGS flags = D3D12_DESCRIPTOR_RANGE_FLAG_NONE, + UINT offsetInDescriptorsFromTableStart = + D3D12_DESCRIPTOR_RANGE_OFFSET_APPEND) + { + range.RangeType = rangeType; + range.NumDescriptors = numDescriptors; + range.BaseShaderRegister = baseShaderRegister; + range.RegisterSpace = registerSpace; + range.Flags = flags; + range.OffsetInDescriptorsFromTableStart = offsetInDescriptorsFromTableStart; + } +}; + +//------------------------------------------------------------------------------------------------ +struct CD3DX12_ROOT_DESCRIPTOR_TABLE1 : public D3D12_ROOT_DESCRIPTOR_TABLE1 +{ + CD3DX12_ROOT_DESCRIPTOR_TABLE1() = default; + explicit CD3DX12_ROOT_DESCRIPTOR_TABLE1(const D3D12_ROOT_DESCRIPTOR_TABLE1 &o) : + D3D12_ROOT_DESCRIPTOR_TABLE1(o) + {} + CD3DX12_ROOT_DESCRIPTOR_TABLE1( + UINT numDescriptorRanges, + _In_reads_opt_(numDescriptorRanges) const D3D12_DESCRIPTOR_RANGE1* _pDescriptorRanges) + { + Init(numDescriptorRanges, _pDescriptorRanges); + } + + inline void Init( + UINT numDescriptorRanges, + _In_reads_opt_(numDescriptorRanges) const D3D12_DESCRIPTOR_RANGE1* _pDescriptorRanges) + { + Init(*this, numDescriptorRanges, _pDescriptorRanges); + } + + static inline void Init( + _Out_ D3D12_ROOT_DESCRIPTOR_TABLE1 &rootDescriptorTable, + UINT numDescriptorRanges, + _In_reads_opt_(numDescriptorRanges) const D3D12_DESCRIPTOR_RANGE1* _pDescriptorRanges) + { + rootDescriptorTable.NumDescriptorRanges = numDescriptorRanges; + rootDescriptorTable.pDescriptorRanges = _pDescriptorRanges; + } +}; + +//------------------------------------------------------------------------------------------------ +struct CD3DX12_ROOT_DESCRIPTOR1 : public D3D12_ROOT_DESCRIPTOR1 +{ + CD3DX12_ROOT_DESCRIPTOR1() = default; + explicit CD3DX12_ROOT_DESCRIPTOR1(const D3D12_ROOT_DESCRIPTOR1 &o) : + D3D12_ROOT_DESCRIPTOR1(o) + {} + CD3DX12_ROOT_DESCRIPTOR1( + UINT shaderRegister, + UINT registerSpace = 0, + D3D12_ROOT_DESCRIPTOR_FLAGS flags = D3D12_ROOT_DESCRIPTOR_FLAG_NONE) + { + Init(shaderRegister, registerSpace, flags); + } + + inline void Init( + UINT shaderRegister, + UINT registerSpace = 0, + D3D12_ROOT_DESCRIPTOR_FLAGS flags = D3D12_ROOT_DESCRIPTOR_FLAG_NONE) + { + Init(*this, shaderRegister, registerSpace, flags); + } + + static inline void Init( + _Out_ D3D12_ROOT_DESCRIPTOR1 &table, + UINT shaderRegister, + UINT registerSpace = 0, + D3D12_ROOT_DESCRIPTOR_FLAGS flags = D3D12_ROOT_DESCRIPTOR_FLAG_NONE) + { + table.ShaderRegister = shaderRegister; + table.RegisterSpace = registerSpace; + table.Flags = flags; + } +}; + +//------------------------------------------------------------------------------------------------ +struct CD3DX12_ROOT_PARAMETER1 : public D3D12_ROOT_PARAMETER1 +{ + CD3DX12_ROOT_PARAMETER1() = default; + explicit CD3DX12_ROOT_PARAMETER1(const D3D12_ROOT_PARAMETER1 &o) : + D3D12_ROOT_PARAMETER1(o) + {} + + static inline void InitAsDescriptorTable( + _Out_ D3D12_ROOT_PARAMETER1 &rootParam, + UINT numDescriptorRanges, + _In_reads_(numDescriptorRanges) const D3D12_DESCRIPTOR_RANGE1* pDescriptorRanges, + D3D12_SHADER_VISIBILITY visibility = D3D12_SHADER_VISIBILITY_ALL) + { + rootParam.ParameterType = D3D12_ROOT_PARAMETER_TYPE_DESCRIPTOR_TABLE; + rootParam.ShaderVisibility = visibility; + CD3DX12_ROOT_DESCRIPTOR_TABLE1::Init(rootParam.DescriptorTable, numDescriptorRanges, pDescriptorRanges); + } + + static inline void InitAsConstants( + _Out_ D3D12_ROOT_PARAMETER1 &rootParam, + UINT num32BitValues, + UINT shaderRegister, + UINT registerSpace = 0, + D3D12_SHADER_VISIBILITY visibility = D3D12_SHADER_VISIBILITY_ALL) + { + rootParam.ParameterType = D3D12_ROOT_PARAMETER_TYPE_32BIT_CONSTANTS; + rootParam.ShaderVisibility = visibility; + CD3DX12_ROOT_CONSTANTS::Init(rootParam.Constants, num32BitValues, shaderRegister, registerSpace); + } + + static inline void InitAsConstantBufferView( + _Out_ D3D12_ROOT_PARAMETER1 &rootParam, + UINT shaderRegister, + UINT registerSpace = 0, + D3D12_ROOT_DESCRIPTOR_FLAGS flags = D3D12_ROOT_DESCRIPTOR_FLAG_NONE, + D3D12_SHADER_VISIBILITY visibility = D3D12_SHADER_VISIBILITY_ALL) + { + rootParam.ParameterType = D3D12_ROOT_PARAMETER_TYPE_CBV; + rootParam.ShaderVisibility = visibility; + CD3DX12_ROOT_DESCRIPTOR1::Init(rootParam.Descriptor, shaderRegister, registerSpace, flags); + } + + static inline void InitAsShaderResourceView( + _Out_ D3D12_ROOT_PARAMETER1 &rootParam, + UINT shaderRegister, + UINT registerSpace = 0, + D3D12_ROOT_DESCRIPTOR_FLAGS flags = D3D12_ROOT_DESCRIPTOR_FLAG_NONE, + D3D12_SHADER_VISIBILITY visibility = D3D12_SHADER_VISIBILITY_ALL) + { + rootParam.ParameterType = D3D12_ROOT_PARAMETER_TYPE_SRV; + rootParam.ShaderVisibility = visibility; + CD3DX12_ROOT_DESCRIPTOR1::Init(rootParam.Descriptor, shaderRegister, registerSpace, flags); + } + + static inline void InitAsUnorderedAccessView( + _Out_ D3D12_ROOT_PARAMETER1 &rootParam, + UINT shaderRegister, + UINT registerSpace = 0, + D3D12_ROOT_DESCRIPTOR_FLAGS flags = D3D12_ROOT_DESCRIPTOR_FLAG_NONE, + D3D12_SHADER_VISIBILITY visibility = D3D12_SHADER_VISIBILITY_ALL) + { + rootParam.ParameterType = D3D12_ROOT_PARAMETER_TYPE_UAV; + rootParam.ShaderVisibility = visibility; + CD3DX12_ROOT_DESCRIPTOR1::Init(rootParam.Descriptor, shaderRegister, registerSpace, flags); + } + + inline void InitAsDescriptorTable( + UINT numDescriptorRanges, + _In_reads_(numDescriptorRanges) const D3D12_DESCRIPTOR_RANGE1* pDescriptorRanges, + D3D12_SHADER_VISIBILITY visibility = D3D12_SHADER_VISIBILITY_ALL) + { + InitAsDescriptorTable(*this, numDescriptorRanges, pDescriptorRanges, visibility); + } + + inline void InitAsConstants( + UINT num32BitValues, + UINT shaderRegister, + UINT registerSpace = 0, + D3D12_SHADER_VISIBILITY visibility = D3D12_SHADER_VISIBILITY_ALL) + { + InitAsConstants(*this, num32BitValues, shaderRegister, registerSpace, visibility); + } + + inline void InitAsConstantBufferView( + UINT shaderRegister, + UINT registerSpace = 0, + D3D12_ROOT_DESCRIPTOR_FLAGS flags = D3D12_ROOT_DESCRIPTOR_FLAG_NONE, + D3D12_SHADER_VISIBILITY visibility = D3D12_SHADER_VISIBILITY_ALL) + { + InitAsConstantBufferView(*this, shaderRegister, registerSpace, flags, visibility); + } + + inline void InitAsShaderResourceView( + UINT shaderRegister, + UINT registerSpace = 0, + D3D12_ROOT_DESCRIPTOR_FLAGS flags = D3D12_ROOT_DESCRIPTOR_FLAG_NONE, + D3D12_SHADER_VISIBILITY visibility = D3D12_SHADER_VISIBILITY_ALL) + { + InitAsShaderResourceView(*this, shaderRegister, registerSpace, flags, visibility); + } + + inline void InitAsUnorderedAccessView( + UINT shaderRegister, + UINT registerSpace = 0, + D3D12_ROOT_DESCRIPTOR_FLAGS flags = D3D12_ROOT_DESCRIPTOR_FLAG_NONE, + D3D12_SHADER_VISIBILITY visibility = D3D12_SHADER_VISIBILITY_ALL) + { + InitAsUnorderedAccessView(*this, shaderRegister, registerSpace, flags, visibility); + } +}; + +//------------------------------------------------------------------------------------------------ +struct CD3DX12_VERSIONED_ROOT_SIGNATURE_DESC : public D3D12_VERSIONED_ROOT_SIGNATURE_DESC +{ + CD3DX12_VERSIONED_ROOT_SIGNATURE_DESC() = default; + explicit CD3DX12_VERSIONED_ROOT_SIGNATURE_DESC(const D3D12_VERSIONED_ROOT_SIGNATURE_DESC &o) : + D3D12_VERSIONED_ROOT_SIGNATURE_DESC(o) + {} + explicit CD3DX12_VERSIONED_ROOT_SIGNATURE_DESC(const D3D12_ROOT_SIGNATURE_DESC &o) + { + Version = D3D_ROOT_SIGNATURE_VERSION_1_0; + Desc_1_0 = o; + } + explicit CD3DX12_VERSIONED_ROOT_SIGNATURE_DESC(const D3D12_ROOT_SIGNATURE_DESC1 &o) + { + Version = D3D_ROOT_SIGNATURE_VERSION_1_1; + Desc_1_1 = o; + } + CD3DX12_VERSIONED_ROOT_SIGNATURE_DESC( + UINT numParameters, + _In_reads_opt_(numParameters) const D3D12_ROOT_PARAMETER* _pParameters, + UINT numStaticSamplers = 0, + _In_reads_opt_(numStaticSamplers) const D3D12_STATIC_SAMPLER_DESC* _pStaticSamplers = nullptr, + D3D12_ROOT_SIGNATURE_FLAGS flags = D3D12_ROOT_SIGNATURE_FLAG_NONE) + { + Init_1_0(numParameters, _pParameters, numStaticSamplers, _pStaticSamplers, flags); + } + CD3DX12_VERSIONED_ROOT_SIGNATURE_DESC( + UINT numParameters, + _In_reads_opt_(numParameters) const D3D12_ROOT_PARAMETER1* _pParameters, + UINT numStaticSamplers = 0, + _In_reads_opt_(numStaticSamplers) const D3D12_STATIC_SAMPLER_DESC* _pStaticSamplers = nullptr, + D3D12_ROOT_SIGNATURE_FLAGS flags = D3D12_ROOT_SIGNATURE_FLAG_NONE) + { + Init_1_1(numParameters, _pParameters, numStaticSamplers, _pStaticSamplers, flags); + } + CD3DX12_VERSIONED_ROOT_SIGNATURE_DESC(CD3DX12_DEFAULT) + { + Init_1_1(0, nullptr, 0, nullptr, D3D12_ROOT_SIGNATURE_FLAG_NONE); + } + + inline void Init_1_0( + UINT numParameters, + _In_reads_opt_(numParameters) const D3D12_ROOT_PARAMETER* _pParameters, + UINT numStaticSamplers = 0, + _In_reads_opt_(numStaticSamplers) const D3D12_STATIC_SAMPLER_DESC* _pStaticSamplers = nullptr, + D3D12_ROOT_SIGNATURE_FLAGS flags = D3D12_ROOT_SIGNATURE_FLAG_NONE) + { + Init_1_0(*this, numParameters, _pParameters, numStaticSamplers, _pStaticSamplers, flags); + } + + static inline void Init_1_0( + _Out_ D3D12_VERSIONED_ROOT_SIGNATURE_DESC &desc, + UINT numParameters, + _In_reads_opt_(numParameters) const D3D12_ROOT_PARAMETER* _pParameters, + UINT numStaticSamplers = 0, + _In_reads_opt_(numStaticSamplers) const D3D12_STATIC_SAMPLER_DESC* _pStaticSamplers = nullptr, + D3D12_ROOT_SIGNATURE_FLAGS flags = D3D12_ROOT_SIGNATURE_FLAG_NONE) + { + desc.Version = D3D_ROOT_SIGNATURE_VERSION_1_0; + desc.Desc_1_0.NumParameters = numParameters; + desc.Desc_1_0.pParameters = _pParameters; + desc.Desc_1_0.NumStaticSamplers = numStaticSamplers; + desc.Desc_1_0.pStaticSamplers = _pStaticSamplers; + desc.Desc_1_0.Flags = flags; + } + + inline void Init_1_1( + UINT numParameters, + _In_reads_opt_(numParameters) const D3D12_ROOT_PARAMETER1* _pParameters, + UINT numStaticSamplers = 0, + _In_reads_opt_(numStaticSamplers) const D3D12_STATIC_SAMPLER_DESC* _pStaticSamplers = nullptr, + D3D12_ROOT_SIGNATURE_FLAGS flags = D3D12_ROOT_SIGNATURE_FLAG_NONE) + { + Init_1_1(*this, numParameters, _pParameters, numStaticSamplers, _pStaticSamplers, flags); + } + + static inline void Init_1_1( + _Out_ D3D12_VERSIONED_ROOT_SIGNATURE_DESC &desc, + UINT numParameters, + _In_reads_opt_(numParameters) const D3D12_ROOT_PARAMETER1* _pParameters, + UINT numStaticSamplers = 0, + _In_reads_opt_(numStaticSamplers) const D3D12_STATIC_SAMPLER_DESC* _pStaticSamplers = nullptr, + D3D12_ROOT_SIGNATURE_FLAGS flags = D3D12_ROOT_SIGNATURE_FLAG_NONE) + { + desc.Version = D3D_ROOT_SIGNATURE_VERSION_1_1; + desc.Desc_1_1.NumParameters = numParameters; + desc.Desc_1_1.pParameters = _pParameters; + desc.Desc_1_1.NumStaticSamplers = numStaticSamplers; + desc.Desc_1_1.pStaticSamplers = _pStaticSamplers; + desc.Desc_1_1.Flags = flags; + } +}; + +//------------------------------------------------------------------------------------------------ +struct CD3DX12_CPU_DESCRIPTOR_HANDLE : public D3D12_CPU_DESCRIPTOR_HANDLE +{ + CD3DX12_CPU_DESCRIPTOR_HANDLE() = default; + explicit CD3DX12_CPU_DESCRIPTOR_HANDLE(const D3D12_CPU_DESCRIPTOR_HANDLE &o) : + D3D12_CPU_DESCRIPTOR_HANDLE(o) + {} + CD3DX12_CPU_DESCRIPTOR_HANDLE(CD3DX12_DEFAULT) { ptr = 0; } + CD3DX12_CPU_DESCRIPTOR_HANDLE(_In_ const D3D12_CPU_DESCRIPTOR_HANDLE &other, INT offsetScaledByIncrementSize) + { + InitOffsetted(other, offsetScaledByIncrementSize); + } + CD3DX12_CPU_DESCRIPTOR_HANDLE(_In_ const D3D12_CPU_DESCRIPTOR_HANDLE &other, INT offsetInDescriptors, UINT descriptorIncrementSize) + { + InitOffsetted(other, offsetInDescriptors, descriptorIncrementSize); + } + CD3DX12_CPU_DESCRIPTOR_HANDLE& Offset(INT offsetInDescriptors, UINT descriptorIncrementSize) + { + ptr += INT64(offsetInDescriptors) * UINT64(descriptorIncrementSize); + return *this; + } + CD3DX12_CPU_DESCRIPTOR_HANDLE& Offset(INT offsetScaledByIncrementSize) + { + ptr += offsetScaledByIncrementSize; + return *this; + } + bool operator==(_In_ const D3D12_CPU_DESCRIPTOR_HANDLE& other) const + { + return (ptr == other.ptr); + } + bool operator!=(_In_ const D3D12_CPU_DESCRIPTOR_HANDLE& other) const + { + return (ptr != other.ptr); + } + CD3DX12_CPU_DESCRIPTOR_HANDLE &operator=(const D3D12_CPU_DESCRIPTOR_HANDLE &other) + { + ptr = other.ptr; + return *this; + } + + inline void InitOffsetted(_In_ const D3D12_CPU_DESCRIPTOR_HANDLE &base, INT offsetScaledByIncrementSize) + { + InitOffsetted(*this, base, offsetScaledByIncrementSize); + } + + inline void InitOffsetted(_In_ const D3D12_CPU_DESCRIPTOR_HANDLE &base, INT offsetInDescriptors, UINT descriptorIncrementSize) + { + InitOffsetted(*this, base, offsetInDescriptors, descriptorIncrementSize); + } + + static inline void InitOffsetted(_Out_ D3D12_CPU_DESCRIPTOR_HANDLE &handle, _In_ const D3D12_CPU_DESCRIPTOR_HANDLE &base, INT offsetScaledByIncrementSize) + { + handle.ptr = base.ptr + offsetScaledByIncrementSize; + } + + static inline void InitOffsetted(_Out_ D3D12_CPU_DESCRIPTOR_HANDLE &handle, _In_ const D3D12_CPU_DESCRIPTOR_HANDLE &base, INT offsetInDescriptors, UINT descriptorIncrementSize) + { + handle.ptr = static_cast(base.ptr + INT64(offsetInDescriptors) * UINT64(descriptorIncrementSize)); + } +}; + +//------------------------------------------------------------------------------------------------ +struct CD3DX12_GPU_DESCRIPTOR_HANDLE : public D3D12_GPU_DESCRIPTOR_HANDLE +{ + CD3DX12_GPU_DESCRIPTOR_HANDLE() = default; + explicit CD3DX12_GPU_DESCRIPTOR_HANDLE(const D3D12_GPU_DESCRIPTOR_HANDLE &o) : + D3D12_GPU_DESCRIPTOR_HANDLE(o) + {} + CD3DX12_GPU_DESCRIPTOR_HANDLE(CD3DX12_DEFAULT) { ptr = 0; } + CD3DX12_GPU_DESCRIPTOR_HANDLE(_In_ const D3D12_GPU_DESCRIPTOR_HANDLE &other, INT offsetScaledByIncrementSize) + { + InitOffsetted(other, offsetScaledByIncrementSize); + } + CD3DX12_GPU_DESCRIPTOR_HANDLE(_In_ const D3D12_GPU_DESCRIPTOR_HANDLE &other, INT offsetInDescriptors, UINT descriptorIncrementSize) + { + InitOffsetted(other, offsetInDescriptors, descriptorIncrementSize); + } + CD3DX12_GPU_DESCRIPTOR_HANDLE& Offset(INT offsetInDescriptors, UINT descriptorIncrementSize) + { + ptr += INT64(offsetInDescriptors) * UINT64(descriptorIncrementSize); + return *this; + } + CD3DX12_GPU_DESCRIPTOR_HANDLE& Offset(INT offsetScaledByIncrementSize) + { + ptr += offsetScaledByIncrementSize; + return *this; + } + inline bool operator==(_In_ const D3D12_GPU_DESCRIPTOR_HANDLE& other) const + { + return (ptr == other.ptr); + } + inline bool operator!=(_In_ const D3D12_GPU_DESCRIPTOR_HANDLE& other) const + { + return (ptr != other.ptr); + } + CD3DX12_GPU_DESCRIPTOR_HANDLE &operator=(const D3D12_GPU_DESCRIPTOR_HANDLE &other) + { + ptr = other.ptr; + return *this; + } + + inline void InitOffsetted(_In_ const D3D12_GPU_DESCRIPTOR_HANDLE &base, INT offsetScaledByIncrementSize) + { + InitOffsetted(*this, base, offsetScaledByIncrementSize); + } + + inline void InitOffsetted(_In_ const D3D12_GPU_DESCRIPTOR_HANDLE &base, INT offsetInDescriptors, UINT descriptorIncrementSize) + { + InitOffsetted(*this, base, offsetInDescriptors, descriptorIncrementSize); + } + + static inline void InitOffsetted(_Out_ D3D12_GPU_DESCRIPTOR_HANDLE &handle, _In_ const D3D12_GPU_DESCRIPTOR_HANDLE &base, INT offsetScaledByIncrementSize) + { + handle.ptr = base.ptr + offsetScaledByIncrementSize; + } + + static inline void InitOffsetted(_Out_ D3D12_GPU_DESCRIPTOR_HANDLE &handle, _In_ const D3D12_GPU_DESCRIPTOR_HANDLE &base, INT offsetInDescriptors, UINT descriptorIncrementSize) + { + handle.ptr = static_cast(base.ptr + INT64(offsetInDescriptors) * UINT64(descriptorIncrementSize)); + } +}; + +//------------------------------------------------------------------------------------------------ +inline UINT D3D12CalcSubresource( UINT MipSlice, UINT ArraySlice, UINT PlaneSlice, UINT MipLevels, UINT ArraySize ) +{ + return MipSlice + ArraySlice * MipLevels + PlaneSlice * MipLevels * ArraySize; +} + +//------------------------------------------------------------------------------------------------ +template +inline void D3D12DecomposeSubresource( UINT Subresource, UINT MipLevels, UINT ArraySize, _Out_ T& MipSlice, _Out_ U& ArraySlice, _Out_ V& PlaneSlice ) +{ + MipSlice = static_cast(Subresource % MipLevels); + ArraySlice = static_cast((Subresource / MipLevels) % ArraySize); + PlaneSlice = static_cast(Subresource / (MipLevels * ArraySize)); +} + +//------------------------------------------------------------------------------------------------ +inline UINT8 D3D12GetFormatPlaneCount( + _In_ ID3D12Device* pDevice, + DXGI_FORMAT Format + ) +{ + D3D12_FEATURE_DATA_FORMAT_INFO formatInfo = { Format, 0 }; + if (FAILED(pDevice->CheckFeatureSupport(D3D12_FEATURE_FORMAT_INFO, &formatInfo, sizeof(formatInfo)))) + { + return 0; + } + return formatInfo.PlaneCount; +} + +//------------------------------------------------------------------------------------------------ +struct CD3DX12_RESOURCE_DESC : public D3D12_RESOURCE_DESC +{ + CD3DX12_RESOURCE_DESC() = default; + explicit CD3DX12_RESOURCE_DESC( const D3D12_RESOURCE_DESC& o ) : + D3D12_RESOURCE_DESC( o ) + {} + CD3DX12_RESOURCE_DESC( + D3D12_RESOURCE_DIMENSION dimension, + UINT64 alignment, + UINT64 width, + UINT height, + UINT16 depthOrArraySize, + UINT16 mipLevels, + DXGI_FORMAT format, + UINT sampleCount, + UINT sampleQuality, + D3D12_TEXTURE_LAYOUT layout, + D3D12_RESOURCE_FLAGS flags ) + { + Dimension = dimension; + Alignment = alignment; + Width = width; + Height = height; + DepthOrArraySize = depthOrArraySize; + MipLevels = mipLevels; + Format = format; + SampleDesc.Count = sampleCount; + SampleDesc.Quality = sampleQuality; + Layout = layout; + Flags = flags; + } + static inline CD3DX12_RESOURCE_DESC Buffer( + const D3D12_RESOURCE_ALLOCATION_INFO& resAllocInfo, + D3D12_RESOURCE_FLAGS flags = D3D12_RESOURCE_FLAG_NONE ) + { + return CD3DX12_RESOURCE_DESC( D3D12_RESOURCE_DIMENSION_BUFFER, resAllocInfo.Alignment, resAllocInfo.SizeInBytes, + 1, 1, 1, DXGI_FORMAT_UNKNOWN, 1, 0, D3D12_TEXTURE_LAYOUT_ROW_MAJOR, flags ); + } + static inline CD3DX12_RESOURCE_DESC Buffer( + UINT64 width, + D3D12_RESOURCE_FLAGS flags = D3D12_RESOURCE_FLAG_NONE, + UINT64 alignment = 0 ) + { + return CD3DX12_RESOURCE_DESC( D3D12_RESOURCE_DIMENSION_BUFFER, alignment, width, 1, 1, 1, + DXGI_FORMAT_UNKNOWN, 1, 0, D3D12_TEXTURE_LAYOUT_ROW_MAJOR, flags ); + } + static inline CD3DX12_RESOURCE_DESC Tex1D( + DXGI_FORMAT format, + UINT64 width, + UINT16 arraySize = 1, + UINT16 mipLevels = 0, + D3D12_RESOURCE_FLAGS flags = D3D12_RESOURCE_FLAG_NONE, + D3D12_TEXTURE_LAYOUT layout = D3D12_TEXTURE_LAYOUT_UNKNOWN, + UINT64 alignment = 0 ) + { + return CD3DX12_RESOURCE_DESC( D3D12_RESOURCE_DIMENSION_TEXTURE1D, alignment, width, 1, arraySize, + mipLevels, format, 1, 0, layout, flags ); + } + static inline CD3DX12_RESOURCE_DESC Tex2D( + DXGI_FORMAT format, + UINT64 width, + UINT height, + UINT16 arraySize = 1, + UINT16 mipLevels = 0, + UINT sampleCount = 1, + UINT sampleQuality = 0, + D3D12_RESOURCE_FLAGS flags = D3D12_RESOURCE_FLAG_NONE, + D3D12_TEXTURE_LAYOUT layout = D3D12_TEXTURE_LAYOUT_UNKNOWN, + UINT64 alignment = 0 ) + { + return CD3DX12_RESOURCE_DESC( D3D12_RESOURCE_DIMENSION_TEXTURE2D, alignment, width, height, arraySize, + mipLevels, format, sampleCount, sampleQuality, layout, flags ); + } + static inline CD3DX12_RESOURCE_DESC Tex3D( + DXGI_FORMAT format, + UINT64 width, + UINT height, + UINT16 depth, + UINT16 mipLevels = 0, + D3D12_RESOURCE_FLAGS flags = D3D12_RESOURCE_FLAG_NONE, + D3D12_TEXTURE_LAYOUT layout = D3D12_TEXTURE_LAYOUT_UNKNOWN, + UINT64 alignment = 0 ) + { + return CD3DX12_RESOURCE_DESC( D3D12_RESOURCE_DIMENSION_TEXTURE3D, alignment, width, height, depth, + mipLevels, format, 1, 0, layout, flags ); + } + inline UINT16 Depth() const + { return (Dimension == D3D12_RESOURCE_DIMENSION_TEXTURE3D ? DepthOrArraySize : 1); } + inline UINT16 ArraySize() const + { return (Dimension != D3D12_RESOURCE_DIMENSION_TEXTURE3D ? DepthOrArraySize : 1); } + inline UINT8 PlaneCount(_In_ ID3D12Device* pDevice) const + { return D3D12GetFormatPlaneCount(pDevice, Format); } + inline UINT Subresources(_In_ ID3D12Device* pDevice) const + { return MipLevels * ArraySize() * PlaneCount(pDevice); } + inline UINT CalcSubresource(UINT MipSlice, UINT ArraySlice, UINT PlaneSlice) + { return D3D12CalcSubresource(MipSlice, ArraySlice, PlaneSlice, MipLevels, ArraySize()); } +}; +inline bool operator==( const D3D12_RESOURCE_DESC& l, const D3D12_RESOURCE_DESC& r ) +{ + return l.Dimension == r.Dimension && + l.Alignment == r.Alignment && + l.Width == r.Width && + l.Height == r.Height && + l.DepthOrArraySize == r.DepthOrArraySize && + l.MipLevels == r.MipLevels && + l.Format == r.Format && + l.SampleDesc.Count == r.SampleDesc.Count && + l.SampleDesc.Quality == r.SampleDesc.Quality && + l.Layout == r.Layout && + l.Flags == r.Flags; +} +inline bool operator!=( const D3D12_RESOURCE_DESC& l, const D3D12_RESOURCE_DESC& r ) +{ return !( l == r ); } + +//------------------------------------------------------------------------------------------------ +struct CD3DX12_VIEW_INSTANCING_DESC : public D3D12_VIEW_INSTANCING_DESC +{ + CD3DX12_VIEW_INSTANCING_DESC() = default; + explicit CD3DX12_VIEW_INSTANCING_DESC( const D3D12_VIEW_INSTANCING_DESC& o ) : + D3D12_VIEW_INSTANCING_DESC( o ) + {} + explicit CD3DX12_VIEW_INSTANCING_DESC( CD3DX12_DEFAULT ) + { + ViewInstanceCount = 0; + pViewInstanceLocations = nullptr; + Flags = D3D12_VIEW_INSTANCING_FLAG_NONE; + } + explicit CD3DX12_VIEW_INSTANCING_DESC( + UINT InViewInstanceCount, + const D3D12_VIEW_INSTANCE_LOCATION* InViewInstanceLocations, + D3D12_VIEW_INSTANCING_FLAGS InFlags) + { + ViewInstanceCount = InViewInstanceCount; + pViewInstanceLocations = InViewInstanceLocations; + Flags = InFlags; + } +}; + +//------------------------------------------------------------------------------------------------ +// Row-by-row memcpy +inline void MemcpySubresource( + _In_ const D3D12_MEMCPY_DEST* pDest, + _In_ const D3D12_SUBRESOURCE_DATA* pSrc, + SIZE_T RowSizeInBytes, + UINT NumRows, + UINT NumSlices) +{ + for (UINT z = 0; z < NumSlices; ++z) + { + BYTE* pDestSlice = reinterpret_cast(pDest->pData) + pDest->SlicePitch * z; + const BYTE* pSrcSlice = reinterpret_cast(pSrc->pData) + pSrc->SlicePitch * z; + for (UINT y = 0; y < NumRows; ++y) + { + memcpy(pDestSlice + pDest->RowPitch * y, + pSrcSlice + pSrc->RowPitch * y, + RowSizeInBytes); + } + } +} + +//------------------------------------------------------------------------------------------------ +// Returns required size of a buffer to be used for data upload +inline UINT64 GetRequiredIntermediateSize( + _In_ ID3D12Resource* pDestinationResource, + _In_range_(0,D3D12_REQ_SUBRESOURCES) UINT FirstSubresource, + _In_range_(0,D3D12_REQ_SUBRESOURCES-FirstSubresource) UINT NumSubresources) +{ + auto Desc = pDestinationResource->GetDesc(); + UINT64 RequiredSize = 0; + + ID3D12Device* pDevice = nullptr; + pDestinationResource->GetDevice(__uuidof(*pDevice), reinterpret_cast(&pDevice)); + pDevice->GetCopyableFootprints(&Desc, FirstSubresource, NumSubresources, 0, nullptr, nullptr, nullptr, &RequiredSize); + pDevice->Release(); + + return RequiredSize; +} + +//------------------------------------------------------------------------------------------------ +// All arrays must be populated (e.g. by calling GetCopyableFootprints) +inline UINT64 UpdateSubresources( + _In_ ID3D12GraphicsCommandList* pCmdList, + _In_ ID3D12Resource* pDestinationResource, + _In_ ID3D12Resource* pIntermediate, + _In_range_(0,D3D12_REQ_SUBRESOURCES) UINT FirstSubresource, + _In_range_(0,D3D12_REQ_SUBRESOURCES-FirstSubresource) UINT NumSubresources, + UINT64 RequiredSize, + _In_reads_(NumSubresources) const D3D12_PLACED_SUBRESOURCE_FOOTPRINT* pLayouts, + _In_reads_(NumSubresources) const UINT* pNumRows, + _In_reads_(NumSubresources) const UINT64* pRowSizesInBytes, + _In_reads_(NumSubresources) const D3D12_SUBRESOURCE_DATA* pSrcData) +{ + // Minor validation + auto IntermediateDesc = pIntermediate->GetDesc(); + auto DestinationDesc = pDestinationResource->GetDesc(); + if (IntermediateDesc.Dimension != D3D12_RESOURCE_DIMENSION_BUFFER || + IntermediateDesc.Width < RequiredSize + pLayouts[0].Offset || + RequiredSize > SIZE_T(-1) || + (DestinationDesc.Dimension == D3D12_RESOURCE_DIMENSION_BUFFER && + (FirstSubresource != 0 || NumSubresources != 1))) + { + return 0; + } + + BYTE* pData; + HRESULT hr = pIntermediate->Map(0, nullptr, reinterpret_cast(&pData)); + if (FAILED(hr)) + { + return 0; + } + + for (UINT i = 0; i < NumSubresources; ++i) + { + if (pRowSizesInBytes[i] > SIZE_T(-1)) return 0; + D3D12_MEMCPY_DEST DestData = { pData + pLayouts[i].Offset, pLayouts[i].Footprint.RowPitch, SIZE_T(pLayouts[i].Footprint.RowPitch) * SIZE_T(pNumRows[i]) }; + MemcpySubresource(&DestData, &pSrcData[i], static_cast(pRowSizesInBytes[i]), pNumRows[i], pLayouts[i].Footprint.Depth); + } + pIntermediate->Unmap(0, nullptr); + + if (DestinationDesc.Dimension == D3D12_RESOURCE_DIMENSION_BUFFER) + { + pCmdList->CopyBufferRegion( + pDestinationResource, 0, pIntermediate, pLayouts[0].Offset, pLayouts[0].Footprint.Width); + } + else + { + for (UINT i = 0; i < NumSubresources; ++i) + { + CD3DX12_TEXTURE_COPY_LOCATION Dst(pDestinationResource, i + FirstSubresource); + CD3DX12_TEXTURE_COPY_LOCATION Src(pIntermediate, pLayouts[i]); + pCmdList->CopyTextureRegion(&Dst, 0, 0, 0, &Src, nullptr); + } + } + return RequiredSize; +} + +//------------------------------------------------------------------------------------------------ +// Heap-allocating UpdateSubresources implementation +inline UINT64 UpdateSubresources( + _In_ ID3D12GraphicsCommandList* pCmdList, + _In_ ID3D12Resource* pDestinationResource, + _In_ ID3D12Resource* pIntermediate, + UINT64 IntermediateOffset, + _In_range_(0,D3D12_REQ_SUBRESOURCES) UINT FirstSubresource, + _In_range_(0,D3D12_REQ_SUBRESOURCES-FirstSubresource) UINT NumSubresources, + _In_reads_(NumSubresources) D3D12_SUBRESOURCE_DATA* pSrcData) +{ + UINT64 RequiredSize = 0; + UINT64 MemToAlloc = static_cast(sizeof(D3D12_PLACED_SUBRESOURCE_FOOTPRINT) + sizeof(UINT) + sizeof(UINT64)) * NumSubresources; + if (MemToAlloc > SIZE_MAX) + { + return 0; + } + void* pMem = HeapAlloc(GetProcessHeap(), 0, static_cast(MemToAlloc)); + if (pMem == nullptr) + { + return 0; + } + auto pLayouts = reinterpret_cast(pMem); + UINT64* pRowSizesInBytes = reinterpret_cast(pLayouts + NumSubresources); + UINT* pNumRows = reinterpret_cast(pRowSizesInBytes + NumSubresources); + + auto Desc = pDestinationResource->GetDesc(); + ID3D12Device* pDevice = nullptr; + pDestinationResource->GetDevice(__uuidof(*pDevice), reinterpret_cast(&pDevice)); + pDevice->GetCopyableFootprints(&Desc, FirstSubresource, NumSubresources, IntermediateOffset, pLayouts, pNumRows, pRowSizesInBytes, &RequiredSize); + pDevice->Release(); + + UINT64 Result = UpdateSubresources(pCmdList, pDestinationResource, pIntermediate, FirstSubresource, NumSubresources, RequiredSize, pLayouts, pNumRows, pRowSizesInBytes, pSrcData); + HeapFree(GetProcessHeap(), 0, pMem); + return Result; +} + +//------------------------------------------------------------------------------------------------ +// Stack-allocating UpdateSubresources implementation +template +inline UINT64 UpdateSubresources( + _In_ ID3D12GraphicsCommandList* pCmdList, + _In_ ID3D12Resource* pDestinationResource, + _In_ ID3D12Resource* pIntermediate, + UINT64 IntermediateOffset, + _In_range_(0, MaxSubresources) UINT FirstSubresource, + _In_range_(1, MaxSubresources - FirstSubresource) UINT NumSubresources, + _In_reads_(NumSubresources) D3D12_SUBRESOURCE_DATA* pSrcData) +{ + UINT64 RequiredSize = 0; + D3D12_PLACED_SUBRESOURCE_FOOTPRINT Layouts[MaxSubresources]; + UINT NumRows[MaxSubresources]; + UINT64 RowSizesInBytes[MaxSubresources]; + + auto Desc = pDestinationResource->GetDesc(); + ID3D12Device* pDevice = nullptr; + pDestinationResource->GetDevice(__uuidof(*pDevice), reinterpret_cast(&pDevice)); + pDevice->GetCopyableFootprints(&Desc, FirstSubresource, NumSubresources, IntermediateOffset, Layouts, NumRows, RowSizesInBytes, &RequiredSize); + pDevice->Release(); + + return UpdateSubresources(pCmdList, pDestinationResource, pIntermediate, FirstSubresource, NumSubresources, RequiredSize, Layouts, NumRows, RowSizesInBytes, pSrcData); +} + +//------------------------------------------------------------------------------------------------ +inline bool D3D12IsLayoutOpaque( D3D12_TEXTURE_LAYOUT Layout ) +{ return Layout == D3D12_TEXTURE_LAYOUT_UNKNOWN || Layout == D3D12_TEXTURE_LAYOUT_64KB_UNDEFINED_SWIZZLE; } + +//------------------------------------------------------------------------------------------------ +template +inline ID3D12CommandList * const * CommandListCast(t_CommandListType * const * pp) +{ + // This cast is useful for passing strongly typed command list pointers into + // ExecuteCommandLists. + // This cast is valid as long as the const-ness is respected. D3D12 APIs do + // respect the const-ness of their arguments. + return reinterpret_cast(pp); +} + +//------------------------------------------------------------------------------------------------ +// D3D12 exports a new method for serializing root signatures in the Windows 10 Anniversary Update. +// To help enable root signature 1.1 features when they are available and not require maintaining +// two code paths for building root signatures, this helper method reconstructs a 1.0 signature when +// 1.1 is not supported. +inline HRESULT D3DX12SerializeVersionedRootSignature( + _In_ const D3D12_VERSIONED_ROOT_SIGNATURE_DESC* pRootSignatureDesc, + D3D_ROOT_SIGNATURE_VERSION MaxVersion, + _Outptr_ ID3DBlob** ppBlob, + _Always_(_Outptr_opt_result_maybenull_) ID3DBlob** ppErrorBlob) +{ + if (ppErrorBlob != nullptr) + { + *ppErrorBlob = nullptr; + } + + switch (MaxVersion) + { + case D3D_ROOT_SIGNATURE_VERSION_1_0: + switch (pRootSignatureDesc->Version) + { + case D3D_ROOT_SIGNATURE_VERSION_1_0: + return D3D12SerializeRootSignature(&pRootSignatureDesc->Desc_1_0, D3D_ROOT_SIGNATURE_VERSION_1, ppBlob, ppErrorBlob); + + case D3D_ROOT_SIGNATURE_VERSION_1_1: + { + HRESULT hr = S_OK; + const D3D12_ROOT_SIGNATURE_DESC1& desc_1_1 = pRootSignatureDesc->Desc_1_1; + + const SIZE_T ParametersSize = sizeof(D3D12_ROOT_PARAMETER) * desc_1_1.NumParameters; + void* pParameters = (ParametersSize > 0) ? HeapAlloc(GetProcessHeap(), 0, ParametersSize) : nullptr; + if (ParametersSize > 0 && pParameters == nullptr) + { + hr = E_OUTOFMEMORY; + } + auto pParameters_1_0 = reinterpret_cast(pParameters); + + if (SUCCEEDED(hr)) + { + for (UINT n = 0; n < desc_1_1.NumParameters; n++) + { + __analysis_assume(ParametersSize == sizeof(D3D12_ROOT_PARAMETER) * desc_1_1.NumParameters); + pParameters_1_0[n].ParameterType = desc_1_1.pParameters[n].ParameterType; + pParameters_1_0[n].ShaderVisibility = desc_1_1.pParameters[n].ShaderVisibility; + + switch (desc_1_1.pParameters[n].ParameterType) + { + case D3D12_ROOT_PARAMETER_TYPE_32BIT_CONSTANTS: + pParameters_1_0[n].Constants.Num32BitValues = desc_1_1.pParameters[n].Constants.Num32BitValues; + pParameters_1_0[n].Constants.RegisterSpace = desc_1_1.pParameters[n].Constants.RegisterSpace; + pParameters_1_0[n].Constants.ShaderRegister = desc_1_1.pParameters[n].Constants.ShaderRegister; + break; + + case D3D12_ROOT_PARAMETER_TYPE_CBV: + case D3D12_ROOT_PARAMETER_TYPE_SRV: + case D3D12_ROOT_PARAMETER_TYPE_UAV: + pParameters_1_0[n].Descriptor.RegisterSpace = desc_1_1.pParameters[n].Descriptor.RegisterSpace; + pParameters_1_0[n].Descriptor.ShaderRegister = desc_1_1.pParameters[n].Descriptor.ShaderRegister; + break; + + case D3D12_ROOT_PARAMETER_TYPE_DESCRIPTOR_TABLE: + const D3D12_ROOT_DESCRIPTOR_TABLE1& table_1_1 = desc_1_1.pParameters[n].DescriptorTable; + + const SIZE_T DescriptorRangesSize = sizeof(D3D12_DESCRIPTOR_RANGE) * table_1_1.NumDescriptorRanges; + void* pDescriptorRanges = (DescriptorRangesSize > 0 && SUCCEEDED(hr)) ? HeapAlloc(GetProcessHeap(), 0, DescriptorRangesSize) : nullptr; + if (DescriptorRangesSize > 0 && pDescriptorRanges == nullptr) + { + hr = E_OUTOFMEMORY; + } + auto pDescriptorRanges_1_0 = reinterpret_cast(pDescriptorRanges); + + if (SUCCEEDED(hr)) + { + for (UINT x = 0; x < table_1_1.NumDescriptorRanges; x++) + { + __analysis_assume(DescriptorRangesSize == sizeof(D3D12_DESCRIPTOR_RANGE) * table_1_1.NumDescriptorRanges); + pDescriptorRanges_1_0[x].BaseShaderRegister = table_1_1.pDescriptorRanges[x].BaseShaderRegister; + pDescriptorRanges_1_0[x].NumDescriptors = table_1_1.pDescriptorRanges[x].NumDescriptors; + pDescriptorRanges_1_0[x].OffsetInDescriptorsFromTableStart = table_1_1.pDescriptorRanges[x].OffsetInDescriptorsFromTableStart; + pDescriptorRanges_1_0[x].RangeType = table_1_1.pDescriptorRanges[x].RangeType; + pDescriptorRanges_1_0[x].RegisterSpace = table_1_1.pDescriptorRanges[x].RegisterSpace; + } + } + + D3D12_ROOT_DESCRIPTOR_TABLE& table_1_0 = pParameters_1_0[n].DescriptorTable; + table_1_0.NumDescriptorRanges = table_1_1.NumDescriptorRanges; + table_1_0.pDescriptorRanges = pDescriptorRanges_1_0; + } + } + } + + if (SUCCEEDED(hr)) + { + CD3DX12_ROOT_SIGNATURE_DESC desc_1_0(desc_1_1.NumParameters, pParameters_1_0, desc_1_1.NumStaticSamplers, desc_1_1.pStaticSamplers, desc_1_1.Flags); + hr = D3D12SerializeRootSignature(&desc_1_0, D3D_ROOT_SIGNATURE_VERSION_1, ppBlob, ppErrorBlob); + } + + if (pParameters) + { + for (UINT n = 0; n < desc_1_1.NumParameters; n++) + { + if (desc_1_1.pParameters[n].ParameterType == D3D12_ROOT_PARAMETER_TYPE_DESCRIPTOR_TABLE) + { + HeapFree(GetProcessHeap(), 0, reinterpret_cast(const_cast(pParameters_1_0[n].DescriptorTable.pDescriptorRanges))); + } + } + HeapFree(GetProcessHeap(), 0, pParameters); + } + return hr; + } + } + break; + + case D3D_ROOT_SIGNATURE_VERSION_1_1: + return D3D12SerializeVersionedRootSignature(pRootSignatureDesc, ppBlob, ppErrorBlob); + } + + return E_INVALIDARG; +} + +//------------------------------------------------------------------------------------------------ +struct CD3DX12_RT_FORMAT_ARRAY : public D3D12_RT_FORMAT_ARRAY +{ + CD3DX12_RT_FORMAT_ARRAY() = default; + explicit CD3DX12_RT_FORMAT_ARRAY(const D3D12_RT_FORMAT_ARRAY& o) + : D3D12_RT_FORMAT_ARRAY(o) + {} + explicit CD3DX12_RT_FORMAT_ARRAY(_In_reads_(NumFormats) const DXGI_FORMAT* pFormats, UINT NumFormats) + { + NumRenderTargets = NumFormats; + memcpy(RTFormats, pFormats, sizeof(RTFormats)); + // assumes ARRAY_SIZE(pFormats) == ARRAY_SIZE(RTFormats) + } +}; + +//------------------------------------------------------------------------------------------------ +// Pipeline State Stream Helpers +//------------------------------------------------------------------------------------------------ + +//------------------------------------------------------------------------------------------------ +// Stream Subobjects, i.e. elements of a stream + +struct DefaultSampleMask { operator UINT() { return UINT_MAX; } }; +struct DefaultSampleDesc { operator DXGI_SAMPLE_DESC() { return DXGI_SAMPLE_DESC{1, 0}; } }; + +#pragma warning(push) +#pragma warning(disable : 4324) +template +class alignas(void*) CD3DX12_PIPELINE_STATE_STREAM_SUBOBJECT +{ +private: + D3D12_PIPELINE_STATE_SUBOBJECT_TYPE _Type; + InnerStructType _Inner; +public: + CD3DX12_PIPELINE_STATE_STREAM_SUBOBJECT() noexcept : _Type(Type), _Inner(DefaultArg()) {} + CD3DX12_PIPELINE_STATE_STREAM_SUBOBJECT(InnerStructType const& i) : _Type(Type), _Inner(i) {} + CD3DX12_PIPELINE_STATE_STREAM_SUBOBJECT& operator=(InnerStructType const& i) { _Inner = i; return *this; } + operator InnerStructType() const { return _Inner; } + operator InnerStructType&() { return _Inner; } +}; +#pragma warning(pop) +typedef CD3DX12_PIPELINE_STATE_STREAM_SUBOBJECT< D3D12_PIPELINE_STATE_FLAGS, D3D12_PIPELINE_STATE_SUBOBJECT_TYPE_FLAGS> CD3DX12_PIPELINE_STATE_STREAM_FLAGS; +typedef CD3DX12_PIPELINE_STATE_STREAM_SUBOBJECT< UINT, D3D12_PIPELINE_STATE_SUBOBJECT_TYPE_NODE_MASK> CD3DX12_PIPELINE_STATE_STREAM_NODE_MASK; +typedef CD3DX12_PIPELINE_STATE_STREAM_SUBOBJECT< ID3D12RootSignature*, D3D12_PIPELINE_STATE_SUBOBJECT_TYPE_ROOT_SIGNATURE> CD3DX12_PIPELINE_STATE_STREAM_ROOT_SIGNATURE; +typedef CD3DX12_PIPELINE_STATE_STREAM_SUBOBJECT< D3D12_INPUT_LAYOUT_DESC, D3D12_PIPELINE_STATE_SUBOBJECT_TYPE_INPUT_LAYOUT> CD3DX12_PIPELINE_STATE_STREAM_INPUT_LAYOUT; +typedef CD3DX12_PIPELINE_STATE_STREAM_SUBOBJECT< D3D12_INDEX_BUFFER_STRIP_CUT_VALUE, D3D12_PIPELINE_STATE_SUBOBJECT_TYPE_IB_STRIP_CUT_VALUE> CD3DX12_PIPELINE_STATE_STREAM_IB_STRIP_CUT_VALUE; +typedef CD3DX12_PIPELINE_STATE_STREAM_SUBOBJECT< D3D12_PRIMITIVE_TOPOLOGY_TYPE, D3D12_PIPELINE_STATE_SUBOBJECT_TYPE_PRIMITIVE_TOPOLOGY> CD3DX12_PIPELINE_STATE_STREAM_PRIMITIVE_TOPOLOGY; +typedef CD3DX12_PIPELINE_STATE_STREAM_SUBOBJECT< D3D12_SHADER_BYTECODE, D3D12_PIPELINE_STATE_SUBOBJECT_TYPE_VS> CD3DX12_PIPELINE_STATE_STREAM_VS; +typedef CD3DX12_PIPELINE_STATE_STREAM_SUBOBJECT< D3D12_SHADER_BYTECODE, D3D12_PIPELINE_STATE_SUBOBJECT_TYPE_GS> CD3DX12_PIPELINE_STATE_STREAM_GS; +typedef CD3DX12_PIPELINE_STATE_STREAM_SUBOBJECT< D3D12_STREAM_OUTPUT_DESC, D3D12_PIPELINE_STATE_SUBOBJECT_TYPE_STREAM_OUTPUT> CD3DX12_PIPELINE_STATE_STREAM_STREAM_OUTPUT; +typedef CD3DX12_PIPELINE_STATE_STREAM_SUBOBJECT< D3D12_SHADER_BYTECODE, D3D12_PIPELINE_STATE_SUBOBJECT_TYPE_HS> CD3DX12_PIPELINE_STATE_STREAM_HS; +typedef CD3DX12_PIPELINE_STATE_STREAM_SUBOBJECT< D3D12_SHADER_BYTECODE, D3D12_PIPELINE_STATE_SUBOBJECT_TYPE_DS> CD3DX12_PIPELINE_STATE_STREAM_DS; +typedef CD3DX12_PIPELINE_STATE_STREAM_SUBOBJECT< D3D12_SHADER_BYTECODE, D3D12_PIPELINE_STATE_SUBOBJECT_TYPE_PS> CD3DX12_PIPELINE_STATE_STREAM_PS; +typedef CD3DX12_PIPELINE_STATE_STREAM_SUBOBJECT< D3D12_SHADER_BYTECODE, D3D12_PIPELINE_STATE_SUBOBJECT_TYPE_CS> CD3DX12_PIPELINE_STATE_STREAM_CS; +typedef CD3DX12_PIPELINE_STATE_STREAM_SUBOBJECT< CD3DX12_BLEND_DESC, D3D12_PIPELINE_STATE_SUBOBJECT_TYPE_BLEND, CD3DX12_DEFAULT> CD3DX12_PIPELINE_STATE_STREAM_BLEND_DESC; +typedef CD3DX12_PIPELINE_STATE_STREAM_SUBOBJECT< CD3DX12_DEPTH_STENCIL_DESC, D3D12_PIPELINE_STATE_SUBOBJECT_TYPE_DEPTH_STENCIL, CD3DX12_DEFAULT> CD3DX12_PIPELINE_STATE_STREAM_DEPTH_STENCIL; +typedef CD3DX12_PIPELINE_STATE_STREAM_SUBOBJECT< CD3DX12_DEPTH_STENCIL_DESC1, D3D12_PIPELINE_STATE_SUBOBJECT_TYPE_DEPTH_STENCIL1, CD3DX12_DEFAULT> CD3DX12_PIPELINE_STATE_STREAM_DEPTH_STENCIL1; +typedef CD3DX12_PIPELINE_STATE_STREAM_SUBOBJECT< DXGI_FORMAT, D3D12_PIPELINE_STATE_SUBOBJECT_TYPE_DEPTH_STENCIL_FORMAT> CD3DX12_PIPELINE_STATE_STREAM_DEPTH_STENCIL_FORMAT; +typedef CD3DX12_PIPELINE_STATE_STREAM_SUBOBJECT< CD3DX12_RASTERIZER_DESC, D3D12_PIPELINE_STATE_SUBOBJECT_TYPE_RASTERIZER, CD3DX12_DEFAULT> CD3DX12_PIPELINE_STATE_STREAM_RASTERIZER; +typedef CD3DX12_PIPELINE_STATE_STREAM_SUBOBJECT< D3D12_RT_FORMAT_ARRAY, D3D12_PIPELINE_STATE_SUBOBJECT_TYPE_RENDER_TARGET_FORMATS> CD3DX12_PIPELINE_STATE_STREAM_RENDER_TARGET_FORMATS; +typedef CD3DX12_PIPELINE_STATE_STREAM_SUBOBJECT< DXGI_SAMPLE_DESC, D3D12_PIPELINE_STATE_SUBOBJECT_TYPE_SAMPLE_DESC, DefaultSampleDesc> CD3DX12_PIPELINE_STATE_STREAM_SAMPLE_DESC; +typedef CD3DX12_PIPELINE_STATE_STREAM_SUBOBJECT< UINT, D3D12_PIPELINE_STATE_SUBOBJECT_TYPE_SAMPLE_MASK, DefaultSampleMask> CD3DX12_PIPELINE_STATE_STREAM_SAMPLE_MASK; +typedef CD3DX12_PIPELINE_STATE_STREAM_SUBOBJECT< D3D12_CACHED_PIPELINE_STATE, D3D12_PIPELINE_STATE_SUBOBJECT_TYPE_CACHED_PSO> CD3DX12_PIPELINE_STATE_STREAM_CACHED_PSO; +typedef CD3DX12_PIPELINE_STATE_STREAM_SUBOBJECT< CD3DX12_VIEW_INSTANCING_DESC, D3D12_PIPELINE_STATE_SUBOBJECT_TYPE_VIEW_INSTANCING, CD3DX12_DEFAULT> CD3DX12_PIPELINE_STATE_STREAM_VIEW_INSTANCING; + +//------------------------------------------------------------------------------------------------ +// Stream Parser Helpers + +struct ID3DX12PipelineParserCallbacks +{ + // Subobject Callbacks + virtual void FlagsCb(D3D12_PIPELINE_STATE_FLAGS) {} + virtual void NodeMaskCb(UINT) {} + virtual void RootSignatureCb(ID3D12RootSignature*) {} + virtual void InputLayoutCb(const D3D12_INPUT_LAYOUT_DESC&) {} + virtual void IBStripCutValueCb(D3D12_INDEX_BUFFER_STRIP_CUT_VALUE) {} + virtual void PrimitiveTopologyTypeCb(D3D12_PRIMITIVE_TOPOLOGY_TYPE) {} + virtual void VSCb(const D3D12_SHADER_BYTECODE&) {} + virtual void GSCb(const D3D12_SHADER_BYTECODE&) {} + virtual void StreamOutputCb(const D3D12_STREAM_OUTPUT_DESC&) {} + virtual void HSCb(const D3D12_SHADER_BYTECODE&) {} + virtual void DSCb(const D3D12_SHADER_BYTECODE&) {} + virtual void PSCb(const D3D12_SHADER_BYTECODE&) {} + virtual void CSCb(const D3D12_SHADER_BYTECODE&) {} + virtual void BlendStateCb(const D3D12_BLEND_DESC&) {} + virtual void DepthStencilStateCb(const D3D12_DEPTH_STENCIL_DESC&) {} + virtual void DepthStencilState1Cb(const D3D12_DEPTH_STENCIL_DESC1&) {} + virtual void DSVFormatCb(DXGI_FORMAT) {} + virtual void RasterizerStateCb(const D3D12_RASTERIZER_DESC&) {} + virtual void RTVFormatsCb(const D3D12_RT_FORMAT_ARRAY&) {} + virtual void SampleDescCb(const DXGI_SAMPLE_DESC&) {} + virtual void SampleMaskCb(UINT) {} + virtual void ViewInstancingCb(const D3D12_VIEW_INSTANCING_DESC&) {} + virtual void CachedPSOCb(const D3D12_CACHED_PIPELINE_STATE&) {} + + // Error Callbacks + virtual void ErrorBadInputParameter(UINT /*ParameterIndex*/) {} + virtual void ErrorDuplicateSubobject(D3D12_PIPELINE_STATE_SUBOBJECT_TYPE /*DuplicateType*/) {} + virtual void ErrorUnknownSubobject(UINT /*UnknownTypeValue*/) {} + + virtual ~ID3DX12PipelineParserCallbacks() = default; +}; + +// CD3DX12_PIPELINE_STATE_STREAM1 Works on RS3+ (where there is a new view instancing subobject). +// Use CD3DX12_PIPELINE_STATE_STREAM for RS2+ support. +struct CD3DX12_PIPELINE_STATE_STREAM1 +{ + CD3DX12_PIPELINE_STATE_STREAM1() = default; + CD3DX12_PIPELINE_STATE_STREAM1(const D3D12_GRAPHICS_PIPELINE_STATE_DESC& Desc) + : Flags(Desc.Flags) + , NodeMask(Desc.NodeMask) + , pRootSignature(Desc.pRootSignature) + , InputLayout(Desc.InputLayout) + , IBStripCutValue(Desc.IBStripCutValue) + , PrimitiveTopologyType(Desc.PrimitiveTopologyType) + , VS(Desc.VS) + , GS(Desc.GS) + , StreamOutput(Desc.StreamOutput) + , HS(Desc.HS) + , DS(Desc.DS) + , PS(Desc.PS) + , BlendState(CD3DX12_BLEND_DESC(Desc.BlendState)) + , DepthStencilState(CD3DX12_DEPTH_STENCIL_DESC1(Desc.DepthStencilState)) + , DSVFormat(Desc.DSVFormat) + , RasterizerState(CD3DX12_RASTERIZER_DESC(Desc.RasterizerState)) + , RTVFormats(CD3DX12_RT_FORMAT_ARRAY(Desc.RTVFormats, Desc.NumRenderTargets)) + , SampleDesc(Desc.SampleDesc) + , SampleMask(Desc.SampleMask) + , CachedPSO(Desc.CachedPSO) + , ViewInstancingDesc(CD3DX12_VIEW_INSTANCING_DESC(CD3DX12_DEFAULT())) + {} + CD3DX12_PIPELINE_STATE_STREAM1(const D3D12_COMPUTE_PIPELINE_STATE_DESC& Desc) + : Flags(Desc.Flags) + , NodeMask(Desc.NodeMask) + , pRootSignature(Desc.pRootSignature) + , CS(CD3DX12_SHADER_BYTECODE(Desc.CS)) + , CachedPSO(Desc.CachedPSO) + { + static_cast(DepthStencilState).DepthEnable = false; + } + CD3DX12_PIPELINE_STATE_STREAM_FLAGS Flags; + CD3DX12_PIPELINE_STATE_STREAM_NODE_MASK NodeMask; + CD3DX12_PIPELINE_STATE_STREAM_ROOT_SIGNATURE pRootSignature; + CD3DX12_PIPELINE_STATE_STREAM_INPUT_LAYOUT InputLayout; + CD3DX12_PIPELINE_STATE_STREAM_IB_STRIP_CUT_VALUE IBStripCutValue; + CD3DX12_PIPELINE_STATE_STREAM_PRIMITIVE_TOPOLOGY PrimitiveTopologyType; + CD3DX12_PIPELINE_STATE_STREAM_VS VS; + CD3DX12_PIPELINE_STATE_STREAM_GS GS; + CD3DX12_PIPELINE_STATE_STREAM_STREAM_OUTPUT StreamOutput; + CD3DX12_PIPELINE_STATE_STREAM_HS HS; + CD3DX12_PIPELINE_STATE_STREAM_DS DS; + CD3DX12_PIPELINE_STATE_STREAM_PS PS; + CD3DX12_PIPELINE_STATE_STREAM_CS CS; + CD3DX12_PIPELINE_STATE_STREAM_BLEND_DESC BlendState; + CD3DX12_PIPELINE_STATE_STREAM_DEPTH_STENCIL1 DepthStencilState; + CD3DX12_PIPELINE_STATE_STREAM_DEPTH_STENCIL_FORMAT DSVFormat; + CD3DX12_PIPELINE_STATE_STREAM_RASTERIZER RasterizerState; + CD3DX12_PIPELINE_STATE_STREAM_RENDER_TARGET_FORMATS RTVFormats; + CD3DX12_PIPELINE_STATE_STREAM_SAMPLE_DESC SampleDesc; + CD3DX12_PIPELINE_STATE_STREAM_SAMPLE_MASK SampleMask; + CD3DX12_PIPELINE_STATE_STREAM_CACHED_PSO CachedPSO; + CD3DX12_PIPELINE_STATE_STREAM_VIEW_INSTANCING ViewInstancingDesc; + D3D12_GRAPHICS_PIPELINE_STATE_DESC GraphicsDescV0() const + { + D3D12_GRAPHICS_PIPELINE_STATE_DESC D; + D.Flags = this->Flags; + D.NodeMask = this->NodeMask; + D.pRootSignature = this->pRootSignature; + D.InputLayout = this->InputLayout; + D.IBStripCutValue = this->IBStripCutValue; + D.PrimitiveTopologyType = this->PrimitiveTopologyType; + D.VS = this->VS; + D.GS = this->GS; + D.StreamOutput = this->StreamOutput; + D.HS = this->HS; + D.DS = this->DS; + D.PS = this->PS; + D.BlendState = this->BlendState; + D.DepthStencilState = CD3DX12_DEPTH_STENCIL_DESC1(D3D12_DEPTH_STENCIL_DESC1(this->DepthStencilState)); + D.DSVFormat = this->DSVFormat; + D.RasterizerState = this->RasterizerState; + D.NumRenderTargets = D3D12_RT_FORMAT_ARRAY(this->RTVFormats).NumRenderTargets; + memcpy(D.RTVFormats, D3D12_RT_FORMAT_ARRAY(this->RTVFormats).RTFormats, sizeof(D.RTVFormats)); + D.SampleDesc = this->SampleDesc; + D.SampleMask = this->SampleMask; + D.CachedPSO = this->CachedPSO; + return D; + } + D3D12_COMPUTE_PIPELINE_STATE_DESC ComputeDescV0() const + { + D3D12_COMPUTE_PIPELINE_STATE_DESC D; + D.Flags = this->Flags; + D.NodeMask = this->NodeMask; + D.pRootSignature = this->pRootSignature; + D.CS = this->CS; + D.CachedPSO = this->CachedPSO; + return D; + } +}; + +// CD3DX12_PIPELINE_STATE_STREAM works on RS2+ but does not support new subobject(s) added in RS3+. +// See CD3DX12_PIPELINE_STATE_STREAM1 for instance. +struct CD3DX12_PIPELINE_STATE_STREAM +{ + CD3DX12_PIPELINE_STATE_STREAM() = default; + CD3DX12_PIPELINE_STATE_STREAM(const D3D12_GRAPHICS_PIPELINE_STATE_DESC& Desc) + : Flags(Desc.Flags) + , NodeMask(Desc.NodeMask) + , pRootSignature(Desc.pRootSignature) + , InputLayout(Desc.InputLayout) + , IBStripCutValue(Desc.IBStripCutValue) + , PrimitiveTopologyType(Desc.PrimitiveTopologyType) + , VS(Desc.VS) + , GS(Desc.GS) + , StreamOutput(Desc.StreamOutput) + , HS(Desc.HS) + , DS(Desc.DS) + , PS(Desc.PS) + , BlendState(CD3DX12_BLEND_DESC(Desc.BlendState)) + , DepthStencilState(CD3DX12_DEPTH_STENCIL_DESC1(Desc.DepthStencilState)) + , DSVFormat(Desc.DSVFormat) + , RasterizerState(CD3DX12_RASTERIZER_DESC(Desc.RasterizerState)) + , RTVFormats(CD3DX12_RT_FORMAT_ARRAY(Desc.RTVFormats, Desc.NumRenderTargets)) + , SampleDesc(Desc.SampleDesc) + , SampleMask(Desc.SampleMask) + , CachedPSO(Desc.CachedPSO) + {} + CD3DX12_PIPELINE_STATE_STREAM(const D3D12_COMPUTE_PIPELINE_STATE_DESC& Desc) + : Flags(Desc.Flags) + , NodeMask(Desc.NodeMask) + , pRootSignature(Desc.pRootSignature) + , CS(CD3DX12_SHADER_BYTECODE(Desc.CS)) + , CachedPSO(Desc.CachedPSO) + {} + CD3DX12_PIPELINE_STATE_STREAM_FLAGS Flags; + CD3DX12_PIPELINE_STATE_STREAM_NODE_MASK NodeMask; + CD3DX12_PIPELINE_STATE_STREAM_ROOT_SIGNATURE pRootSignature; + CD3DX12_PIPELINE_STATE_STREAM_INPUT_LAYOUT InputLayout; + CD3DX12_PIPELINE_STATE_STREAM_IB_STRIP_CUT_VALUE IBStripCutValue; + CD3DX12_PIPELINE_STATE_STREAM_PRIMITIVE_TOPOLOGY PrimitiveTopologyType; + CD3DX12_PIPELINE_STATE_STREAM_VS VS; + CD3DX12_PIPELINE_STATE_STREAM_GS GS; + CD3DX12_PIPELINE_STATE_STREAM_STREAM_OUTPUT StreamOutput; + CD3DX12_PIPELINE_STATE_STREAM_HS HS; + CD3DX12_PIPELINE_STATE_STREAM_DS DS; + CD3DX12_PIPELINE_STATE_STREAM_PS PS; + CD3DX12_PIPELINE_STATE_STREAM_CS CS; + CD3DX12_PIPELINE_STATE_STREAM_BLEND_DESC BlendState; + CD3DX12_PIPELINE_STATE_STREAM_DEPTH_STENCIL1 DepthStencilState; + CD3DX12_PIPELINE_STATE_STREAM_DEPTH_STENCIL_FORMAT DSVFormat; + CD3DX12_PIPELINE_STATE_STREAM_RASTERIZER RasterizerState; + CD3DX12_PIPELINE_STATE_STREAM_RENDER_TARGET_FORMATS RTVFormats; + CD3DX12_PIPELINE_STATE_STREAM_SAMPLE_DESC SampleDesc; + CD3DX12_PIPELINE_STATE_STREAM_SAMPLE_MASK SampleMask; + CD3DX12_PIPELINE_STATE_STREAM_CACHED_PSO CachedPSO; + D3D12_GRAPHICS_PIPELINE_STATE_DESC GraphicsDescV0() const + { + D3D12_GRAPHICS_PIPELINE_STATE_DESC D; + D.Flags = this->Flags; + D.NodeMask = this->NodeMask; + D.pRootSignature = this->pRootSignature; + D.InputLayout = this->InputLayout; + D.IBStripCutValue = this->IBStripCutValue; + D.PrimitiveTopologyType = this->PrimitiveTopologyType; + D.VS = this->VS; + D.GS = this->GS; + D.StreamOutput = this->StreamOutput; + D.HS = this->HS; + D.DS = this->DS; + D.PS = this->PS; + D.BlendState = this->BlendState; + D.DepthStencilState = CD3DX12_DEPTH_STENCIL_DESC1(D3D12_DEPTH_STENCIL_DESC1(this->DepthStencilState)); + D.DSVFormat = this->DSVFormat; + D.RasterizerState = this->RasterizerState; + D.NumRenderTargets = D3D12_RT_FORMAT_ARRAY(this->RTVFormats).NumRenderTargets; + memcpy(D.RTVFormats, D3D12_RT_FORMAT_ARRAY(this->RTVFormats).RTFormats, sizeof(D.RTVFormats)); + D.SampleDesc = this->SampleDesc; + D.SampleMask = this->SampleMask; + D.CachedPSO = this->CachedPSO; + return D; + } + D3D12_COMPUTE_PIPELINE_STATE_DESC ComputeDescV0() const + { + D3D12_COMPUTE_PIPELINE_STATE_DESC D; + D.Flags = this->Flags; + D.NodeMask = this->NodeMask; + D.pRootSignature = this->pRootSignature; + D.CS = this->CS; + D.CachedPSO = this->CachedPSO; + return D; + } +}; + +struct CD3DX12_PIPELINE_STATE_STREAM_PARSE_HELPER : public ID3DX12PipelineParserCallbacks +{ + CD3DX12_PIPELINE_STATE_STREAM1 PipelineStream; + CD3DX12_PIPELINE_STATE_STREAM_PARSE_HELPER() noexcept + : SeenDSS(false) + { + // Adjust defaults to account for absent members. + PipelineStream.PrimitiveTopologyType = D3D12_PRIMITIVE_TOPOLOGY_TYPE_TRIANGLE; + + // Depth disabled if no DSV format specified. + static_cast(PipelineStream.DepthStencilState).DepthEnable = false; + } + + // ID3DX12PipelineParserCallbacks + void FlagsCb(D3D12_PIPELINE_STATE_FLAGS Flags) override {PipelineStream.Flags = Flags;} + void NodeMaskCb(UINT NodeMask) override {PipelineStream.NodeMask = NodeMask;} + void RootSignatureCb(ID3D12RootSignature* pRootSignature) override {PipelineStream.pRootSignature = pRootSignature;} + void InputLayoutCb(const D3D12_INPUT_LAYOUT_DESC& InputLayout) override {PipelineStream.InputLayout = InputLayout;} + void IBStripCutValueCb(D3D12_INDEX_BUFFER_STRIP_CUT_VALUE IBStripCutValue) override {PipelineStream.IBStripCutValue = IBStripCutValue;} + void PrimitiveTopologyTypeCb(D3D12_PRIMITIVE_TOPOLOGY_TYPE PrimitiveTopologyType) override {PipelineStream.PrimitiveTopologyType = PrimitiveTopologyType;} + void VSCb(const D3D12_SHADER_BYTECODE& VS) override {PipelineStream.VS = VS;} + void GSCb(const D3D12_SHADER_BYTECODE& GS) override {PipelineStream.GS = GS;} + void StreamOutputCb(const D3D12_STREAM_OUTPUT_DESC& StreamOutput) override {PipelineStream.StreamOutput = StreamOutput;} + void HSCb(const D3D12_SHADER_BYTECODE& HS) override {PipelineStream.HS = HS;} + void DSCb(const D3D12_SHADER_BYTECODE& DS) override {PipelineStream.DS = DS;} + void PSCb(const D3D12_SHADER_BYTECODE& PS) override {PipelineStream.PS = PS;} + void CSCb(const D3D12_SHADER_BYTECODE& CS) override {PipelineStream.CS = CS;} + void BlendStateCb(const D3D12_BLEND_DESC& BlendState) override {PipelineStream.BlendState = CD3DX12_BLEND_DESC(BlendState);} + void DepthStencilStateCb(const D3D12_DEPTH_STENCIL_DESC& DepthStencilState) override + { + PipelineStream.DepthStencilState = CD3DX12_DEPTH_STENCIL_DESC1(DepthStencilState); + SeenDSS = true; + } + void DepthStencilState1Cb(const D3D12_DEPTH_STENCIL_DESC1& DepthStencilState) override + { + PipelineStream.DepthStencilState = CD3DX12_DEPTH_STENCIL_DESC1(DepthStencilState); + SeenDSS = true; + } + void DSVFormatCb(DXGI_FORMAT DSVFormat) override + { + PipelineStream.DSVFormat = DSVFormat; + if (!SeenDSS && DSVFormat != DXGI_FORMAT_UNKNOWN) + { + // Re-enable depth for the default state. + static_cast(PipelineStream.DepthStencilState).DepthEnable = true; + } + } + void RasterizerStateCb(const D3D12_RASTERIZER_DESC& RasterizerState) override {PipelineStream.RasterizerState = CD3DX12_RASTERIZER_DESC(RasterizerState);} + void RTVFormatsCb(const D3D12_RT_FORMAT_ARRAY& RTVFormats) override {PipelineStream.RTVFormats = RTVFormats;} + void SampleDescCb(const DXGI_SAMPLE_DESC& SampleDesc) override {PipelineStream.SampleDesc = SampleDesc;} + void SampleMaskCb(UINT SampleMask) override {PipelineStream.SampleMask = SampleMask;} + void ViewInstancingCb(const D3D12_VIEW_INSTANCING_DESC& ViewInstancingDesc) override {PipelineStream.ViewInstancingDesc = CD3DX12_VIEW_INSTANCING_DESC(ViewInstancingDesc);} + void CachedPSOCb(const D3D12_CACHED_PIPELINE_STATE& CachedPSO) override {PipelineStream.CachedPSO = CachedPSO;} + +private: + bool SeenDSS; +}; + +inline D3D12_PIPELINE_STATE_SUBOBJECT_TYPE D3DX12GetBaseSubobjectType(D3D12_PIPELINE_STATE_SUBOBJECT_TYPE SubobjectType) +{ + switch (SubobjectType) + { + case D3D12_PIPELINE_STATE_SUBOBJECT_TYPE_DEPTH_STENCIL1: + return D3D12_PIPELINE_STATE_SUBOBJECT_TYPE_DEPTH_STENCIL; + default: + return SubobjectType; + } +} + +inline HRESULT D3DX12ParsePipelineStream(const D3D12_PIPELINE_STATE_STREAM_DESC& Desc, ID3DX12PipelineParserCallbacks* pCallbacks) +{ + if (pCallbacks == nullptr) + { + return E_INVALIDARG; + } + + if (Desc.SizeInBytes == 0 || Desc.pPipelineStateSubobjectStream == nullptr) + { + pCallbacks->ErrorBadInputParameter(1); // first parameter issue + return E_INVALIDARG; + } + + bool SubobjectSeen[D3D12_PIPELINE_STATE_SUBOBJECT_TYPE_MAX_VALID] = {}; + for (SIZE_T CurOffset = 0, SizeOfSubobject = 0; CurOffset < Desc.SizeInBytes; CurOffset += SizeOfSubobject) + { + BYTE* pStream = static_cast(Desc.pPipelineStateSubobjectStream)+CurOffset; + auto SubobjectType = *reinterpret_cast(pStream); + if (SubobjectType < 0 || SubobjectType >= D3D12_PIPELINE_STATE_SUBOBJECT_TYPE_MAX_VALID) + { + pCallbacks->ErrorUnknownSubobject(SubobjectType); + return E_INVALIDARG; + } + if (SubobjectSeen[D3DX12GetBaseSubobjectType(SubobjectType)]) + { + pCallbacks->ErrorDuplicateSubobject(SubobjectType); + return E_INVALIDARG; // disallow subobject duplicates in a stream + } + SubobjectSeen[SubobjectType] = true; + switch (SubobjectType) + { + case D3D12_PIPELINE_STATE_SUBOBJECT_TYPE_ROOT_SIGNATURE: + pCallbacks->RootSignatureCb(*reinterpret_cast(pStream)); + SizeOfSubobject = sizeof(CD3DX12_PIPELINE_STATE_STREAM::pRootSignature); + break; + case D3D12_PIPELINE_STATE_SUBOBJECT_TYPE_VS: + pCallbacks->VSCb(*reinterpret_cast(pStream)); + SizeOfSubobject = sizeof(CD3DX12_PIPELINE_STATE_STREAM::VS); + break; + case D3D12_PIPELINE_STATE_SUBOBJECT_TYPE_PS: + pCallbacks->PSCb(*reinterpret_cast(pStream)); + SizeOfSubobject = sizeof(CD3DX12_PIPELINE_STATE_STREAM::PS); + break; + case D3D12_PIPELINE_STATE_SUBOBJECT_TYPE_DS: + pCallbacks->DSCb(*reinterpret_cast(pStream)); + SizeOfSubobject = sizeof(CD3DX12_PIPELINE_STATE_STREAM::DS); + break; + case D3D12_PIPELINE_STATE_SUBOBJECT_TYPE_HS: + pCallbacks->HSCb(*reinterpret_cast(pStream)); + SizeOfSubobject = sizeof(CD3DX12_PIPELINE_STATE_STREAM::HS); + break; + case D3D12_PIPELINE_STATE_SUBOBJECT_TYPE_GS: + pCallbacks->GSCb(*reinterpret_cast(pStream)); + SizeOfSubobject = sizeof(CD3DX12_PIPELINE_STATE_STREAM::GS); + break; + case D3D12_PIPELINE_STATE_SUBOBJECT_TYPE_CS: + pCallbacks->CSCb(*reinterpret_cast(pStream)); + SizeOfSubobject = sizeof(CD3DX12_PIPELINE_STATE_STREAM::CS); + break; + case D3D12_PIPELINE_STATE_SUBOBJECT_TYPE_STREAM_OUTPUT: + pCallbacks->StreamOutputCb(*reinterpret_cast(pStream)); + SizeOfSubobject = sizeof(CD3DX12_PIPELINE_STATE_STREAM::StreamOutput); + break; + case D3D12_PIPELINE_STATE_SUBOBJECT_TYPE_BLEND: + pCallbacks->BlendStateCb(*reinterpret_cast(pStream)); + SizeOfSubobject = sizeof(CD3DX12_PIPELINE_STATE_STREAM::BlendState); + break; + case D3D12_PIPELINE_STATE_SUBOBJECT_TYPE_SAMPLE_MASK: + pCallbacks->SampleMaskCb(*reinterpret_cast(pStream)); + SizeOfSubobject = sizeof(CD3DX12_PIPELINE_STATE_STREAM::SampleMask); + break; + case D3D12_PIPELINE_STATE_SUBOBJECT_TYPE_RASTERIZER: + pCallbacks->RasterizerStateCb(*reinterpret_cast(pStream)); + SizeOfSubobject = sizeof(CD3DX12_PIPELINE_STATE_STREAM::RasterizerState); + break; + case D3D12_PIPELINE_STATE_SUBOBJECT_TYPE_DEPTH_STENCIL: + pCallbacks->DepthStencilStateCb(*reinterpret_cast(pStream)); + SizeOfSubobject = sizeof(CD3DX12_PIPELINE_STATE_STREAM_DEPTH_STENCIL); + break; + case D3D12_PIPELINE_STATE_SUBOBJECT_TYPE_DEPTH_STENCIL1: + pCallbacks->DepthStencilState1Cb(*reinterpret_cast(pStream)); + SizeOfSubobject = sizeof(CD3DX12_PIPELINE_STATE_STREAM::DepthStencilState); + break; + case D3D12_PIPELINE_STATE_SUBOBJECT_TYPE_INPUT_LAYOUT: + pCallbacks->InputLayoutCb(*reinterpret_cast(pStream)); + SizeOfSubobject = sizeof(CD3DX12_PIPELINE_STATE_STREAM::InputLayout); + break; + case D3D12_PIPELINE_STATE_SUBOBJECT_TYPE_IB_STRIP_CUT_VALUE: + pCallbacks->IBStripCutValueCb(*reinterpret_cast(pStream)); + SizeOfSubobject = sizeof(CD3DX12_PIPELINE_STATE_STREAM::IBStripCutValue); + break; + case D3D12_PIPELINE_STATE_SUBOBJECT_TYPE_PRIMITIVE_TOPOLOGY: + pCallbacks->PrimitiveTopologyTypeCb(*reinterpret_cast(pStream)); + SizeOfSubobject = sizeof(CD3DX12_PIPELINE_STATE_STREAM::PrimitiveTopologyType); + break; + case D3D12_PIPELINE_STATE_SUBOBJECT_TYPE_RENDER_TARGET_FORMATS: + pCallbacks->RTVFormatsCb(*reinterpret_cast(pStream)); + SizeOfSubobject = sizeof(CD3DX12_PIPELINE_STATE_STREAM::RTVFormats); + break; + case D3D12_PIPELINE_STATE_SUBOBJECT_TYPE_DEPTH_STENCIL_FORMAT: + pCallbacks->DSVFormatCb(*reinterpret_cast(pStream)); + SizeOfSubobject = sizeof(CD3DX12_PIPELINE_STATE_STREAM::DSVFormat); + break; + case D3D12_PIPELINE_STATE_SUBOBJECT_TYPE_SAMPLE_DESC: + pCallbacks->SampleDescCb(*reinterpret_cast(pStream)); + SizeOfSubobject = sizeof(CD3DX12_PIPELINE_STATE_STREAM::SampleDesc); + break; + case D3D12_PIPELINE_STATE_SUBOBJECT_TYPE_NODE_MASK: + pCallbacks->NodeMaskCb(*reinterpret_cast(pStream)); + SizeOfSubobject = sizeof(CD3DX12_PIPELINE_STATE_STREAM::NodeMask); + break; + case D3D12_PIPELINE_STATE_SUBOBJECT_TYPE_CACHED_PSO: + pCallbacks->CachedPSOCb(*reinterpret_cast(pStream)); + SizeOfSubobject = sizeof(CD3DX12_PIPELINE_STATE_STREAM::CachedPSO); + break; + case D3D12_PIPELINE_STATE_SUBOBJECT_TYPE_FLAGS: + pCallbacks->FlagsCb(*reinterpret_cast(pStream)); + SizeOfSubobject = sizeof(CD3DX12_PIPELINE_STATE_STREAM::Flags); + break; + case D3D12_PIPELINE_STATE_SUBOBJECT_TYPE_VIEW_INSTANCING: + pCallbacks->ViewInstancingCb(*reinterpret_cast(pStream)); + SizeOfSubobject = sizeof(CD3DX12_PIPELINE_STATE_STREAM1::ViewInstancingDesc); + break; + default: + pCallbacks->ErrorUnknownSubobject(SubobjectType); + return E_INVALIDARG; + break; + } + } + + return S_OK; +} + +//------------------------------------------------------------------------------------------------ +inline bool operator==( const D3D12_CLEAR_VALUE &a, const D3D12_CLEAR_VALUE &b) +{ + if (a.Format != b.Format) return false; + if (a.Format == DXGI_FORMAT_D24_UNORM_S8_UINT + || a.Format == DXGI_FORMAT_D16_UNORM + || a.Format == DXGI_FORMAT_D32_FLOAT + || a.Format == DXGI_FORMAT_D32_FLOAT_S8X24_UINT) + { + return (a.DepthStencil.Depth == b.DepthStencil.Depth) && + (a.DepthStencil.Stencil == b.DepthStencil.Stencil); + } else { + return (a.Color[0] == b.Color[0]) && + (a.Color[1] == b.Color[1]) && + (a.Color[2] == b.Color[2]) && + (a.Color[3] == b.Color[3]); + } +} +inline bool operator==( const D3D12_RENDER_PASS_BEGINNING_ACCESS_CLEAR_PARAMETERS &a, const D3D12_RENDER_PASS_BEGINNING_ACCESS_CLEAR_PARAMETERS &b) +{ + return a.ClearValue == b.ClearValue; +} +inline bool operator==( const D3D12_RENDER_PASS_ENDING_ACCESS_RESOLVE_PARAMETERS &a, const D3D12_RENDER_PASS_ENDING_ACCESS_RESOLVE_PARAMETERS &b) +{ + if (a.pSrcResource != b.pSrcResource) return false; + if (a.pDstResource != b.pDstResource) return false; + if (a.SubresourceCount != b.SubresourceCount) return false; + if (a.Format != b.Format) return false; + if (a.ResolveMode != b.ResolveMode) return false; + if (a.PreserveResolveSource != b.PreserveResolveSource) return false; + return true; +} +inline bool operator==( const D3D12_RENDER_PASS_BEGINNING_ACCESS &a, const D3D12_RENDER_PASS_BEGINNING_ACCESS &b) +{ + if (a.Type != b.Type) return false; + if (a.Type == D3D12_RENDER_PASS_BEGINNING_ACCESS_TYPE_CLEAR && !(a.Clear == b.Clear)) return false; + return true; +} +inline bool operator==( const D3D12_RENDER_PASS_ENDING_ACCESS &a, const D3D12_RENDER_PASS_ENDING_ACCESS &b) +{ + if (a.Type != b.Type) return false; + if (a.Type == D3D12_RENDER_PASS_ENDING_ACCESS_TYPE_RESOLVE && !(a.Resolve == b.Resolve)) return false; + return true; +} +inline bool operator==( const D3D12_RENDER_PASS_RENDER_TARGET_DESC &a, const D3D12_RENDER_PASS_RENDER_TARGET_DESC &b) +{ + if (a.cpuDescriptor.ptr != b.cpuDescriptor.ptr) return false; + if (!(a.BeginningAccess == b.BeginningAccess)) return false; + if (!(a.EndingAccess == b.EndingAccess)) return false; + return true; +} +inline bool operator==( const D3D12_RENDER_PASS_DEPTH_STENCIL_DESC &a, const D3D12_RENDER_PASS_DEPTH_STENCIL_DESC &b) +{ + if (a.cpuDescriptor.ptr != b.cpuDescriptor.ptr) return false; + if (!(a.DepthBeginningAccess == b.DepthBeginningAccess)) return false; + if (!(a.StencilBeginningAccess == b.StencilBeginningAccess)) return false; + if (!(a.DepthEndingAccess == b.DepthEndingAccess)) return false; + if (!(a.StencilEndingAccess == b.StencilEndingAccess)) return false; + return true; +} + + +#ifndef D3DX12_NO_STATE_OBJECT_HELPERS + +//================================================================================================ +// D3DX12 State Object Creation Helpers +// +// Helper classes for creating new style state objects out of an arbitrary set of subobjects. +// Uses STL +// +// Start by instantiating CD3DX12_STATE_OBJECT_DESC (see it's public methods). +// One of its methods is CreateSubobject(), which has a comment showing a couple of options for +// defining subobjects using the helper classes for each subobject (CD3DX12_DXIL_LIBRARY_SUBOBJECT +// etc.). The subobject helpers each have methods specific to the subobject for configuring it's +// contents. +// +//================================================================================================ +#include +#include +#include +#include +#include + +//------------------------------------------------------------------------------------------------ +class CD3DX12_STATE_OBJECT_DESC +{ +public: + CD3DX12_STATE_OBJECT_DESC() + { + Init(D3D12_STATE_OBJECT_TYPE_COLLECTION); + } + CD3DX12_STATE_OBJECT_DESC(D3D12_STATE_OBJECT_TYPE Type) + { + Init(Type); + } + void SetStateObjectType(D3D12_STATE_OBJECT_TYPE Type) { m_Desc.Type = Type; } + operator const D3D12_STATE_OBJECT_DESC&() + { + // Do final preparation work + m_RepointedAssociations.clear(); + m_SubobjectArray.clear(); + m_SubobjectArray.reserve(m_Desc.NumSubobjects); + // Flatten subobjects into an array (each flattened subobject still has a + // member that's a pointer to it's desc that's not flattened) + for (auto Iter = m_SubobjectList.begin(); + Iter != m_SubobjectList.end(); Iter++) + { + m_SubobjectArray.push_back(*Iter); + // Store new location in array so we can redirect pointers contained in subobjects + Iter->pSubobjectArrayLocation = &m_SubobjectArray.back(); + } + // For subobjects with pointer fields, create a new copy of those subobject definitions + // with fixed pointers + for (UINT i = 0; i < m_Desc.NumSubobjects; i++) + { + if (m_SubobjectArray[i].Type == D3D12_STATE_SUBOBJECT_TYPE_SUBOBJECT_TO_EXPORTS_ASSOCIATION) + { + auto pOriginalSubobjectAssociation = + reinterpret_cast(m_SubobjectArray[i].pDesc); + D3D12_SUBOBJECT_TO_EXPORTS_ASSOCIATION Repointed = *pOriginalSubobjectAssociation; + auto pWrapper = + static_cast(pOriginalSubobjectAssociation->pSubobjectToAssociate); + Repointed.pSubobjectToAssociate = pWrapper->pSubobjectArrayLocation; + m_RepointedAssociations.push_back(Repointed); + m_SubobjectArray[i].pDesc = &m_RepointedAssociations.back(); + } + } + // Below: using ugly way to get pointer in case .data() is not defined + m_Desc.pSubobjects = m_Desc.NumSubobjects ? &m_SubobjectArray[0] : nullptr; + return m_Desc; + } + operator const D3D12_STATE_OBJECT_DESC*() + { + // Cast calls the above final preparation work + return &static_cast(*this); + } + + // CreateSubobject creates a sububject helper (e.g. CD3DX12_HIT_GROUP_SUBOBJECT) + // whose lifetime is owned by this class. + // e.g. + // + // CD3DX12_STATE_OBJECT_DESC Collection1(D3D12_STATE_OBJECT_TYPE_COLLECTION); + // auto Lib0 = Collection1.CreateSubobject(); + // Lib0->SetDXILLibrary(&pMyAppDxilLibs[0]); + // Lib0->DefineExport(L"rayGenShader0"); // in practice these export listings might be + // // data/engine driven + // etc. + // + // Alternatively, users can instantiate sububject helpers explicitly, such as via local + // variables instead, passing the state object desc that should point to it into the helper + // constructor (or call mySubobjectHelper.AddToStateObject(Collection1)). + // In this alternative scenario, the user must keep the subobject alive as long as the state + // object it is associated with is alive, else it's pointer references will be stale. + // e.g. + // + // CD3DX12_STATE_OBJECT_DESC RaytracingState2(D3D12_STATE_OBJECT_TYPE_RAYTRACING_PIPELINE); + // CD3DX12_DXIL_LIBRARY_SUBOBJECT LibA(RaytracingState2); + // LibA.SetDXILLibrary(&pMyAppDxilLibs[4]); // not manually specifying exports + // // - meaning all exports in the libraries + // // are exported + // etc. + + template + T* CreateSubobject() + { + T* pSubobject = new T(*this); + m_OwnedSubobjectHelpers.emplace_back(pSubobject); + return pSubobject; + } + +private: + D3D12_STATE_SUBOBJECT* TrackSubobject(D3D12_STATE_SUBOBJECT_TYPE Type, void* pDesc) + { + SUBOBJECT_WRAPPER Subobject; + Subobject.pSubobjectArrayLocation = nullptr; + Subobject.Type = Type; + Subobject.pDesc = pDesc; + m_SubobjectList.push_back(Subobject); + m_Desc.NumSubobjects++; + return &m_SubobjectList.back(); + } + void Init(D3D12_STATE_OBJECT_TYPE Type) + { + SetStateObjectType(Type); + m_Desc.pSubobjects = nullptr; + m_Desc.NumSubobjects = 0; + m_SubobjectList.clear(); + m_SubobjectArray.clear(); + m_RepointedAssociations.clear(); + } + typedef struct SUBOBJECT_WRAPPER : public D3D12_STATE_SUBOBJECT + { + D3D12_STATE_SUBOBJECT* pSubobjectArrayLocation; // new location when flattened into array + // for repointing pointers in subobjects + } SUBOBJECT_WRAPPER; + D3D12_STATE_OBJECT_DESC m_Desc; + std::list m_SubobjectList; // Pointers to list nodes handed out so + // these can be edited live + std::vector m_SubobjectArray; // Built at the end, copying list contents + + std::list + m_RepointedAssociations; // subobject type that contains pointers to other subobjects, + // repointed to flattened array + + class StringContainer + { + public: + LPCWSTR LocalCopy(LPCWSTR string, bool bSingleString = false) + { + if (string) + { + if (bSingleString) + { + m_Strings.clear(); + m_Strings.push_back(string); + } + else + { + m_Strings.push_back(string); + } + return m_Strings.back().c_str(); + } + else + { + return nullptr; + } + } + void clear() { m_Strings.clear(); } + private: + std::list m_Strings; + }; + + class SUBOBJECT_HELPER_BASE + { + public: + SUBOBJECT_HELPER_BASE() { Init(); }; + virtual ~SUBOBJECT_HELPER_BASE() {}; + virtual D3D12_STATE_SUBOBJECT_TYPE Type() const = 0; + void AddToStateObject(CD3DX12_STATE_OBJECT_DESC& ContainingStateObject) + { + m_pSubobject = ContainingStateObject.TrackSubobject(Type(), Data()); + } + protected: + virtual void* Data() = 0; + void Init() { m_pSubobject = nullptr; } + D3D12_STATE_SUBOBJECT* m_pSubobject; + }; + +#if(__cplusplus >= 201103L) + std::list> m_OwnedSubobjectHelpers; +#else + class OWNED_HELPER + { + public: + OWNED_HELPER(const SUBOBJECT_HELPER_BASE* pHelper) { m_pHelper = pHelper; } + ~OWNED_HELPER() { delete m_pHelper; } + const SUBOBJECT_HELPER_BASE* m_pHelper; + }; + + std::list m_OwnedSubobjectHelpers; +#endif + + friend class CD3DX12_DXIL_LIBRARY_SUBOBJECT; + friend class CD3DX12_EXISTING_COLLECTION_SUBOBJECT; + friend class CD3DX12_SUBOBJECT_TO_EXPORTS_ASSOCIATION_SUBOBJECT; + friend class CD3DX12_DXIL_SUBOBJECT_TO_EXPORTS_ASSOCIATION; + friend class CD3DX12_HIT_GROUP_SUBOBJECT; + friend class CD3DX12_RAYTRACING_SHADER_CONFIG_SUBOBJECT; + friend class CD3DX12_RAYTRACING_PIPELINE_CONFIG_SUBOBJECT; + friend class CD3DX12_GLOBAL_ROOT_SIGNATURE_SUBOBJECT; + friend class CD3DX12_LOCAL_ROOT_SIGNATURE_SUBOBJECT; + friend class CD3DX12_STATE_OBJECT_CONFIG_SUBOBJECT; + friend class CD3DX12_NODE_MASK_SUBOBJECT; +}; + +//------------------------------------------------------------------------------------------------ +class CD3DX12_DXIL_LIBRARY_SUBOBJECT + : public CD3DX12_STATE_OBJECT_DESC::SUBOBJECT_HELPER_BASE +{ +public: + CD3DX12_DXIL_LIBRARY_SUBOBJECT() + { + Init(); + } + CD3DX12_DXIL_LIBRARY_SUBOBJECT(CD3DX12_STATE_OBJECT_DESC& ContainingStateObject) + { + Init(); + AddToStateObject(ContainingStateObject); + } + void SetDXILLibrary(D3D12_SHADER_BYTECODE*pCode) + { + static const D3D12_SHADER_BYTECODE Default = {}; + m_Desc.DXILLibrary = pCode ? *pCode : Default; + } + void DefineExport( + LPCWSTR Name, + LPCWSTR ExportToRename = nullptr, + D3D12_EXPORT_FLAGS Flags = D3D12_EXPORT_FLAG_NONE) + { + D3D12_EXPORT_DESC Export; + Export.Name = m_Strings.LocalCopy(Name); + Export.ExportToRename = m_Strings.LocalCopy(ExportToRename); + Export.Flags = Flags; + m_Exports.push_back(Export); + m_Desc.pExports = &m_Exports[0]; // using ugly way to get pointer in case .data() is not defined + m_Desc.NumExports = static_cast(m_Exports.size()); + } + template + void DefineExports(LPCWSTR(&Exports)[N]) + { + for (UINT i = 0; i < N; i++) + { + DefineExport(Exports[i]); + } + } + void DefineExports(LPCWSTR* Exports, UINT N) + { + for (UINT i = 0; i < N; i++) + { + DefineExport(Exports[i]); + } + } + D3D12_STATE_SUBOBJECT_TYPE Type() const + { + return D3D12_STATE_SUBOBJECT_TYPE_DXIL_LIBRARY; + } + operator const D3D12_STATE_SUBOBJECT&() const { return *m_pSubobject; } + operator const D3D12_DXIL_LIBRARY_DESC&() const { return m_Desc; } +private: + void Init() + { + SUBOBJECT_HELPER_BASE::Init(); + m_Desc = {}; + m_Strings.clear(); + m_Exports.clear(); + } + void* Data() { return &m_Desc; } + D3D12_DXIL_LIBRARY_DESC m_Desc; + CD3DX12_STATE_OBJECT_DESC::StringContainer m_Strings; + std::vector m_Exports; +}; + +//------------------------------------------------------------------------------------------------ +class CD3DX12_EXISTING_COLLECTION_SUBOBJECT + : public CD3DX12_STATE_OBJECT_DESC::SUBOBJECT_HELPER_BASE +{ +public: + CD3DX12_EXISTING_COLLECTION_SUBOBJECT() + { + Init(); + } + CD3DX12_EXISTING_COLLECTION_SUBOBJECT(CD3DX12_STATE_OBJECT_DESC& ContainingStateObject) + { + Init(); + AddToStateObject(ContainingStateObject); + } + void SetExistingCollection(ID3D12StateObject*pExistingCollection) + { + m_Desc.pExistingCollection = pExistingCollection; + m_CollectionRef = pExistingCollection; + } + void DefineExport( + LPCWSTR Name, + LPCWSTR ExportToRename = nullptr, + D3D12_EXPORT_FLAGS Flags = D3D12_EXPORT_FLAG_NONE) + { + D3D12_EXPORT_DESC Export; + Export.Name = m_Strings.LocalCopy(Name); + Export.ExportToRename = m_Strings.LocalCopy(ExportToRename); + Export.Flags = Flags; + m_Exports.push_back(Export); + m_Desc.pExports = &m_Exports[0]; // using ugly way to get pointer in case .data() is not defined + m_Desc.NumExports = static_cast(m_Exports.size()); + } + template + void DefineExports(LPCWSTR(&Exports)[N]) + { + for (UINT i = 0; i < N; i++) + { + DefineExport(Exports[i]); + } + } + void DefineExports(LPCWSTR* Exports, UINT N) + { + for (UINT i = 0; i < N; i++) + { + DefineExport(Exports[i]); + } + } + D3D12_STATE_SUBOBJECT_TYPE Type() const + { + return D3D12_STATE_SUBOBJECT_TYPE_EXISTING_COLLECTION; + } + operator const D3D12_STATE_SUBOBJECT&() const { return *m_pSubobject; } + operator const D3D12_EXISTING_COLLECTION_DESC&() const { return m_Desc; } +private: + void Init() + { + SUBOBJECT_HELPER_BASE::Init(); + m_Desc = {}; + m_CollectionRef = nullptr; + m_Strings.clear(); + m_Exports.clear(); + } + void* Data() { return &m_Desc; } + D3D12_EXISTING_COLLECTION_DESC m_Desc; + Microsoft::WRL::ComPtr m_CollectionRef; + CD3DX12_STATE_OBJECT_DESC::StringContainer m_Strings; + std::vector m_Exports; +}; + +//------------------------------------------------------------------------------------------------ +class CD3DX12_SUBOBJECT_TO_EXPORTS_ASSOCIATION_SUBOBJECT + : public CD3DX12_STATE_OBJECT_DESC::SUBOBJECT_HELPER_BASE +{ +public: + CD3DX12_SUBOBJECT_TO_EXPORTS_ASSOCIATION_SUBOBJECT() + { + Init(); + } + CD3DX12_SUBOBJECT_TO_EXPORTS_ASSOCIATION_SUBOBJECT(CD3DX12_STATE_OBJECT_DESC& ContainingStateObject) + { + Init(); + AddToStateObject(ContainingStateObject); + } + void SetSubobjectToAssociate(const D3D12_STATE_SUBOBJECT& SubobjectToAssociate) + { + m_Desc.pSubobjectToAssociate = &SubobjectToAssociate; + } + void AddExport(LPCWSTR Export) + { + m_Desc.NumExports++; + m_Exports.push_back(m_Strings.LocalCopy(Export)); + m_Desc.pExports = &m_Exports[0]; // using ugly way to get pointer in case .data() is not defined + } + template + void AddExports(LPCWSTR (&Exports)[N]) + { + for (UINT i = 0; i < N; i++) + { + AddExport(Exports[i]); + } + } + void AddExports(LPCWSTR* Exports, UINT N) + { + for (UINT i = 0; i < N; i++) + { + AddExport(Exports[i]); + } + } + D3D12_STATE_SUBOBJECT_TYPE Type() const + { + return D3D12_STATE_SUBOBJECT_TYPE_SUBOBJECT_TO_EXPORTS_ASSOCIATION; + } + operator const D3D12_STATE_SUBOBJECT&() const { return *m_pSubobject; } + operator const D3D12_SUBOBJECT_TO_EXPORTS_ASSOCIATION&() const { return m_Desc; } +private: + void Init() + { + SUBOBJECT_HELPER_BASE::Init(); + m_Desc = {}; + m_Strings.clear(); + m_Exports.clear(); + } + void* Data() { return &m_Desc; } + D3D12_SUBOBJECT_TO_EXPORTS_ASSOCIATION m_Desc; + CD3DX12_STATE_OBJECT_DESC::StringContainer m_Strings; + std::vector m_Exports; +}; + +//------------------------------------------------------------------------------------------------ +class CD3DX12_DXIL_SUBOBJECT_TO_EXPORTS_ASSOCIATION + : public CD3DX12_STATE_OBJECT_DESC::SUBOBJECT_HELPER_BASE +{ +public: + CD3DX12_DXIL_SUBOBJECT_TO_EXPORTS_ASSOCIATION() + { + Init(); + } + CD3DX12_DXIL_SUBOBJECT_TO_EXPORTS_ASSOCIATION(CD3DX12_STATE_OBJECT_DESC& ContainingStateObject) + { + Init(); + AddToStateObject(ContainingStateObject); + } + void SetSubobjectNameToAssociate(LPCWSTR SubobjectToAssociate) + { + m_Desc.SubobjectToAssociate = m_SubobjectName.LocalCopy(SubobjectToAssociate, true); + } + void AddExport(LPCWSTR Export) + { + m_Desc.NumExports++; + m_Exports.push_back(m_Strings.LocalCopy(Export)); + m_Desc.pExports = &m_Exports[0]; // using ugly way to get pointer in case .data() is not defined + } + template + void AddExports(LPCWSTR (&Exports)[N]) + { + for (UINT i = 0; i < N; i++) + { + AddExport(Exports[i]); + } + } + void AddExports(LPCWSTR* Exports, UINT N) + { + for (UINT i = 0; i < N; i++) + { + AddExport(Exports[i]); + } + } + D3D12_STATE_SUBOBJECT_TYPE Type() const + { + return D3D12_STATE_SUBOBJECT_TYPE_DXIL_SUBOBJECT_TO_EXPORTS_ASSOCIATION; + } + operator const D3D12_STATE_SUBOBJECT&() const { return *m_pSubobject; } + operator const D3D12_DXIL_SUBOBJECT_TO_EXPORTS_ASSOCIATION&() const { return m_Desc; } +private: + void Init() + { + SUBOBJECT_HELPER_BASE::Init(); + m_Desc = {}; + m_Strings.clear(); + m_SubobjectName.clear(); + m_Exports.clear(); + } + void* Data() { return &m_Desc; } + D3D12_DXIL_SUBOBJECT_TO_EXPORTS_ASSOCIATION m_Desc; + CD3DX12_STATE_OBJECT_DESC::StringContainer m_Strings; + CD3DX12_STATE_OBJECT_DESC::StringContainer m_SubobjectName; + std::vector m_Exports; +}; + +//------------------------------------------------------------------------------------------------ +class CD3DX12_HIT_GROUP_SUBOBJECT + : public CD3DX12_STATE_OBJECT_DESC::SUBOBJECT_HELPER_BASE +{ +public: + CD3DX12_HIT_GROUP_SUBOBJECT() + { + Init(); + } + CD3DX12_HIT_GROUP_SUBOBJECT(CD3DX12_STATE_OBJECT_DESC& ContainingStateObject) + { + Init(); + AddToStateObject(ContainingStateObject); + } + void SetHitGroupExport(LPCWSTR exportName) + { + m_Desc.HitGroupExport = m_Strings[0].LocalCopy(exportName, true); + } + void SetHitGroupType(D3D12_HIT_GROUP_TYPE Type) { m_Desc.Type = Type; } + void SetAnyHitShaderImport(LPCWSTR importName) + { + m_Desc.AnyHitShaderImport = m_Strings[1].LocalCopy(importName, true); + } + void SetClosestHitShaderImport(LPCWSTR importName) + { + m_Desc.ClosestHitShaderImport = m_Strings[2].LocalCopy(importName, true); + } + void SetIntersectionShaderImport(LPCWSTR importName) + { + m_Desc.IntersectionShaderImport = m_Strings[3].LocalCopy(importName, true); + } + D3D12_STATE_SUBOBJECT_TYPE Type() const + { + return D3D12_STATE_SUBOBJECT_TYPE_HIT_GROUP; + } + operator const D3D12_STATE_SUBOBJECT&() const { return *m_pSubobject; } + operator const D3D12_HIT_GROUP_DESC&() const { return m_Desc; } +private: + void Init() + { + SUBOBJECT_HELPER_BASE::Init(); + m_Desc = {}; + for (UINT i = 0; i < m_NumStrings; i++) + { + m_Strings[i].clear(); + } + } + void* Data() { return &m_Desc; } + D3D12_HIT_GROUP_DESC m_Desc; + static const UINT m_NumStrings = 4; + CD3DX12_STATE_OBJECT_DESC::StringContainer + m_Strings[m_NumStrings]; // one string for every entrypoint name +}; + +//------------------------------------------------------------------------------------------------ +class CD3DX12_RAYTRACING_SHADER_CONFIG_SUBOBJECT + : public CD3DX12_STATE_OBJECT_DESC::SUBOBJECT_HELPER_BASE +{ +public: + CD3DX12_RAYTRACING_SHADER_CONFIG_SUBOBJECT() + { + Init(); + } + CD3DX12_RAYTRACING_SHADER_CONFIG_SUBOBJECT(CD3DX12_STATE_OBJECT_DESC& ContainingStateObject) + { + Init(); + AddToStateObject(ContainingStateObject); + } + void Config(UINT MaxPayloadSizeInBytes, UINT MaxAttributeSizeInBytes) + { + m_Desc.MaxPayloadSizeInBytes = MaxPayloadSizeInBytes; + m_Desc.MaxAttributeSizeInBytes = MaxAttributeSizeInBytes; + } + D3D12_STATE_SUBOBJECT_TYPE Type() const + { + return D3D12_STATE_SUBOBJECT_TYPE_RAYTRACING_SHADER_CONFIG; + } + operator const D3D12_STATE_SUBOBJECT&() const { return *m_pSubobject; } + operator const D3D12_RAYTRACING_SHADER_CONFIG&() const { return m_Desc; } +private: + void Init() + { + SUBOBJECT_HELPER_BASE::Init(); + m_Desc = {}; + } + void* Data() { return &m_Desc; } + D3D12_RAYTRACING_SHADER_CONFIG m_Desc; +}; + +//------------------------------------------------------------------------------------------------ +class CD3DX12_RAYTRACING_PIPELINE_CONFIG_SUBOBJECT + : public CD3DX12_STATE_OBJECT_DESC::SUBOBJECT_HELPER_BASE +{ +public: + CD3DX12_RAYTRACING_PIPELINE_CONFIG_SUBOBJECT() + { + Init(); + } + CD3DX12_RAYTRACING_PIPELINE_CONFIG_SUBOBJECT(CD3DX12_STATE_OBJECT_DESC& ContainingStateObject) + { + Init(); + AddToStateObject(ContainingStateObject); + } + void Config(UINT MaxTraceRecursionDepth) + { + m_Desc.MaxTraceRecursionDepth = MaxTraceRecursionDepth; + } + D3D12_STATE_SUBOBJECT_TYPE Type() const + { + return D3D12_STATE_SUBOBJECT_TYPE_RAYTRACING_PIPELINE_CONFIG; + } + operator const D3D12_STATE_SUBOBJECT&() const { return *m_pSubobject; } + operator const D3D12_RAYTRACING_PIPELINE_CONFIG&() const { return m_Desc; } +private: + void Init() + { + SUBOBJECT_HELPER_BASE::Init(); + m_Desc = {}; + } + void* Data() { return &m_Desc; } + D3D12_RAYTRACING_PIPELINE_CONFIG m_Desc; +}; + +//------------------------------------------------------------------------------------------------ +class CD3DX12_GLOBAL_ROOT_SIGNATURE_SUBOBJECT + : public CD3DX12_STATE_OBJECT_DESC::SUBOBJECT_HELPER_BASE +{ +public: + CD3DX12_GLOBAL_ROOT_SIGNATURE_SUBOBJECT() + { + Init(); + } + CD3DX12_GLOBAL_ROOT_SIGNATURE_SUBOBJECT(CD3DX12_STATE_OBJECT_DESC& ContainingStateObject) + { + Init(); + AddToStateObject(ContainingStateObject); + } + void SetRootSignature(ID3D12RootSignature* pRootSig) + { + m_pRootSig = pRootSig; + } + D3D12_STATE_SUBOBJECT_TYPE Type() const + { + return D3D12_STATE_SUBOBJECT_TYPE_GLOBAL_ROOT_SIGNATURE; + } + operator const D3D12_STATE_SUBOBJECT&() const { return *m_pSubobject; } + operator ID3D12RootSignature*() const { return m_pRootSig.Get(); } +private: + void Init() + { + SUBOBJECT_HELPER_BASE::Init(); + m_pRootSig = nullptr; + } + void* Data() { return m_pRootSig.GetAddressOf(); } + Microsoft::WRL::ComPtr m_pRootSig; +}; + +//------------------------------------------------------------------------------------------------ +class CD3DX12_LOCAL_ROOT_SIGNATURE_SUBOBJECT + : public CD3DX12_STATE_OBJECT_DESC::SUBOBJECT_HELPER_BASE +{ +public: + CD3DX12_LOCAL_ROOT_SIGNATURE_SUBOBJECT() + { + Init(); + } + CD3DX12_LOCAL_ROOT_SIGNATURE_SUBOBJECT(CD3DX12_STATE_OBJECT_DESC& ContainingStateObject) + { + Init(); + AddToStateObject(ContainingStateObject); + } + void SetRootSignature(ID3D12RootSignature* pRootSig) + { + m_pRootSig = pRootSig; + } + D3D12_STATE_SUBOBJECT_TYPE Type() const + { + return D3D12_STATE_SUBOBJECT_TYPE_LOCAL_ROOT_SIGNATURE; + } + operator const D3D12_STATE_SUBOBJECT&() const { return *m_pSubobject; } + operator ID3D12RootSignature*() const { return m_pRootSig.Get(); } +private: + void Init() + { + SUBOBJECT_HELPER_BASE::Init(); + m_pRootSig = nullptr; + } + void* Data() { return m_pRootSig.GetAddressOf(); } + Microsoft::WRL::ComPtr m_pRootSig; +}; + +//------------------------------------------------------------------------------------------------ +class CD3DX12_STATE_OBJECT_CONFIG_SUBOBJECT + : public CD3DX12_STATE_OBJECT_DESC::SUBOBJECT_HELPER_BASE +{ +public: + CD3DX12_STATE_OBJECT_CONFIG_SUBOBJECT() + { + Init(); + } + CD3DX12_STATE_OBJECT_CONFIG_SUBOBJECT(CD3DX12_STATE_OBJECT_DESC& ContainingStateObject) + { + Init(); + AddToStateObject(ContainingStateObject); + } + void SetFlags(D3D12_STATE_OBJECT_FLAGS Flags) + { + m_Desc.Flags = Flags; + } + D3D12_STATE_SUBOBJECT_TYPE Type() const + { + return D3D12_STATE_SUBOBJECT_TYPE_STATE_OBJECT_CONFIG; + } + operator const D3D12_STATE_SUBOBJECT&() const { return *m_pSubobject; } + operator const D3D12_STATE_OBJECT_CONFIG&() const { return m_Desc; } +private: + void Init() + { + SUBOBJECT_HELPER_BASE::Init(); + m_Desc = {}; + } + void* Data() { return &m_Desc; } + D3D12_STATE_OBJECT_CONFIG m_Desc; +}; + +//------------------------------------------------------------------------------------------------ +class CD3DX12_NODE_MASK_SUBOBJECT + : public CD3DX12_STATE_OBJECT_DESC::SUBOBJECT_HELPER_BASE +{ +public: + CD3DX12_NODE_MASK_SUBOBJECT() + { + Init(); + } + CD3DX12_NODE_MASK_SUBOBJECT(CD3DX12_STATE_OBJECT_DESC& ContainingStateObject) + { + Init(); + AddToStateObject(ContainingStateObject); + } + void SetNodeMask(UINT NodeMask) + { + m_Desc.NodeMask = NodeMask; + } + D3D12_STATE_SUBOBJECT_TYPE Type() const + { + return D3D12_STATE_SUBOBJECT_TYPE_NODE_MASK; + } + operator const D3D12_STATE_SUBOBJECT&() const { return *m_pSubobject; } + operator const D3D12_NODE_MASK&() const { return m_Desc; } +private: + void Init() + { + SUBOBJECT_HELPER_BASE::Init(); + m_Desc = {}; + } + void* Data() { return &m_Desc; } + D3D12_NODE_MASK m_Desc; +}; + +#endif // #ifndef D3DX12_NO_STATE_OBJECT_HELPERS + +#endif // defined( __cplusplus ) + +#endif //__D3DX12_H__ + + + diff --git a/Samples/DirectMLConv/DirectMLConv/main.cpp b/Samples/DirectMLConv/DirectMLConv/main.cpp new file mode 100644 index 00000000..59d7a601 --- /dev/null +++ b/Samples/DirectMLConv/DirectMLConv/main.cpp @@ -0,0 +1,596 @@ +#include "pch.h" + +#pragma warning(disable : 4238) // References to temporary classes are okay because they are only used as function parameters. + +using Microsoft::WRL::ComPtr; + +void InitializeDirect3D12( + ComPtr& d3D12Device, + ComPtr& commandQueue, + ComPtr& commandAllocator, + ComPtr& commandList) +{ + HRESULT hr{}; + + ComPtr dxgiFactory; + hr = CreateDXGIFactory1(IID_PPV_ARGS(dxgiFactory.GetAddressOf())); + + if (hr != S_OK) + std::cout << "failed to create dxgi factory"; + + ComPtr dxgiAdapter; + UINT adapterIndex{}; + + do + { + dxgiAdapter = nullptr; + THROW_IF_FAILED(dxgiFactory->EnumAdapters(adapterIndex, dxgiAdapter.ReleaseAndGetAddressOf())); + ++adapterIndex; + + hr = ::D3D12CreateDevice( + dxgiAdapter.Get(), + D3D_FEATURE_LEVEL_11_0, + IID_PPV_ARGS(d3D12Device.ReleaseAndGetAddressOf())); + if (hr == DXGI_ERROR_UNSUPPORTED) continue; + // THROW_IF_FAILED(hr); + + if (hr != S_OK) + std::cout << "failed to init adapter"; + + } while (hr != S_OK); + + D3D12_COMMAND_QUEUE_DESC commandQueueDesc{}; + commandQueueDesc.Type = D3D12_COMMAND_LIST_TYPE_DIRECT; + commandQueueDesc.Flags = D3D12_COMMAND_QUEUE_FLAG_NONE; + + THROW_IF_FAILED(d3D12Device->CreateCommandQueue( + &commandQueueDesc, + IID_GRAPHICS_PPV_ARGS(commandQueue.ReleaseAndGetAddressOf()))); + + THROW_IF_FAILED(d3D12Device->CreateCommandAllocator( + D3D12_COMMAND_LIST_TYPE_DIRECT, + IID_GRAPHICS_PPV_ARGS(commandAllocator.ReleaseAndGetAddressOf()))); + + THROW_IF_FAILED(d3D12Device->CreateCommandList( + 0, + D3D12_COMMAND_LIST_TYPE_DIRECT, + commandAllocator.Get(), + nullptr, + IID_GRAPHICS_PPV_ARGS(commandList.ReleaseAndGetAddressOf()))); +} + +void CloseExecuteResetWait( + ComPtr d3D12Device, + ComPtr commandQueue, + ComPtr commandAllocator, + ComPtr commandList) +{ + THROW_IF_FAILED(commandList->Close()); + + ID3D12CommandList* commandLists[] = { commandList.Get() }; + commandQueue->ExecuteCommandLists(ARRAYSIZE(commandLists), commandLists); + + ComPtr d3D12Fence; + THROW_IF_FAILED(d3D12Device->CreateFence( + 0, + D3D12_FENCE_FLAG_NONE, + IID_GRAPHICS_PPV_ARGS(d3D12Fence.GetAddressOf()))); + + wil::unique_handle fenceEventHandle(::CreateEvent(nullptr, true, false, nullptr)); + THROW_LAST_ERROR_IF_NULL(fenceEventHandle); + + THROW_IF_FAILED(commandQueue->Signal(d3D12Fence.Get(), 1)); + THROW_IF_FAILED(d3D12Fence->SetEventOnCompletion(1, fenceEventHandle.get())); + + ::WaitForSingleObjectEx(fenceEventHandle.get(), INFINITE, FALSE); + + THROW_IF_FAILED(commandAllocator->Reset()); + THROW_IF_FAILED(commandList->Reset(commandAllocator.Get(), nullptr)); +} + +// stride for NCHW layout +void SetStrides(const UINT sizes[4], UINT stridesOut[4]) { + stridesOut[0] = sizes[1] * sizes[2] * sizes[3]; + stridesOut[1] = sizes[2] * sizes[3]; + stridesOut[2] = sizes[3]; + stridesOut[3] = 1; +} + +int main() { + + // initialize D3D12 related resources + ComPtr d3D12Device; + ComPtr commandQueue; + ComPtr commandAllocator; + ComPtr commandList; + + // Set up Direct3D 12. + InitializeDirect3D12(d3D12Device, commandQueue, commandAllocator, commandList); + + + // Create the DirectML device. + DML_CREATE_DEVICE_FLAGS dmlCreateDeviceFlags = DML_CREATE_DEVICE_FLAG_NONE; + ComPtr dmlDevice; + THROW_IF_FAILED(DMLCreateDevice( + d3D12Device.Get(), + dmlCreateDeviceFlags, + IID_PPV_ARGS(dmlDevice.GetAddressOf()))); + + + // ************ set input tensor related params ************ + + constexpr UINT inputSizes[4] = { 1,1,3,3 }; + constexpr UINT inputTensorElementCount = inputSizes[0] * inputSizes[1] * inputSizes[2] * inputSizes[3]; + + UINT inputStrides[4]; + SetStrides(inputSizes, inputStrides); + + DML_BUFFER_TENSOR_DESC inputBufferTensorDesc = {}; + inputBufferTensorDesc.DataType = DML_TENSOR_DATA_TYPE_FLOAT32; + inputBufferTensorDesc.Flags = DML_TENSOR_FLAG_NONE; + inputBufferTensorDesc.DimensionCount = ARRAYSIZE(inputSizes); + inputBufferTensorDesc.Sizes = inputSizes; + inputBufferTensorDesc.Strides = inputStrides; + inputBufferTensorDesc.TotalTensorSizeInBytes = DMLCalcBufferTensorSize( + inputBufferTensorDesc.DataType, + inputBufferTensorDesc.DimensionCount, + inputBufferTensorDesc.Sizes, + inputBufferTensorDesc.Strides); + + DML_TENSOR_DESC inputTensorDesc{}; + inputTensorDesc.Type = DML_TENSOR_TYPE_BUFFER; + inputTensorDesc.Desc = &inputBufferTensorDesc; + + + // ************ set input tensor related params ************ + + constexpr UINT filterSizes[4] = { 1,1,1,1 }; + constexpr UINT filterTensorElementCount = filterSizes[0] * filterSizes[1] * filterSizes[2] * filterSizes[3]; + + UINT filterStrides[4]; + SetStrides(filterSizes, filterStrides); + + DML_BUFFER_TENSOR_DESC filterBufferTensorDesc = {}; + filterBufferTensorDesc.DataType = DML_TENSOR_DATA_TYPE_FLOAT32; + filterBufferTensorDesc.Flags = DML_TENSOR_FLAG_NONE; + filterBufferTensorDesc.DimensionCount = ARRAYSIZE(filterSizes); + filterBufferTensorDesc.Sizes = filterSizes; + filterBufferTensorDesc.Strides = filterStrides; + filterBufferTensorDesc.TotalTensorSizeInBytes = DMLCalcBufferTensorSize( + filterBufferTensorDesc.DataType, + filterBufferTensorDesc.DimensionCount, + filterBufferTensorDesc.Sizes, + filterBufferTensorDesc.Strides); + + DML_TENSOR_DESC filterTensorDesc{}; + filterTensorDesc.Type = DML_TENSOR_TYPE_BUFFER; + filterTensorDesc.Desc = &filterBufferTensorDesc; + + + // ************ set output tensor related params ************ + + UINT outputSizes[4]; + outputSizes[0] = inputSizes[0]; + outputSizes[1] = filterSizes[0]; + outputSizes[2] = inputSizes[2]; + outputSizes[3] = inputSizes[3]; + + UINT outputTensorElementCount = outputSizes[0] * outputSizes[1] * outputSizes[2] * outputSizes[3]; + + UINT outputStrides[4]; + SetStrides(outputSizes, outputStrides); + + DML_BUFFER_TENSOR_DESC outputBufferTensorDesc = {}; + outputBufferTensorDesc.DataType = DML_TENSOR_DATA_TYPE_FLOAT32; + outputBufferTensorDesc.Flags = DML_TENSOR_FLAG_NONE; + outputBufferTensorDesc.DimensionCount = ARRAYSIZE(outputSizes); + outputBufferTensorDesc.Sizes = outputSizes; + outputBufferTensorDesc.Strides = outputStrides; + outputBufferTensorDesc.TotalTensorSizeInBytes = DMLCalcBufferTensorSize( + outputBufferTensorDesc.DataType, + outputBufferTensorDesc.DimensionCount, + outputBufferTensorDesc.Sizes, + outputBufferTensorDesc.Strides); + + DML_TENSOR_DESC outputTensorDesc{}; + outputTensorDesc.Type = DML_TENSOR_TYPE_BUFFER; + outputTensorDesc.Desc = &outputBufferTensorDesc; + + + // ************ define convolution operator ************ + + // The output size of a convolution operation is given by: + // height = (inputHeight - filterHeight + 2*paddingHeight) / filterStride + 1 + // width = (inputWidth - filterWidth + 2*paddingWidth ) / filterStride + 1 + // + // We want to preserve the height and width, so assuming stride is 1, we get: + // paddingHeight = (filterHeight - 1) / 2 + // paddingWidth = (filterWidth - 1) / 2 + // If padding is fractional, we pad unevenly with ceil/floor. + UINT paddingHeightTop = static_cast(ceil((filterSizes[2] - 1) / 2.0f)); + UINT paddingHeightBottom = static_cast(floor((filterSizes[2] - 1) / 2.0f)); + UINT paddingWidthLeft = static_cast(ceil((filterSizes[3] - 1) / 2.0f)); + UINT paddingWidthRight = static_cast(floor((filterSizes[3] - 1) / 2.0f)); + + UINT strides[] = { 1, 1 }; + UINT dilations[] = { 1, 1 }; + UINT startPadding[] = { paddingHeightTop, paddingWidthLeft }; + UINT endPadding[] = { paddingHeightBottom, paddingWidthRight }; + UINT outputPadding[] = { 0, 0 }; + + + // create conv DML operator descriptions + + DML_CONVOLUTION_OPERATOR_DESC dmlConvOperatorDesc{}; + dmlConvOperatorDesc.BiasTensor = nullptr; + dmlConvOperatorDesc.InputTensor = &inputTensorDesc; + dmlConvOperatorDesc.FilterTensor = &filterTensorDesc; + dmlConvOperatorDesc.OutputTensor = &outputTensorDesc; // Input and output tensors have same size/type. + dmlConvOperatorDesc.Mode = DML_CONVOLUTION_MODE_CROSS_CORRELATION; + dmlConvOperatorDesc.Direction = DML_CONVOLUTION_DIRECTION_FORWARD; + dmlConvOperatorDesc.DimensionCount = 2; + dmlConvOperatorDesc.GroupCount = 1; + + dmlConvOperatorDesc.StartPadding = startPadding; + dmlConvOperatorDesc.EndPadding = endPadding; + dmlConvOperatorDesc.OutputPadding = outputPadding; + dmlConvOperatorDesc.Strides = strides; + dmlConvOperatorDesc.Dilations = dilations; + + DML_OPERATOR_DESC dmlOperatorDesc{}; + dmlOperatorDesc.Type = DML_OPERATOR_CONVOLUTION; + dmlOperatorDesc.Desc = &dmlConvOperatorDesc; + + + // Create and compile the conv DML operator + + ComPtr dmlOperator; + THROW_IF_FAILED(dmlDevice->CreateOperator( + &dmlOperatorDesc, + IID_PPV_ARGS(dmlOperator.GetAddressOf()))); + + ComPtr dmlCompiledOperator; + THROW_IF_FAILED(dmlDevice->CompileOperator( + dmlOperator.Get(), + DML_EXECUTION_FLAG_NONE, + IID_PPV_ARGS(dmlCompiledOperator.GetAddressOf()))); + + + ComPtr dmlOperatorInitializer; + IDMLCompiledOperator* dmlCompiledOperators[] = { dmlCompiledOperator.Get() }; + THROW_IF_FAILED(dmlDevice->CreateOperatorInitializer( + ARRAYSIZE(dmlCompiledOperators), + dmlCompiledOperators, + IID_PPV_ARGS(dmlOperatorInitializer.GetAddressOf()))); + + + // Query the operator for the required size (in descriptors) of its binding table. + // You need to initialize an operator exactly once before it can be executed, and + // the two stages require different numbers of descriptors for binding. For simplicity, + // we create a single descriptor heap that's large enough to satisfy them both. + DML_BINDING_PROPERTIES initializeBindingProperties = dmlOperatorInitializer->GetBindingProperties(); + DML_BINDING_PROPERTIES executeBindingProperties = dmlCompiledOperator->GetBindingProperties(); + UINT descriptorCount = std::max( + initializeBindingProperties.RequiredDescriptorCount, + executeBindingProperties.RequiredDescriptorCount); + + // Create descriptor heaps. + ComPtr descriptorHeap; + + D3D12_DESCRIPTOR_HEAP_DESC descriptorHeapDesc{}; + descriptorHeapDesc.Type = D3D12_DESCRIPTOR_HEAP_TYPE_CBV_SRV_UAV; + descriptorHeapDesc.NumDescriptors = descriptorCount; + descriptorHeapDesc.Flags = D3D12_DESCRIPTOR_HEAP_FLAG_SHADER_VISIBLE; + THROW_IF_FAILED(d3D12Device->CreateDescriptorHeap( + &descriptorHeapDesc, + IID_GRAPHICS_PPV_ARGS(descriptorHeap.GetAddressOf()))); + + // Set the descriptor heap(s). + ID3D12DescriptorHeap* d3D12DescriptorHeaps[] = { descriptorHeap.Get() }; + commandList->SetDescriptorHeaps(ARRAYSIZE(d3D12DescriptorHeaps), d3D12DescriptorHeaps); + + + // Create a binding table over the descriptor heap we just created. + DML_BINDING_TABLE_DESC dmlBindingTableDesc{}; + dmlBindingTableDesc.Dispatchable = dmlOperatorInitializer.Get(); + dmlBindingTableDesc.CPUDescriptorHandle = descriptorHeap->GetCPUDescriptorHandleForHeapStart(); + dmlBindingTableDesc.GPUDescriptorHandle = descriptorHeap->GetGPUDescriptorHandleForHeapStart(); + dmlBindingTableDesc.SizeInDescriptors = descriptorCount; + + ComPtr dmlBindingTable; + THROW_IF_FAILED(dmlDevice->CreateBindingTable( + &dmlBindingTableDesc, + IID_PPV_ARGS(dmlBindingTable.GetAddressOf()))); + + + // Create the temporary and persistent resources that are necessary for executing an operator. + + // The temporary resource is scratch memory (used internally by DirectML), whose contents you don't need to define. + // The persistent resource is long-lived, and you need to initialize it using the IDMLOperatorInitializer. + + UINT64 temporaryResourceSize = std::max( + initializeBindingProperties.TemporaryResourceSize, + executeBindingProperties.TemporaryResourceSize); + UINT64 persistentResourceSize = executeBindingProperties.PersistentResourceSize; + + + // Bind and initialize the operator on the GPU. + + ComPtr temporaryBuffer; + if (temporaryResourceSize != 0) + { + THROW_IF_FAILED(d3D12Device->CreateCommittedResource( + &CD3DX12_HEAP_PROPERTIES(D3D12_HEAP_TYPE_DEFAULT), + D3D12_HEAP_FLAG_NONE, + &CD3DX12_RESOURCE_DESC::Buffer(temporaryResourceSize, D3D12_RESOURCE_FLAG_ALLOW_UNORDERED_ACCESS), + D3D12_RESOURCE_STATE_COMMON, + nullptr, + IID_GRAPHICS_PPV_ARGS(temporaryBuffer.GetAddressOf()))); + + if (initializeBindingProperties.TemporaryResourceSize != 0) + { + DML_BUFFER_BINDING bufferBinding{ temporaryBuffer.Get(), 0, temporaryResourceSize }; + DML_BINDING_DESC bindingDesc{ DML_BINDING_TYPE_BUFFER, &bufferBinding }; + dmlBindingTable->BindTemporaryResource(&bindingDesc); + } + } + + ComPtr persistentBuffer; + if (persistentResourceSize != 0) + { + THROW_IF_FAILED(d3D12Device->CreateCommittedResource( + &CD3DX12_HEAP_PROPERTIES(D3D12_HEAP_TYPE_DEFAULT), + D3D12_HEAP_FLAG_NONE, + &CD3DX12_RESOURCE_DESC::Buffer(persistentResourceSize), + D3D12_RESOURCE_STATE_COMMON, + nullptr, + IID_GRAPHICS_PPV_ARGS(persistentBuffer.GetAddressOf()))); + + // The persistent resource should be bound as the output to the IDMLOperatorInitializer. + DML_BUFFER_BINDING bufferBinding{ persistentBuffer.Get(), 0, persistentResourceSize }; + DML_BINDING_DESC bindingDesc{ DML_BINDING_TYPE_BUFFER, &bufferBinding }; + dmlBindingTable->BindOutputs(1, &bindingDesc); + } + + // The command recorder is a stateless object that records Dispatches into an existing Direct3D 12 command list. + ComPtr dmlCommandRecorder; + THROW_IF_FAILED(dmlDevice->CreateCommandRecorder( + IID_PPV_ARGS(dmlCommandRecorder.GetAddressOf()))); + + // Record execution of the operator initializer. + dmlCommandRecorder->RecordDispatch( + commandList.Get(), + dmlOperatorInitializer.Get(), + dmlBindingTable.Get()); + + + // Close the Direct3D 12 command list, and submit it for execution as you would any other command list. You could + // in principle record the execution into the same command list as the initialization, but you need only to Initialize + // once, and typically you want to Execute an operator more frequently than that. + CloseExecuteResetWait(d3D12Device, commandQueue, commandAllocator, commandList); + + // + // Bind and execute the operator on the GPU. + // + commandList->SetDescriptorHeaps(ARRAYSIZE(d3D12DescriptorHeaps), d3D12DescriptorHeaps); + + + // Reset the binding table to bind for the operator we want to execute (it was previously used to bind for the + // initializer). + + dmlBindingTableDesc.Dispatchable = dmlCompiledOperator.Get(); + THROW_IF_FAILED(dmlBindingTable->Reset(&dmlBindingTableDesc)); + + if (temporaryResourceSize != 0) + { + DML_BUFFER_BINDING bufferBinding{ temporaryBuffer.Get(), 0, temporaryResourceSize }; + DML_BINDING_DESC bindingDesc{ DML_BINDING_TYPE_BUFFER, &bufferBinding }; + dmlBindingTable->BindTemporaryResource(&bindingDesc); + } + + if (persistentResourceSize != 0) + { + DML_BUFFER_BINDING bufferBinding{ persistentBuffer.Get(), 0, persistentResourceSize }; + DML_BINDING_DESC bindingDesc{ DML_BINDING_TYPE_BUFFER, &bufferBinding }; + dmlBindingTable->BindPersistentResource(&bindingDesc); + } + + + // Create tensor buffers for upload/input/output/readback of the tensor elements. + + // *************** filter tensor *************** + + // 24 elements * 4 == 96 bytes. + UINT64 filterTensorBufferSize{ filterBufferTensorDesc.TotalTensorSizeInBytes }; + + ComPtr filterUploadBuffer; + THROW_IF_FAILED(d3D12Device->CreateCommittedResource( + &CD3DX12_HEAP_PROPERTIES(D3D12_HEAP_TYPE_UPLOAD), + D3D12_HEAP_FLAG_NONE, + &CD3DX12_RESOURCE_DESC::Buffer(filterTensorBufferSize), + D3D12_RESOURCE_STATE_GENERIC_READ, + nullptr, + IID_GRAPHICS_PPV_ARGS(filterUploadBuffer.GetAddressOf()))); + + ComPtr filterInputBuffer; + THROW_IF_FAILED(d3D12Device->CreateCommittedResource( + &CD3DX12_HEAP_PROPERTIES(D3D12_HEAP_TYPE_DEFAULT), + D3D12_HEAP_FLAG_NONE, + &CD3DX12_RESOURCE_DESC::Buffer(filterTensorBufferSize, D3D12_RESOURCE_FLAG_ALLOW_UNORDERED_ACCESS), + D3D12_RESOURCE_STATE_COPY_DEST, + nullptr, + IID_GRAPHICS_PPV_ARGS(filterInputBuffer.GetAddressOf()))); + + std::wcout << std::fixed; std::wcout.precision(4); + std::array filterTensorElementArray; + { + std::wcout << L"filter tensor: \n"; + for (auto& element : filterTensorElementArray) + { + element = 5.6f; + std::wcout << element << L' '; + }; + std::wcout << std::endl; + + D3D12_SUBRESOURCE_DATA filterTensorSubresourceData{}; + filterTensorSubresourceData.pData = filterTensorElementArray.data(); + filterTensorSubresourceData.RowPitch = static_cast(filterTensorBufferSize); + filterTensorSubresourceData.SlicePitch = filterTensorSubresourceData.RowPitch; + + // Upload the input tensor to the GPU. + ::UpdateSubresources( + commandList.Get(), + filterInputBuffer.Get(), + filterUploadBuffer.Get(), + 0, + 0, + 1, + &filterTensorSubresourceData); + + commandList->ResourceBarrier( + 1, + &CD3DX12_RESOURCE_BARRIER::Transition( + filterInputBuffer.Get(), + D3D12_RESOURCE_STATE_COPY_DEST, + D3D12_RESOURCE_STATE_UNORDERED_ACCESS + ) + ); + } + + DML_BUFFER_BINDING filterInputBufferBinding{ filterInputBuffer.Get(), 0, filterTensorBufferSize }; + DML_BINDING_DESC filterInputBindingDesc{ DML_BINDING_TYPE_BUFFER, &filterInputBufferBinding }; + // dmlBindingTable->BindInputs(1, &filterInputBindingDesc); + + + // *************** input tensor *************** + // 24 elements * 4 == 96 bytes. + UINT64 tensorBufferSize{ inputBufferTensorDesc.TotalTensorSizeInBytes }; + + ComPtr uploadBuffer; + THROW_IF_FAILED(d3D12Device->CreateCommittedResource( + &CD3DX12_HEAP_PROPERTIES(D3D12_HEAP_TYPE_UPLOAD), + D3D12_HEAP_FLAG_NONE, + &CD3DX12_RESOURCE_DESC::Buffer(tensorBufferSize), + D3D12_RESOURCE_STATE_GENERIC_READ, + nullptr, + IID_GRAPHICS_PPV_ARGS(uploadBuffer.GetAddressOf()))); + + ComPtr inputBuffer; + THROW_IF_FAILED(d3D12Device->CreateCommittedResource( + &CD3DX12_HEAP_PROPERTIES(D3D12_HEAP_TYPE_DEFAULT), + D3D12_HEAP_FLAG_NONE, + &CD3DX12_RESOURCE_DESC::Buffer(tensorBufferSize, D3D12_RESOURCE_FLAG_ALLOW_UNORDERED_ACCESS), + D3D12_RESOURCE_STATE_COPY_DEST, + nullptr, + IID_GRAPHICS_PPV_ARGS(inputBuffer.GetAddressOf()))); + + std::wcout << std::fixed; std::wcout.precision(4); + std::array inputTensorElementArray; + { + std::wcout << L"input tensor: \n"; + for (auto& element : inputTensorElementArray) + { + element = -2.0f; + std::wcout << element << L' '; + }; + std::wcout << std::endl; + + D3D12_SUBRESOURCE_DATA tensorSubresourceData{}; + tensorSubresourceData.pData = inputTensorElementArray.data(); + tensorSubresourceData.RowPitch = static_cast(tensorBufferSize); + tensorSubresourceData.SlicePitch = tensorSubresourceData.RowPitch; + + // Upload the input tensor to the GPU. + ::UpdateSubresources( + commandList.Get(), + inputBuffer.Get(), + uploadBuffer.Get(), + 0, + 0, + 1, + &tensorSubresourceData); + + commandList->ResourceBarrier( + 1, + &CD3DX12_RESOURCE_BARRIER::Transition( + inputBuffer.Get(), + D3D12_RESOURCE_STATE_COPY_DEST, + D3D12_RESOURCE_STATE_UNORDERED_ACCESS + ) + ); + } + + DML_BUFFER_BINDING inputBufferBinding{ inputBuffer.Get(), 0, tensorBufferSize }; + DML_BINDING_DESC inputBindingDesc{ DML_BINDING_TYPE_BUFFER, &inputBufferBinding }; + + DML_BINDING_DESC bindings[3]; + bindings[0] = inputBindingDesc; + bindings[1] = filterInputBindingDesc; + bindings[2].Type = DML_BINDING_TYPE_NONE; + bindings[2].Desc = nullptr; + dmlBindingTable->BindInputs(3, bindings); + // dmlBindingTable->BindInputs(1, &inputBindingDesc); + + + // *************** output tensor *************** + + ComPtr outputBuffer; + THROW_IF_FAILED(d3D12Device->CreateCommittedResource( + &CD3DX12_HEAP_PROPERTIES(D3D12_HEAP_TYPE_DEFAULT), + D3D12_HEAP_FLAG_NONE, + &CD3DX12_RESOURCE_DESC::Buffer(tensorBufferSize, D3D12_RESOURCE_FLAG_ALLOW_UNORDERED_ACCESS), + D3D12_RESOURCE_STATE_UNORDERED_ACCESS, + nullptr, + IID_GRAPHICS_PPV_ARGS(outputBuffer.GetAddressOf()))); + + DML_BUFFER_BINDING outputBufferBinding{ outputBuffer.Get(), 0, tensorBufferSize }; + DML_BINDING_DESC outputBindingDesc{ DML_BINDING_TYPE_BUFFER, &outputBufferBinding }; + dmlBindingTable->BindOutputs(1, &outputBindingDesc); + + // Record execution of the compiled operator. + dmlCommandRecorder->RecordDispatch(commandList.Get(), dmlCompiledOperator.Get(), dmlBindingTable.Get()); + + CloseExecuteResetWait(d3D12Device, commandQueue, commandAllocator, commandList); + + // The output buffer now contains the result of the identity operator, + // so read it back if you want the CPU to access it. + + ComPtr readbackBuffer; + THROW_IF_FAILED(d3D12Device->CreateCommittedResource( + &CD3DX12_HEAP_PROPERTIES(D3D12_HEAP_TYPE_READBACK), + D3D12_HEAP_FLAG_NONE, + &CD3DX12_RESOURCE_DESC::Buffer(tensorBufferSize), + D3D12_RESOURCE_STATE_COPY_DEST, + nullptr, + IID_GRAPHICS_PPV_ARGS(readbackBuffer.GetAddressOf()))); + + commandList->ResourceBarrier( + 1, + &CD3DX12_RESOURCE_BARRIER::Transition( + outputBuffer.Get(), + D3D12_RESOURCE_STATE_UNORDERED_ACCESS, + D3D12_RESOURCE_STATE_COPY_SOURCE + ) + ); + + commandList->CopyResource(readbackBuffer.Get(), outputBuffer.Get()); + + CloseExecuteResetWait(d3D12Device, commandQueue, commandAllocator, commandList); + + D3D12_RANGE tensorBufferRange{ 0, static_cast(tensorBufferSize) }; + FLOAT* outputBufferData{}; + THROW_IF_FAILED(readbackBuffer->Map(0, &tensorBufferRange, reinterpret_cast(&outputBufferData))); + + std::wstring outputString = L"\noutput tensor: \n"; + for (size_t tensorElementIndex{ 0 }; tensorElementIndex < outputTensorElementCount; ++tensorElementIndex, ++outputBufferData) + { + outputString += std::to_wstring(*outputBufferData) + L' '; + } + + std::wcout << outputString << std::endl; + OutputDebugStringW(outputString.c_str()); + + D3D12_RANGE emptyRange{ 0, 0 }; + readbackBuffer->Unmap(0, &emptyRange); + +} \ No newline at end of file diff --git a/Samples/DirectMLConv/DirectMLConv/packages.config b/Samples/DirectMLConv/DirectMLConv/packages.config new file mode 100644 index 00000000..026d48d6 --- /dev/null +++ b/Samples/DirectMLConv/DirectMLConv/packages.config @@ -0,0 +1,5 @@ + + + + + \ No newline at end of file diff --git a/Samples/DirectMLConv/DirectMLConv/pch.cpp b/Samples/DirectMLConv/DirectMLConv/pch.cpp new file mode 100644 index 00000000..1da170eb --- /dev/null +++ b/Samples/DirectMLConv/DirectMLConv/pch.cpp @@ -0,0 +1,4 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "pch.h" diff --git a/Samples/DirectMLConv/DirectMLConv/pch.h b/Samples/DirectMLConv/DirectMLConv/pch.h new file mode 100644 index 00000000..fee295cd --- /dev/null +++ b/Samples/DirectMLConv/DirectMLConv/pch.h @@ -0,0 +1,34 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#define NOMINMAX + +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#ifdef _GAMING_XBOX_SCARLETT +#include +#include +#else +#include "d3dx12.h" // The D3D12 Helper Library that you downloaded. +#include +#define IID_GRAPHICS_PPV_ARGS IID_PPV_ARGS +#endif + +#define DML_TARGET_VERSION_USE_LATEST +#include // The DirectML header from the Windows SDK. +#include "DirectMLX.h" + +#include \ No newline at end of file diff --git a/Samples/DirectMLConv/README.md b/Samples/DirectMLConv/README.md new file mode 100644 index 00000000..dfcc1daf --- /dev/null +++ b/Samples/DirectMLConv/README.md @@ -0,0 +1,14 @@ +# DirectMLConv + +This application illustrates how to use DirectML APIs to implement and execute a simple convolution layer. + +The project uses the following tensor configuration for the input and filter +- input tensor: 1x1x3x3 +- filter tensor: 1x1x1x1 +- output tensor: 1x1x3x3 + +## How to use? +- Install Visual Studio 2022 +- Open the project using Visual Studio 2022 +- Build the solution which downloads all dependencies +- Run/debug the project \ No newline at end of file From 0a300c567644e74d8008b2ca6a4414968c1de291 Mon Sep 17 00:00:00 2001 From: Vishal Agarwal Date: Fri, 9 Feb 2024 15:45:01 +0530 Subject: [PATCH 2/2] update readme --- Samples/DirectMLConv/README.md | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/Samples/DirectMLConv/README.md b/Samples/DirectMLConv/README.md index dfcc1daf..b5d512f0 100644 --- a/Samples/DirectMLConv/README.md +++ b/Samples/DirectMLConv/README.md @@ -1,12 +1,23 @@ # DirectMLConv -This application illustrates how to use DirectML APIs to implement and execute a simple convolution layer. +This application illustrates how to use D3D12 and DirectML APIs to implement and execute a simple convolution layer. -The project uses the following tensor configuration for the input and filter +The sample uses the following tensor configuration for the input and filter - input tensor: 1x1x3x3 - filter tensor: 1x1x1x1 - output tensor: 1x1x3x3 +example input and output +```sh +filter tensor: +5.6000 +input tensor: +-2.0000 -2.0000 -2.0000 -2.0000 -2.0000 -2.0000 -2.0000 -2.0000 -2.0000 + +output tensor: +-11.200000 -11.200000 -11.200000 -11.200000 -11.200000 -11.200000 -11.200000 -11.200000 -11.200000 +``` + ## How to use? - Install Visual Studio 2022 - Open the project using Visual Studio 2022