diff --git a/Samples/DirectMLConv/DirectMLConv.sln b/Samples/DirectMLConv/DirectMLConv.sln
new file mode 100644
index 00000000..99e81ad2
--- /dev/null
+++ b/Samples/DirectMLConv/DirectMLConv.sln
@@ -0,0 +1,31 @@
+
+Microsoft Visual Studio Solution File, Format Version 12.00
+# Visual Studio Version 17
+VisualStudioVersion = 17.8.34525.116
+MinimumVisualStudioVersion = 10.0.40219.1
+Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "DirectMLConv", "DirectMLConv\DirectMLConv.vcxproj", "{A3AA531B-DAF1-4D6C-8389-EA88FA2082E6}"
+EndProject
+Global
+ GlobalSection(SolutionConfigurationPlatforms) = preSolution
+ Debug|x64 = Debug|x64
+ Debug|x86 = Debug|x86
+ Release|x64 = Release|x64
+ Release|x86 = Release|x86
+ EndGlobalSection
+ GlobalSection(ProjectConfigurationPlatforms) = postSolution
+ {A3AA531B-DAF1-4D6C-8389-EA88FA2082E6}.Debug|x64.ActiveCfg = Debug|x64
+ {A3AA531B-DAF1-4D6C-8389-EA88FA2082E6}.Debug|x64.Build.0 = Debug|x64
+ {A3AA531B-DAF1-4D6C-8389-EA88FA2082E6}.Debug|x86.ActiveCfg = Debug|Win32
+ {A3AA531B-DAF1-4D6C-8389-EA88FA2082E6}.Debug|x86.Build.0 = Debug|Win32
+ {A3AA531B-DAF1-4D6C-8389-EA88FA2082E6}.Release|x64.ActiveCfg = Release|x64
+ {A3AA531B-DAF1-4D6C-8389-EA88FA2082E6}.Release|x64.Build.0 = Release|x64
+ {A3AA531B-DAF1-4D6C-8389-EA88FA2082E6}.Release|x86.ActiveCfg = Release|Win32
+ {A3AA531B-DAF1-4D6C-8389-EA88FA2082E6}.Release|x86.Build.0 = Release|Win32
+ EndGlobalSection
+ GlobalSection(SolutionProperties) = preSolution
+ HideSolutionNode = FALSE
+ EndGlobalSection
+ GlobalSection(ExtensibilityGlobals) = postSolution
+ SolutionGuid = {72CE4CBB-7C71-40C5-956F-D5A9DCCE48A1}
+ EndGlobalSection
+EndGlobal
diff --git a/Samples/DirectMLConv/DirectMLConv/DirectMLConv.vcxproj b/Samples/DirectMLConv/DirectMLConv/DirectMLConv.vcxproj
new file mode 100644
index 00000000..e5f1791c
--- /dev/null
+++ b/Samples/DirectMLConv/DirectMLConv/DirectMLConv.vcxproj
@@ -0,0 +1,160 @@
+
+
+
+
+
+ Debug
+ Win32
+
+
+ Release
+ Win32
+
+
+ Debug
+ x64
+
+
+ Release
+ x64
+
+
+
+ 17.0
+ Win32Proj
+ {a3aa531b-daf1-4d6c-8389-ea88fa2082e6}
+ DirectMLConv
+ 10.0
+
+
+
+ Application
+ true
+ v143
+ Unicode
+
+
+ Application
+ false
+ v143
+ true
+ Unicode
+
+
+ Application
+ true
+ v143
+ Unicode
+
+
+ Application
+ false
+ v143
+ true
+ Unicode
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ Level3
+ true
+ WIN32;_DEBUG;_CONSOLE;%(PreprocessorDefinitions)
+ true
+
+
+ Console
+ true
+
+
+
+
+ Level3
+ true
+ true
+ true
+ WIN32;NDEBUG;_CONSOLE;%(PreprocessorDefinitions)
+ true
+
+
+ Console
+ true
+ true
+ true
+
+
+
+
+ Level3
+ true
+ _DEBUG;_CONSOLE;%(PreprocessorDefinitions)
+ false
+ stdcpp17
+ Create
+ pch.h
+
+
+ Console
+ true
+ WindowsApp.lib;dxgi.lib;d3d12.lib;$(CoreLibraryDependencies);%(AdditionalDependencies)
+
+
+
+
+ Level3
+ true
+ true
+ true
+ NDEBUG;_CONSOLE;%(PreprocessorDefinitions)
+ true
+
+
+ Console
+ true
+ true
+ true
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ This project references NuGet package(s) that are missing on this computer. Use NuGet Package Restore to download them. For more information, see http://go.microsoft.com/fwlink/?LinkID=322105. The missing file is {0}.
+
+
+
+
+
+
\ No newline at end of file
diff --git a/Samples/DirectMLConv/DirectMLConv/DirectMLConv.vcxproj.filters b/Samples/DirectMLConv/DirectMLConv/DirectMLConv.vcxproj.filters
new file mode 100644
index 00000000..dd8b06c4
--- /dev/null
+++ b/Samples/DirectMLConv/DirectMLConv/DirectMLConv.vcxproj.filters
@@ -0,0 +1,40 @@
+
+
+
+
+ {4FC737F1-C7A5-4376-A066-2A32D752A2FF}
+ cpp;c;cc;cxx;c++;cppm;ixx;def;odl;idl;hpj;bat;asm;asmx
+
+
+ {93995380-89BD-4b04-88EB-625FBE52EBFB}
+ h;hh;hpp;hxx;h++;hm;inl;inc;ipp;xsd
+
+
+ {67DA6AB6-F800-4c08-8B7A-83BB121AAD01}
+ rc;ico;cur;bmp;dlg;rc2;rct;bin;rgs;gif;jpg;jpeg;jpe;resx;tiff;tif;png;wav;mfcribbon-ms
+
+
+
+
+ Header Files
+
+
+ Header Files
+
+
+ Header Files
+
+
+
+
+ Source Files
+
+
+ Source Files
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/Samples/DirectMLConv/DirectMLConv/DirectMLX.h b/Samples/DirectMLConv/DirectMLConv/DirectMLX.h
new file mode 100644
index 00000000..5a1308c3
--- /dev/null
+++ b/Samples/DirectMLConv/DirectMLConv/DirectMLX.h
@@ -0,0 +1,4367 @@
+//*********************************************************
+//
+// Copyright (c) Microsoft. All rights reserved.
+// This code is licensed under the MIT License (MIT).
+// THIS CODE IS PROVIDED *AS IS* WITHOUT WARRANTY OF
+// ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING ANY
+// IMPLIED WARRANTIES OF FITNESS FOR A PARTICULAR
+// PURPOSE, MERCHANTABILITY, OR NON-INFRINGEMENT.
+//
+//*********************************************************
+// clang-format off
+
+#pragma once
+#include "DirectML.h"
+
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+
+#include // For Microsoft::WRL::ComPtr
+
+#if DMLX_USE_ABSEIL
+ #if __cpp_lib_span
+ #include
+ #endif
+#elif __cplusplus >= 201703L && __has_include()
+ // stl optional is only available in cpp17 and above.
+ #include
+#elif __has_include("dml_optional_extensions.h")
+ #include "dml_optional_extensions.h"
+ #define DMLX_OPTIONAL_EXTENDED
+#endif
+
+#if __cpp_exceptions
+ #include
+#endif
+
+#if __cplusplus >= 201703L && __has_include()
+ #include
+#endif
+
+/** Calculates the minimum number of bytes required to store a buffer tensor with the specified type, sizes, and
+ strides. The formula can be expressed as the following:
+
+ IndexOfLastElement = dot(Sizes - 1, Strides);
+ MinimumImpliedSizeInBytes = roundup((IndexOfLastElement + 1) * ElementSizeInBytes, 4)
+
+ In other words, the minimum size of a tensor is the index of the one-past-the-end element, multiplied by the
+ element size (e.g. 2 bytes for a FLOAT16 tensor). Additionally DirectML requires that all buffers bound must have
+ a total size which is DWORD-aligned, and hence the minimum implied size in bytes must be rounded up to the nearest
+ 4-byte boundary.
+ */
+
+inline UINT64 DMLCalcBufferTensorSize(
+ DML_TENSOR_DATA_TYPE dataType,
+ UINT dimensionCount,
+ _In_reads_(dimensionCount) const UINT* sizes,
+ _In_reads_opt_(dimensionCount) const UINT* strides)
+{
+ UINT elementSizeInBytes = 0;
+ switch (dataType)
+ {
+ case DML_TENSOR_DATA_TYPE_FLOAT32:
+ case DML_TENSOR_DATA_TYPE_UINT32:
+ case DML_TENSOR_DATA_TYPE_INT32:
+ elementSizeInBytes = 4;
+ break;
+
+ case DML_TENSOR_DATA_TYPE_FLOAT16:
+ case DML_TENSOR_DATA_TYPE_UINT16:
+ case DML_TENSOR_DATA_TYPE_INT16:
+ elementSizeInBytes = 2;
+ break;
+
+ case DML_TENSOR_DATA_TYPE_UINT8:
+ case DML_TENSOR_DATA_TYPE_INT8:
+ elementSizeInBytes = 1;
+ break;
+
+ case DML_TENSOR_DATA_TYPE_FLOAT64:
+ case DML_TENSOR_DATA_TYPE_UINT64:
+ case DML_TENSOR_DATA_TYPE_INT64:
+ elementSizeInBytes = 8;
+ break;
+
+ default:
+ return 0; // Invalid data type
+ }
+
+ UINT64 minimumImpliedSizeInBytes = 0;
+ if (!strides)
+ {
+ minimumImpliedSizeInBytes = sizes[0];
+ for (UINT i = 1; i < dimensionCount; ++i)
+ {
+ minimumImpliedSizeInBytes *= sizes[i];
+ }
+ minimumImpliedSizeInBytes *= elementSizeInBytes;
+ }
+ else
+ {
+ UINT indexOfLastElement = 0;
+ for (UINT i = 0; i < dimensionCount; ++i)
+ {
+ indexOfLastElement += (sizes[i] - 1) * strides[i];
+ }
+
+ minimumImpliedSizeInBytes = (static_cast(indexOfLastElement) + 1) * elementSizeInBytes;
+ }
+
+ // Round up to the nearest 4 bytes.
+ minimumImpliedSizeInBytes = (minimumImpliedSizeInBytes + 3) & ~3ull;
+
+ return minimumImpliedSizeInBytes;
+}
+
+namespace dml
+{
+ namespace detail
+ {
+ // Provide non-member size() and data(). Defaults to standard library implementation (if available)
+#if __cpp_lib_nonmember_container_access
+ template
+ constexpr auto size(const C& c) -> decltype(c.size())
+ {
+ return std::size(c);
+ }
+
+ template
+ constexpr std::size_t size(const T(&array)[N]) noexcept
+ {
+ return std::size(array);
+ }
+
+ template
+ constexpr auto data(C& c) -> decltype(c.data())
+ {
+ return std::data(c);
+ }
+
+ template
+ constexpr T* data(T(&array)[N]) noexcept
+ {
+ return std::data(array);
+ }
+#else
+ template
+ constexpr auto size(const C& c) -> decltype(c.size())
+ {
+ return c.size();
+ }
+
+ template
+ constexpr std::size_t size(const T(&array)[N]) noexcept
+ {
+ return N;
+ }
+
+ template
+ constexpr auto data(C& c) -> decltype(c.data())
+ {
+ return c.data();
+ }
+
+ template
+ constexpr T* data(T(&array)[N]) noexcept
+ {
+ return array;
+ }
+#endif
+
+ template
+ class span
+ {
+ public:
+ span() = default;
+
+ constexpr span(std::initializer_list i) : m_begin(i.begin()), m_end(i.end()) {}
+ constexpr span(T* begin, T* end) : m_begin(begin), m_end(end) {}
+ constexpr span(T* begin, size_t elementCount) : m_begin(begin), m_end(begin + elementCount) {}
+
+ template
+ constexpr span(ContiguousContainer&& container)
+ : m_begin(dml::detail::data(container)), m_end(m_begin + dml::detail::size(container)) {}
+
+ template
+ constexpr span(T(&a)[N]) noexcept : span(a, N) {}
+
+ T* data() noexcept { return m_begin; }
+ T* begin() noexcept { return m_begin; }
+ T* end() noexcept { return m_end; }
+ T const* data() const noexcept { return m_begin; }
+ T const* begin() const noexcept { return m_begin; }
+ T const* end() const noexcept { return m_end; }
+ bool empty() const noexcept { return m_end == m_begin; }
+ size_t size() const noexcept { return m_end - m_begin; }
+ size_t size_bytes() const noexcept { return sizeof(T) * size(); }
+ T& operator[](size_t index) const noexcept { return m_begin[index]; }
+ span subspan(size_t index, size_t count) { return span(m_begin + index, m_begin + index + count); }
+
+ protected:
+ T* m_begin = nullptr;
+ T* m_end = nullptr;
+ };
+ }
+
+#if DMLX_USE_ABSEIL
+ template
+ using Optional = absl::optional;
+
+ constexpr absl::nullopt_t NullOpt = absl::nullopt;
+
+ template
+ using SmallVector = absl::InlinedVector;
+
+ template
+ using Span = absl::Span;
+
+ using absl::make_unique;
+#else
+ #ifndef DMLX_OPTIONAL_EXTENDED
+ template
+ using Optional = std::optional;
+ constexpr std::nullopt_t NullOpt = std::nullopt;
+ #endif
+
+ template
+ using SmallVector = std::vector;
+
+ #if __cpp_lib_span
+ template
+ using Span = std::span;
+ #elif DMLX_USE_GSL
+ template
+ using Span = gsl::span;
+ #else
+ template
+ using Span = dml::detail::span;
+ #endif
+
+ using std::make_unique;
+#endif
+
+#if __cplusplus >= 201703L && __has_include()
+ using StringView = std::string_view;
+#else
+ using StringView = const std::string&;
+#endif
+
+#if __cpp_exceptions
+ #if DMLX_USE_WIL
+ #define DMLX_THROW_IF_FAILED(_hr) THROW_IF_FAILED(_hr)
+ #define DMLX_THROW(_hr) THROW_HR(_hr)
+ #else
+ #define DMLX_THROW_IF_FAILED(_hr) if (FAILED(_hr)) { throw std::runtime_error(#_hr); }
+ #define DMLX_THROW(_hr) throw std::runtime_error(#_hr);
+ #endif
+#else
+ #define DMLX_THROW_IF_FAILED(_hr) if (FAILED(_hr)) { std::abort(); }
+ #define DMLX_THROW(_hr) { std::abort(); }
+#endif
+
+ class Graph;
+ class Expression;
+
+ using TensorDimensions = SmallVector;
+ using TensorStrides = SmallVector;
+
+ // The custom properties returned by a TensorPolicy.
+ struct TensorProperties
+ {
+ Optional strides;
+ uint64_t totalTensorSizeInBytes;
+ uint32_t guaranteedBaseOffsetAlignment;
+ };
+
+ // Provides a way to customize the properties that DMLX automatically sets on tensors. Callers may provide their
+ // own TensorPolicy implementation to provide custom strides, total tensor sizes, and alignment. TensorPolicy
+ // objects can be set using Graph::SetTensorPolicy().
+ class TensorPolicy
+ {
+ public:
+ // A function type that returns a TensorProperties object given a tensor data type, flags, and sizes.
+ using Func = std::function<
+ TensorProperties (DML_TENSOR_DATA_TYPE dataType, DML_TENSOR_FLAGS flags, Span sizes)
+ >;
+
+ TensorPolicy() = default;
+ /*implicit*/ TensorPolicy(Func impl)
+ : m_impl(impl)
+ {}
+
+ TensorProperties Get(
+ DML_TENSOR_DATA_TYPE dataType,
+ DML_TENSOR_FLAGS flags,
+ Span sizes) const
+ {
+ // Empty/uninitialized policy falls back to default.
+ if (!m_impl)
+ {
+ return ComputeDefault(dataType, flags, sizes);
+ }
+
+ return m_impl(dataType, flags, sizes);
+ }
+
+ // Returns the default tensor policy, which doesn't produce any changes to tensor layout, has no guaranteed
+ // alignment, and which uses DMLCalcBufferTensorSize to compute the total tensor size.
+ static TensorPolicy Default()
+ {
+ return TensorPolicy();
+ }
+
+ // A tensor policy that returns strides which produce tensors with a layout transposed to dimension order
+ // (0, 2, ..., n, 1). This is often referred to as "NHWC" or "interleaved channel" layout. This is useful,
+ // for example, when applied to 2D Convolution to produce outputs in an NHWC layout (as opposed to NCHW, which
+ // is the DirectML default for 2D Convolution).
+ //
+ // Examples of the transposes produced by this policy:
+ // NCW -> NWC
+ // NCHW -> NHWC
+ // NCDHW -> NDHWC
+ static TensorPolicy InterleavedChannel()
+ {
+ return TensorPolicy(&ComputeInterleavedChannel);
+ }
+
+ private:
+ static TensorProperties ComputeDefault(
+ DML_TENSOR_DATA_TYPE dataType,
+ DML_TENSOR_FLAGS /*flags*/,
+ Span sizes)
+ {
+ uint32_t dimensionCount = static_cast(sizes.size());
+ TensorProperties props;
+ props.strides = NullOpt; // no strides
+ props.totalTensorSizeInBytes = DMLCalcBufferTensorSize(dataType, dimensionCount, sizes.data(), nullptr);
+ props.guaranteedBaseOffsetAlignment = 0;
+ return props;
+ }
+
+ static TensorProperties ComputeInterleavedChannel(
+ DML_TENSOR_DATA_TYPE dataType,
+ DML_TENSOR_FLAGS /*flags*/,
+ Span sizes)
+ {
+ uint32_t dimensionCount = static_cast(sizes.size());
+ TensorStrides strides(dimensionCount);
+
+ enum Axes { N, C, /* spatial dimensions ... */ };
+
+ // N dimension strides
+ if (dimensionCount >= 1)
+ {
+ strides[N] = 1;
+ for (uint32_t i = 1; i < dimensionCount; ++i)
+ {
+ strides[N] *= sizes[i];
+ }
+ }
+
+ // C dimension strides
+ if (dimensionCount >= 2)
+ {
+ strides[C] = 1;
+ }
+
+ // Spatial dimension strides
+ if (dimensionCount >= 3)
+ {
+ uint32_t stride = sizes[C];
+ for (uint32_t i = dimensionCount - 1; i >= 2; --i)
+ {
+ strides[i] = stride;
+ stride *= sizes[i];
+ }
+ }
+
+ TensorProperties props;
+ props.strides = std::move(strides);
+ props.totalTensorSizeInBytes = DMLCalcBufferTensorSize(dataType, dimensionCount, sizes.data(), props.strides->data());
+ props.guaranteedBaseOffsetAlignment = 0;
+ return props;
+ }
+
+ Func m_impl;
+ };
+
+ struct TensorDesc
+ {
+ public:
+ using Dimensions = TensorDimensions;
+ using Strides = TensorStrides;
+
+ DML_TENSOR_DATA_TYPE dataType = DML_TENSOR_DATA_TYPE_UNKNOWN;
+ DML_TENSOR_FLAGS flags = DML_TENSOR_FLAG_NONE;
+ Dimensions sizes;
+ Optional strides;
+ uint64_t totalTensorSizeInBytes = 0;
+ uint32_t guaranteedBaseOffsetAlignment = 0;
+
+ TensorDesc() = default;
+
+ TensorDesc(DML_TENSOR_DATA_TYPE dataType, Dimensions sizes, const TensorPolicy& policy = {})
+ : TensorDesc(dataType, DML_TENSOR_FLAG_NONE, sizes, policy)
+ {}
+
+ TensorDesc(DML_TENSOR_DATA_TYPE dataType, DML_TENSOR_FLAGS flags, Dimensions sizes, const TensorPolicy& policy = {})
+ {
+ TensorProperties props = policy.Get(dataType, flags, sizes);
+ Initialize(
+ dataType,
+ flags,
+ std::move(sizes),
+ std::move(props.strides),
+ props.totalTensorSizeInBytes,
+ props.guaranteedBaseOffsetAlignment);
+ }
+
+ TensorDesc(
+ DML_TENSOR_DATA_TYPE dataType,
+ DML_TENSOR_FLAGS flags,
+ Dimensions sizes,
+ Optional strides,
+ uint64_t totalTensorSizeInBytes,
+ uint32_t guaranteedBaseOffsetAlignment)
+ {
+ Initialize(dataType, flags, std::move(sizes), std::move(strides), totalTensorSizeInBytes, guaranteedBaseOffsetAlignment);
+ }
+
+ /* implicit */ TensorDesc(const DML_TENSOR_DESC& desc)
+ : TensorDesc(*static_cast(desc.Desc))
+ {
+ assert(desc.Type == DML_TENSOR_TYPE_BUFFER);
+ assert(desc.Desc != nullptr);
+ }
+
+ /* implicit */ TensorDesc(const DML_BUFFER_TENSOR_DESC& desc)
+ {
+ this->dataType = desc.DataType;
+ this->flags = desc.Flags;
+ this->sizes.assign(desc.Sizes, desc.Sizes + desc.DimensionCount);
+ if (desc.Strides)
+ {
+ this->strides.emplace();
+ this->strides->assign(desc.Strides, desc.Strides + desc.DimensionCount);
+ }
+ this->totalTensorSizeInBytes = desc.TotalTensorSizeInBytes;
+ this->guaranteedBaseOffsetAlignment = desc.GuaranteedBaseOffsetAlignment;
+ }
+
+ // Returns an equivalent DML_TENSOR_DESC or DML_BUFFER_TENSOR_DESC. The returned object contains pointers
+ // into the TensorDesc, so it is only valid as long as the TensorDesc itself is alive.
+ template
+ T* AsPtr()
+ {
+ // "sizeof(T) == -1" is always false; this is just to make the static_assert dependent on the template
+ // parameter and therefore not evaluated until template instantiation
+ static_assert(sizeof(T) == -1, "Invalid type");
+ }
+
+ template <>
+ DML_BUFFER_TENSOR_DESC* AsPtr()
+ {
+ assert(!strides || sizes.size() == strides->size());
+
+ m_bufferDesc.DataType = this->dataType;
+ m_bufferDesc.Flags = this->flags;
+ m_bufferDesc.DimensionCount = static_cast(sizes.size());
+ m_bufferDesc.Sizes = this->sizes.data();
+ m_bufferDesc.Strides = this->strides ? this->strides->data() : nullptr;
+ m_bufferDesc.TotalTensorSizeInBytes = this->totalTensorSizeInBytes;
+ m_bufferDesc.GuaranteedBaseOffsetAlignment = this->guaranteedBaseOffsetAlignment;
+ return &m_bufferDesc;
+ }
+
+ template <>
+ DML_TENSOR_DESC* AsPtr()
+ {
+ m_tensorDesc = DML_TENSOR_DESC{ DML_TENSOR_TYPE_BUFFER, AsPtr() };
+ return &m_tensorDesc;
+ }
+
+ private:
+ DML_BUFFER_TENSOR_DESC m_bufferDesc;
+ DML_TENSOR_DESC m_tensorDesc;
+
+ void Initialize(
+ DML_TENSOR_DATA_TYPE tensorDataType,
+ DML_TENSOR_FLAGS tensorFlags,
+ Dimensions tensorSizes,
+ Optional tensorStrides,
+ uint64_t totalTensorSizeInBytesVal,
+ uint32_t guaranteedBaseOffsetAlignmentVal)
+ {
+ assert(!tensorStrides || tensorStrides->size() == static_cast(tensorSizes.size()));
+
+ this->dataType = tensorDataType;
+ this->flags = tensorFlags;
+ this->sizes = std::move(tensorSizes);
+ this->strides = std::move(tensorStrides);
+ this->totalTensorSizeInBytes = totalTensorSizeInBytesVal;
+ this->guaranteedBaseOffsetAlignment = guaranteedBaseOffsetAlignmentVal;
+ }
+ };
+
+ namespace detail
+ {
+ class GraphBuilder;
+ class NodeOutput;
+
+ // A node in the graph which represents a graph input.
+ struct InputNode
+ {
+ uint32_t inputIndex;
+ };
+
+ // A node in the graph which represents a DML operator.
+ struct OperatorNode
+ {
+ Microsoft::WRL::ComPtr op;
+
+ // The inputs to this node
+ std::vector inputs;
+
+ std::string name;
+ };
+
+ // Used for representing reshapes and type punning
+ struct ReinterpretNode
+ {
+ NodeOutput* input;
+ };
+
+ enum class NodeType
+ {
+ Invalid,
+ Input,
+ Operator,
+ Reinterpret,
+ };
+
+ // Identifies a node in the graph.
+ struct NodeID
+ {
+ NodeType type;
+ uint32_t index; // The index of this node in the GraphBuilder
+ };
+
+ // Represents one of the outputs of a node.
+ class NodeOutput
+ {
+ public:
+ NodeOutput(GraphBuilder* owner, NodeID node, uint32_t outputIndex, TensorDesc tensorDesc)
+ : m_owner(owner)
+ , m_node(node)
+ , m_outputIndex(outputIndex)
+ , m_tensorDesc(std::move(tensorDesc))
+ {}
+
+ // Retrieves the GraphBuilder that owns this object.
+ GraphBuilder* GetGraphBuilder() const { return m_owner; }
+
+ NodeID GetNode() const { return m_node; }
+ uint32_t GetOutputIndex() const { return m_outputIndex; }
+ const TensorDesc& GetOutputDesc() const { return m_tensorDesc; }
+
+ private:
+ GraphBuilder* m_owner;
+ NodeID m_node;
+
+ // An operator can have multiple outputs; this index identifies which one of the operator's outputs this
+ // NodeOutput represents.
+ uint32_t m_outputIndex;
+
+ TensorDesc m_tensorDesc;
+ };
+
+ struct GraphDesc
+ {
+ uint32_t inputCount;
+ uint32_t outputCount;
+ std::vector nodes;
+ std::vector inputEdges;
+ std::vector outputEdges;
+ std::vector intermediateEdges;
+ };
+
+ class GraphBuilder
+ {
+ public:
+ GraphBuilder(IDMLDevice* device, TensorPolicy tensorPolicy = {})
+ : m_device(device)
+ , m_tensorPolicy(tensorPolicy)
+ {}
+
+ IDMLDevice* GetDevice() const
+ {
+ return m_device.Get();
+ }
+
+ void PushName(StringView name)
+ {
+ m_nameSubLengths.push(m_name.size());
+ if (!m_name.empty())
+ {
+ m_name += "_";
+ }
+ m_name += name;
+ }
+
+ void PopName()
+ {
+ if (!m_nameSubLengths.empty())
+ {
+ m_name.resize(m_nameSubLengths.top());
+ m_nameSubLengths.pop();
+ }
+ }
+
+ void SetTensorPolicy(TensorPolicy policy) { m_tensorPolicy = std::move(policy); }
+ const TensorPolicy& GetTensorPolicy() const { return m_tensorPolicy; }
+ TensorPolicy& GetTensorPolicy() { return m_tensorPolicy; }
+
+ // Creates a DML operator node owned by this graph builder and returns a NodeInfo identifier. The
+ // inputs to this node must be supplied in the correct order matching the DML operator.
+ NodeID CreateOperatorNode(DML_OPERATOR_TYPE type, const void* desc, Span inputs);
+ NodeID CreateInputNode(uint32_t inputIndex);
+ NodeID CreateReinterpretNode(NodeOutput* input);
+ NodeOutput* CreateNodeOutput(NodeID node, uint32_t outputIndex, TensorDesc tensorDesc);
+ GraphDesc GetGraphDesc(Span outputs) const;
+
+ private:
+ Microsoft::WRL::ComPtr m_device;
+ TensorPolicy m_tensorPolicy;
+ std::vector m_inputNodes;
+ std::vector m_operatorNodes;
+ std::vector m_reinterpretNodes;
+ std::deque m_nodeOutputs; // deque doesn't invalidate references to elements when it resizes
+
+ std::string m_name;
+ std::stack m_nameSubLengths;
+ };
+
+ } // namespace detail
+
+ class Expression
+ {
+ public:
+ /*implicit*/ Expression(detail::NodeOutput* nodeOutput = nullptr)
+ : m_nodeOutput(nodeOutput)
+ {}
+
+ // Returns a struct containing the required properties of the tensor to hold the output of this expression,
+ // once evaluated.
+ const TensorDesc& GetOutputDesc() const { return Impl()->GetOutputDesc(); }
+
+ // For internal use only
+ detail::NodeOutput* Impl() const { return m_nodeOutput; }
+
+ explicit operator bool() const
+ {
+ return m_nodeOutput != nullptr;
+ }
+
+ private:
+ detail::NodeOutput* m_nodeOutput; // weak; this is owned by the GraphBuilder
+ };
+
+ class NameScope
+ {
+ public:
+ detail::GraphBuilder* m_builder = nullptr;
+
+ NameScope(detail::GraphBuilder* builder, StringView name) : m_builder(builder)
+ {
+ if (m_builder) m_builder->PushName(name);
+ }
+
+ ~NameScope()
+ {
+ if (m_builder) m_builder->PopName();
+ }
+ };
+
+ class Graph
+ {
+ public:
+ explicit Graph(IDMLDevice* device, TensorPolicy tensorPolicy = {})
+ : m_graphBuilder(make_unique(device, tensorPolicy))
+ {}
+
+ // For internal use only
+ detail::GraphBuilder* Impl() { return m_graphBuilder.get(); }
+
+ // Sets/gets the tensor policy. If not set, defaults to TensorPolicy::Default(). Tensor policies can be used
+ // to control properties (such as strides) on output tensors produced by this Graph.
+ void SetTensorPolicy(TensorPolicy policy) { m_graphBuilder->SetTensorPolicy(std::move(policy)); }
+ const TensorPolicy& GetTensorPolicy() const { return m_graphBuilder->GetTensorPolicy(); }
+ TensorPolicy& GetTensorPolicy() { return m_graphBuilder->GetTensorPolicy(); }
+
+ NameScope CreateNameScope(StringView name) { return NameScope(m_graphBuilder.get(), name); }
+
+ void PushName(StringView name) { m_graphBuilder->PushName(name); }
+ void PopName() { m_graphBuilder->PopName(); }
+
+ Microsoft::WRL::ComPtr Compile(
+ DML_EXECUTION_FLAGS flags,
+ Span outputs,
+ uint32_t inputCount = 0) const
+ {
+ detail::GraphDesc graph = m_graphBuilder->GetGraphDesc(outputs);
+
+ // If supplied, the requested number of inputs to the compiled operator can be larger than the actual
+ // number of input nodes on the graph (e.g. in the case of unused empty inputs), but never smaller.
+ assert(inputCount == 0 || inputCount >= graph.inputCount);
+
+ std::vector graphNodes(graph.nodes.size());
+ for (size_t i = 0; i < graphNodes.size(); ++i)
+ {
+ graphNodes[i] = { DML_GRAPH_NODE_TYPE_OPERATOR, &graph.nodes[i] };
+ }
+
+ std::vector inputEdges(graph.inputEdges.size());
+ for (size_t i = 0; i < inputEdges.size(); ++i)
+ {
+ inputEdges[i] = { DML_GRAPH_EDGE_TYPE_INPUT, &graph.inputEdges[i] };
+ }
+
+ std::vector outputEdges(graph.outputEdges.size());
+ for (size_t i = 0; i < outputEdges.size(); ++i)
+ {
+ outputEdges[i] = { DML_GRAPH_EDGE_TYPE_OUTPUT, &graph.outputEdges[i] };
+ }
+
+ std::vector intermediateEdges(graph.intermediateEdges.size());
+ for (size_t i = 0; i < intermediateEdges.size(); ++i)
+ {
+ intermediateEdges[i] = { DML_GRAPH_EDGE_TYPE_INTERMEDIATE, &graph.intermediateEdges[i] };
+ }
+
+ DML_GRAPH_DESC graphDesc = {};
+ graphDesc.InputCount = inputCount ? inputCount : graph.inputCount;
+ graphDesc.OutputCount = graph.outputCount;
+ graphDesc.NodeCount = static_cast(graphNodes.size());
+ graphDesc.Nodes = graphNodes.data();
+ graphDesc.InputEdgeCount = static_cast(inputEdges.size());
+ graphDesc.InputEdges = inputEdges.data();
+ graphDesc.OutputEdgeCount = static_cast(outputEdges.size());
+ graphDesc.OutputEdges = outputEdges.data();
+ graphDesc.IntermediateEdgeCount = static_cast(intermediateEdges.size());
+ graphDesc.IntermediateEdges = intermediateEdges.data();
+
+ Microsoft::WRL::ComPtr device1;
+ DMLX_THROW_IF_FAILED(m_graphBuilder->GetDevice()->QueryInterface(IID_PPV_ARGS(&device1)));
+
+ Microsoft::WRL::ComPtr compiledGraph;
+ DMLX_THROW_IF_FAILED(device1->CompileGraph(&graphDesc, flags, IID_PPV_ARGS(&compiledGraph)));
+
+ return compiledGraph;
+ }
+
+ private:
+ std::unique_ptr m_graphBuilder;
+ };
+
+ // Represents an activation to be fused with an existing operator. The meaning of param1 and param2 depend on the
+ // activation to be fused.
+ //
+ // For HARD_SIGMOID, LINEAR, PARAMETRIC_SOFTPLUS, and SCALED_TANH: param1 = Alpha and param2 = Beta
+ // For ELU, LEAKY_RELU, THRESHOLDED_RELU, and CELU: param1 = Alpha. param2 is unused.
+ // For SCALED_ELU, param1 = Alpha and param2 = Gamma.
+ // For SHRINK, param1 = Bias and param2 = Threshold
+ // For SOFTPLUS, param1 = Steepness.
+ // For all other activations, both param1 and param2 are unused.
+ struct FusedActivation
+ {
+ DML_OPERATOR_TYPE activation = DML_OPERATOR_INVALID;
+ float param1 = 0.0f;
+ float param2 = 0.0f;
+
+ FusedActivation() = default;
+
+ explicit FusedActivation(DML_OPERATOR_TYPE activation, float param1 = 0.0f, float param2 = 0.0f)
+ : activation(activation), param1(param1), param2(param2)
+ {}
+
+ static FusedActivation None()
+ {
+ return FusedActivation();
+ }
+
+ static FusedActivation Elu(float alpha = 1.0f)
+ {
+ return FusedActivation(DML_OPERATOR_ACTIVATION_ELU, alpha);
+ }
+
+ static FusedActivation HardSigmoid(float alpha = 0.2f, float beta = 0.5f)
+ {
+ return FusedActivation(DML_OPERATOR_ACTIVATION_HARD_SIGMOID, alpha, beta);
+ }
+
+ static FusedActivation Identity()
+ {
+ return FusedActivation(DML_OPERATOR_ACTIVATION_IDENTITY);
+ }
+
+ static FusedActivation LeakyRelu(float alpha = 0.01f)
+ {
+ return FusedActivation(DML_OPERATOR_ACTIVATION_LEAKY_RELU, alpha);
+ }
+
+ static FusedActivation Linear(float alpha, float beta)
+ {
+ return FusedActivation(DML_OPERATOR_ACTIVATION_LINEAR, alpha, beta);
+ }
+
+ static FusedActivation ParametricSoftplus(float alpha, float beta)
+ {
+ return FusedActivation(DML_OPERATOR_ACTIVATION_PARAMETRIC_SOFTPLUS, alpha, beta);
+ }
+
+ static FusedActivation Relu()
+ {
+ return FusedActivation(DML_OPERATOR_ACTIVATION_RELU);
+ }
+
+ static FusedActivation ScaledElu(float alpha = 1.67326319217681884765625f, float gamma = 1.05070102214813232421875f)
+ {
+ return FusedActivation(DML_OPERATOR_ACTIVATION_SCALED_ELU, alpha, gamma);
+ }
+
+ static FusedActivation ScaledTanh(float alpha = 1.0f, float beta = 0.5f)
+ {
+ return FusedActivation(DML_OPERATOR_ACTIVATION_SCALED_TANH, alpha, beta);
+ }
+
+ static FusedActivation Sigmoid()
+ {
+ return FusedActivation(DML_OPERATOR_ACTIVATION_SIGMOID);
+ }
+
+ static FusedActivation Softplus(float steepness = 1.0f)
+ {
+ return FusedActivation(DML_OPERATOR_ACTIVATION_SOFTPLUS, steepness);
+ }
+
+ static FusedActivation Softsign()
+ {
+ return FusedActivation(DML_OPERATOR_ACTIVATION_SOFTSIGN);
+ }
+
+ static FusedActivation Tanh()
+ {
+ return FusedActivation(DML_OPERATOR_ACTIVATION_TANH);
+ }
+
+ static FusedActivation ThresholdedRelu(float alpha = 1.0f)
+ {
+ return FusedActivation(DML_OPERATOR_ACTIVATION_THRESHOLDED_RELU, alpha);
+ }
+
+ static FusedActivation Shrink(float bias = 0.0f, float threshold = 0.5f)
+ {
+ return FusedActivation(DML_OPERATOR_ACTIVATION_SHRINK, bias, threshold);
+ }
+
+ static FusedActivation Celu(float alpha = 1.0f)
+ {
+ return FusedActivation(DML_OPERATOR_ACTIVATION_CELU, alpha);
+ }
+ };
+
+ // Implementation detail helper for determining if a list of expressions share the same GraphBuilder.
+ namespace detail
+ {
+ inline bool HasSameOwner(Span exprs)
+ {
+ if (exprs.size() == 0)
+ {
+ return true;
+ }
+
+ detail::GraphBuilder* owner = exprs.begin()->Impl()->GetGraphBuilder();
+ for (Expression expr : exprs)
+ {
+ if (expr.Impl()->GetGraphBuilder() != owner)
+ {
+ return false;
+ }
+ }
+
+ return true;
+ }
+
+ inline bool HasSameOwner(std::initializer_list exprs)
+ {
+ Span span(exprs.begin(), exprs.size());
+ return HasSameOwner(span);
+ }
+
+ inline bool HasSameDataType(Span exprs)
+ {
+ if (exprs.size() == 0)
+ {
+ return true;
+ }
+
+ DML_TENSOR_DATA_TYPE dataType = exprs.begin()->Impl()->GetOutputDesc().dataType;
+ for (Expression expr : exprs)
+ {
+ if (expr.Impl()->GetOutputDesc().dataType != dataType)
+ {
+ return false;
+ }
+ }
+
+ return true;
+ }
+
+ inline bool HasSameDataType(std::initializer_list exprs)
+ {
+ Span span(exprs.begin(), exprs.size());
+ return HasSameDataType(span);
+ }
+ } // namespace detail
+
+ // Expression implementation helpers
+ namespace detail
+ {
+ template
+ Expression ElementWiseUnary(Expression input, const Optional& scaleBias)
+ {
+ detail::GraphBuilder* builder = input.Impl()->GetGraphBuilder();
+
+ TensorDesc inputTensor = input.Impl()->GetOutputDesc();
+ TensorDesc outputTensor(inputTensor.dataType, inputTensor.sizes, builder->GetTensorPolicy()); // Same as input
+
+ TDesc desc = {};
+ desc.InputTensor = inputTensor.AsPtr();
+ desc.OutputTensor = outputTensor.AsPtr();
+ desc.ScaleBias = scaleBias ? &scaleBias.value() : nullptr;
+
+ detail::NodeOutput* const inputs[] = { input.Impl() };
+ detail::NodeID node = builder->CreateOperatorNode(OperatorType, &desc, inputs);
+ detail::NodeOutput* output = builder->CreateNodeOutput(node, 0, std::move(outputTensor));
+
+ return output;
+ }
+
+ template
+ Expression ElementWiseUnary(Expression input, DML_TENSOR_DATA_TYPE outputDataType = DML_TENSOR_DATA_TYPE_UNKNOWN)
+ {
+ detail::GraphBuilder* builder = input.Impl()->GetGraphBuilder();
+
+ TensorDesc inputTensor = input.Impl()->GetOutputDesc();
+
+ if (outputDataType == DML_TENSOR_DATA_TYPE_UNKNOWN)
+ {
+ outputDataType = inputTensor.dataType;
+ }
+ TensorDesc outputTensor(outputDataType, inputTensor.sizes, builder->GetTensorPolicy());
+
+ TDesc desc = {};
+ desc.InputTensor = inputTensor.AsPtr();
+ desc.OutputTensor = outputTensor.AsPtr();
+
+ detail::NodeOutput* const inputs[] = { input.Impl() };
+ detail::NodeID node = builder->CreateOperatorNode(OperatorType, &desc, inputs);
+ detail::NodeOutput* output = builder->CreateNodeOutput(node, 0, std::move(outputTensor));
+
+ return output;
+ }
+
+ template
+ Expression ElementWiseBinary(Expression a, Expression b)
+ {
+ assert(detail::HasSameOwner({ a, b }));
+ detail::GraphBuilder* builder = a.Impl()->GetGraphBuilder();
+
+ TensorDesc aTensor = a.Impl()->GetOutputDesc();
+ TensorDesc bTensor = b.Impl()->GetOutputDesc();
+ TensorDesc outputTensor(aTensor.dataType, aTensor.sizes, builder->GetTensorPolicy()); // Same as input
+
+ TDesc desc = {};
+ desc.ATensor = aTensor.AsPtr();
+ desc.BTensor = bTensor.AsPtr();
+ desc.OutputTensor = outputTensor.AsPtr();
+
+ detail::NodeOutput* const inputs[] = { a.Impl(), b.Impl() };
+ detail::NodeID node = builder->CreateOperatorNode(OperatorType, &desc, inputs);
+ detail::NodeOutput* output = builder->CreateNodeOutput(node, 0, std::move(outputTensor));
+
+ return output;
+ }
+
+ template
+ Expression ElementWiseComparison(Expression a, Expression b, DML_TENSOR_DATA_TYPE outputDataType = DML_TENSOR_DATA_TYPE_UINT8)
+ {
+ assert(detail::HasSameOwner({ a, b }));
+ detail::GraphBuilder* builder = a.Impl()->GetGraphBuilder();
+
+ TensorDesc aTensor = a.Impl()->GetOutputDesc();
+ TensorDesc bTensor = b.Impl()->GetOutputDesc();
+ TensorDesc outputTensor(outputDataType, aTensor.sizes, builder->GetTensorPolicy());
+
+ TDesc desc = {};
+ desc.ATensor = aTensor.AsPtr();
+ desc.BTensor = bTensor.AsPtr();
+ desc.OutputTensor = outputTensor.AsPtr();
+
+ detail::NodeOutput* const inputs[] = { a.Impl(), b.Impl() };
+ detail::NodeID node = builder->CreateOperatorNode(OperatorType, &desc, inputs);
+ detail::NodeOutput* output = builder->CreateNodeOutput(node, 0, std::move(outputTensor));
+
+ return output;
+ }
+
+ // Used to reserve some space on the stack for setting up fused activation operator descs.
+ struct FusedActivationStorage
+ {
+ DML_OPERATOR_DESC opDesc;
+
+ // All fuseable activation descs have a common layout: two tensor desc pointers and up to 2 optional
+ // float parameters, so just use LINEAR as an archetype
+ DML_ACTIVATION_LINEAR_OPERATOR_DESC activationDesc;
+ };
+
+ // Returns the correct value for filling out fused activation fields in the DML API, e.g.
+ // DML_CONVOLUTION_OPERATOR_DESC::FusedActivation. The descs themselves are stored in the `storage` outptr.
+ inline const DML_OPERATOR_DESC* GetFusedActivationPtr(
+ FusedActivation fusedActivation,
+ _Out_ FusedActivationStorage* storage)
+ {
+ if (fusedActivation.activation == DML_OPERATOR_INVALID)
+ {
+ // No fused activation
+ return nullptr;
+ }
+
+ storage->activationDesc.InputTensor = nullptr;
+ storage->activationDesc.OutputTensor = nullptr;
+ storage->activationDesc.Alpha = fusedActivation.param1;
+ storage->activationDesc.Beta = fusedActivation.param2;
+
+ storage->opDesc.Type = fusedActivation.activation;
+ storage->opDesc.Desc = &storage->activationDesc;
+
+ return &storage->opDesc;
+ }
+
+ } // namespace detail
+
+ inline Expression InputTensor(Graph& graph, uint32_t inputIndex, TensorDesc desc)
+ {
+ detail::GraphBuilder* builder = graph.Impl();
+
+ detail::NodeID node = builder->CreateInputNode(inputIndex);
+ detail::NodeOutput* output = builder->CreateNodeOutput(node, 0, std::move(desc));
+ return output;
+ }
+
+ inline Expression Identity(Expression input, const Optional& scaleBias = NullOpt)
+ {
+ return detail::ElementWiseUnary(input, scaleBias);
+ }
+
+ inline Expression Abs(Expression input, const Optional& scaleBias = NullOpt)
+ {
+ return detail::ElementWiseUnary(input, scaleBias);
+ }
+
+ inline Expression ACos(Expression input, const Optional& scaleBias = NullOpt)
+ {
+ return detail::ElementWiseUnary(input, scaleBias);
+ }
+
+ inline Expression Add(Expression a, Expression b)
+ {
+ assert(detail::HasSameOwner({ a, b }));
+ detail::GraphBuilder* builder = a.Impl()->GetGraphBuilder();
+
+ TensorDesc aTensor = a.Impl()->GetOutputDesc();
+ TensorDesc bTensor = b.Impl()->GetOutputDesc();
+ TensorDesc outputTensor(aTensor.dataType, aTensor.sizes, builder->GetTensorPolicy()); // Same as input
+
+ DML_ELEMENT_WISE_ADD_OPERATOR_DESC desc = {};
+ desc.ATensor = aTensor.AsPtr();
+ desc.BTensor = bTensor.AsPtr();
+ desc.OutputTensor = outputTensor.AsPtr();
+
+ detail::NodeOutput* const inputs[] = { a.Impl(), b.Impl() };
+ detail::NodeID node = builder->CreateOperatorNode(DML_OPERATOR_ELEMENT_WISE_ADD, &desc, inputs);
+ detail::NodeOutput* output = builder->CreateNodeOutput(node, 0, std::move(outputTensor));
+
+ return output;
+ }
+
+ inline Expression Add(Expression a, Expression b, FusedActivation fusedActivation)
+ {
+ assert(detail::HasSameOwner({ a, b }));
+ detail::GraphBuilder* builder = a.Impl()->GetGraphBuilder();
+
+ TensorDesc aTensor = a.Impl()->GetOutputDesc();
+ TensorDesc bTensor = b.Impl()->GetOutputDesc();
+ TensorDesc outputTensor(aTensor.dataType, aTensor.sizes, builder->GetTensorPolicy()); // Same as input
+ detail::FusedActivationStorage storage;
+
+ DML_ELEMENT_WISE_ADD1_OPERATOR_DESC desc = {};
+ desc.ATensor = aTensor.AsPtr();
+ desc.BTensor = bTensor.AsPtr();
+ desc.OutputTensor = outputTensor.AsPtr();
+ desc.FusedActivation = detail::GetFusedActivationPtr(fusedActivation, &storage);
+
+ detail::NodeOutput* const inputs[] = { a.Impl(), b.Impl() };
+ detail::NodeID node = builder->CreateOperatorNode(DML_OPERATOR_ELEMENT_WISE_ADD1, &desc, inputs);
+ detail::NodeOutput* output = builder->CreateNodeOutput(node, 0, std::move(outputTensor));
+
+ return output;
+ }
+
+ inline Expression ASin(Expression input, const Optional& scaleBias = NullOpt)
+ {
+ return detail::ElementWiseUnary(input, scaleBias);
+ }
+
+ inline Expression ATan(Expression input, const Optional& scaleBias = NullOpt)
+ {
+ return detail::ElementWiseUnary(input, scaleBias);
+ }
+
+#if DML_TARGET_VERSION >= 0x3100
+
+ inline Expression ATanYX(Expression a, Expression b)
+ {
+ return detail::ElementWiseBinary(a, b);
+ }
+
+#endif // DML_TARGET_VERSION >= 0x3100
+
+ inline Expression Ceil(Expression input, const Optional& scaleBias = NullOpt)
+ {
+ return detail::ElementWiseUnary(input, scaleBias);
+ }
+
+ inline Expression Clip(Expression input, float min, float max, const Optional& scaleBias = NullOpt)
+ {
+ detail::GraphBuilder* builder = input.Impl()->GetGraphBuilder();
+
+ TensorDesc inputTensor = input.Impl()->GetOutputDesc();
+ TensorDesc outputTensor(inputTensor.dataType, inputTensor.sizes, builder->GetTensorPolicy()); // Same as input
+
+ DML_ELEMENT_WISE_CLIP_OPERATOR_DESC desc = {};
+ desc.InputTensor = inputTensor.AsPtr();
+ desc.OutputTensor = outputTensor.AsPtr();
+ desc.ScaleBias = scaleBias ? &scaleBias.value() : nullptr;
+ desc.Min = min;
+ desc.Max = max;
+
+ detail::NodeOutput* const inputs[] = { input.Impl() };
+ detail::NodeID node = builder->CreateOperatorNode(DML_OPERATOR_ELEMENT_WISE_CLIP, &desc, inputs);
+ detail::NodeOutput* output = builder->CreateNodeOutput(node, 0, std::move(outputTensor));
+
+ return output;
+ }
+
+#if DML_TARGET_VERSION >= 0x3100
+
+ inline Expression ClipGrad(Expression input, Expression inputGradient, float min, float max)
+ {
+ detail::GraphBuilder* builder = input.Impl()->GetGraphBuilder();
+
+ TensorDesc inputTensor = input.Impl()->GetOutputDesc();
+ TensorDesc inputGradientTensor = inputGradient.Impl()->GetOutputDesc();
+ TensorDesc outputGradientTensor(inputGradientTensor.dataType, inputGradientTensor.sizes, builder->GetTensorPolicy());
+
+ DML_ELEMENT_WISE_CLIP_GRAD_OPERATOR_DESC desc = {};
+ desc.InputTensor = inputTensor.AsPtr();
+ desc.InputGradientTensor = inputGradientTensor.AsPtr();
+ desc.OutputGradientTensor = outputGradientTensor.AsPtr();
+ desc.Min = min;
+ desc.Max = max;
+
+ detail::NodeOutput* const inputs[] = { input.Impl() };
+ detail::NodeID node = builder->CreateOperatorNode(DML_OPERATOR_ELEMENT_WISE_CLIP_GRAD, &desc, inputs);
+ detail::NodeOutput* output = builder->CreateNodeOutput(node, 0, std::move(outputGradientTensor));
+
+ return output;
+ }
+
+#endif // DML_TARGET_VERSION >= 0x3100
+
+ inline Expression Cos(Expression input, const Optional& scaleBias = NullOpt)
+ {
+ return detail::ElementWiseUnary(input, scaleBias);
+ }
+
+ inline Expression Divide(Expression a, Expression b)
+ {
+ return detail::ElementWiseBinary(a, b);
+ }
+
+ inline Expression Exp(Expression input, const Optional& scaleBias = NullOpt)
+ {
+ return detail::ElementWiseUnary(input, scaleBias);
+ }
+
+ inline Expression Floor(Expression input, const Optional& scaleBias = NullOpt)
+ {
+ return detail::ElementWiseUnary(input, scaleBias);
+ }
+
+ inline Expression Log(Expression input, const Optional& scaleBias = NullOpt)
+ {
+ return detail::ElementWiseUnary(input, scaleBias);
+ }
+
+ inline Expression LogicalAnd(Expression a, Expression b)
+ {
+ return detail::ElementWiseBinary(a, b);
+ }
+
+ inline Expression Equals(Expression a, Expression b, DML_TENSOR_DATA_TYPE outputDataType = DML_TENSOR_DATA_TYPE_UINT8)
+ {
+ return detail::ElementWiseComparison<
+ DML_OPERATOR_ELEMENT_WISE_LOGICAL_EQUALS,
+ DML_ELEMENT_WISE_LOGICAL_EQUALS_OPERATOR_DESC>(a, b, outputDataType);
+ }
+
+ inline Expression GreaterThan(Expression a, Expression b, DML_TENSOR_DATA_TYPE outputDataType = DML_TENSOR_DATA_TYPE_UINT8)
+ {
+ return detail::ElementWiseComparison<
+ DML_OPERATOR_ELEMENT_WISE_LOGICAL_GREATER_THAN,
+ DML_ELEMENT_WISE_LOGICAL_GREATER_THAN_OPERATOR_DESC>(a, b, outputDataType);
+ }
+
+ inline Expression GreaterThanOrEqual(Expression a, Expression b, DML_TENSOR_DATA_TYPE outputDataType = DML_TENSOR_DATA_TYPE_UINT8)
+ {
+ return detail::ElementWiseComparison<
+ DML_OPERATOR_ELEMENT_WISE_LOGICAL_GREATER_THAN_OR_EQUAL,
+ DML_ELEMENT_WISE_LOGICAL_GREATER_THAN_OR_EQUAL_OPERATOR_DESC>(a, b, outputDataType);
+ }
+
+ inline Expression LessThan(Expression a, Expression b, DML_TENSOR_DATA_TYPE outputDataType = DML_TENSOR_DATA_TYPE_UINT8)
+ {
+ return detail::ElementWiseComparison<
+ DML_OPERATOR_ELEMENT_WISE_LOGICAL_LESS_THAN,
+ DML_ELEMENT_WISE_LOGICAL_LESS_THAN_OPERATOR_DESC>(a, b, outputDataType);
+ }
+
+ inline Expression LessThanOrEqual(Expression a, Expression b, DML_TENSOR_DATA_TYPE outputDataType = DML_TENSOR_DATA_TYPE_UINT8)
+ {
+ return detail::ElementWiseComparison<
+ DML_OPERATOR_ELEMENT_WISE_LOGICAL_LESS_THAN_OR_EQUAL,
+ DML_ELEMENT_WISE_LOGICAL_LESS_THAN_OR_EQUAL_OPERATOR_DESC>(a, b, outputDataType);
+ }
+
+ inline Expression LogicalNot(Expression input)
+ {
+ detail::GraphBuilder* builder = input.Impl()->GetGraphBuilder();
+
+ TensorDesc inputTensor = input.Impl()->GetOutputDesc();
+ TensorDesc outputTensor(inputTensor.dataType, inputTensor.sizes, builder->GetTensorPolicy()); // Same as input
+
+ DML_ELEMENT_WISE_LOGICAL_NOT_OPERATOR_DESC desc = {};
+ desc.InputTensor = inputTensor.AsPtr();
+ desc.OutputTensor = outputTensor.AsPtr();
+
+ detail::NodeOutput* const inputs[] = { input.Impl() };
+ detail::NodeID node = builder->CreateOperatorNode(DML_OPERATOR_ELEMENT_WISE_LOGICAL_NOT, &desc, inputs);
+ detail::NodeOutput* output = builder->CreateNodeOutput(node, 0, std::move(outputTensor));
+
+ return output;
+ }
+
+ inline Expression LogicalOr(Expression a, Expression b)
+ {
+ return detail::ElementWiseBinary(a, b);
+ }
+
+ inline Expression LogicalXor(Expression a, Expression b)
+ {
+ return detail::ElementWiseBinary(a, b);
+ }
+
+ inline Expression Max(Expression a, Expression b)
+ {
+ return detail::ElementWiseBinary(a, b);
+ }
+
+ inline Expression Mean(Expression a, Expression b)
+ {
+ return detail::ElementWiseBinary(a, b);
+ }
+
+ inline Expression Min(Expression a, Expression b)
+ {
+ return detail::ElementWiseBinary(a, b);
+ }
+
+ inline Expression Multiply(Expression a, Expression b)
+ {
+ return detail::ElementWiseBinary(a, b);
+ }
+
+ inline Expression Pow(Expression input, Expression exponent, const Optional& scaleBias = NullOpt)
+ {
+ assert(detail::HasSameOwner({ input, exponent }));
+ detail::GraphBuilder* builder = input.Impl()->GetGraphBuilder();
+
+ TensorDesc inputTensor = input.Impl()->GetOutputDesc();
+ TensorDesc exponentTensor = exponent.Impl()->GetOutputDesc();
+ TensorDesc outputTensor(inputTensor.dataType, inputTensor.sizes, builder->GetTensorPolicy()); // Same as input
+
+ DML_ELEMENT_WISE_POW_OPERATOR_DESC desc = {};
+ desc.InputTensor = inputTensor.AsPtr();
+ desc.ExponentTensor = exponentTensor.AsPtr();
+ desc.OutputTensor = outputTensor.AsPtr();
+ desc.ScaleBias = scaleBias ? &scaleBias.value() : nullptr;
+
+ detail::NodeOutput* const inputs[] = { input.Impl(), exponent.Impl() };
+ detail::NodeID node = builder->CreateOperatorNode(DML_OPERATOR_ELEMENT_WISE_POW, &desc, inputs);
+ detail::NodeOutput* output = builder->CreateNodeOutput(node, 0, std::move(outputTensor));
+
+ return output;
+ }
+
+ inline Expression Pow(Expression input, float exponent, const Optional& scaleBias = NullOpt)
+ {
+ detail::GraphBuilder* builder = input.Impl()->GetGraphBuilder();
+
+ TensorDesc inputTensor = input.Impl()->GetOutputDesc();
+ TensorDesc outputTensor(inputTensor.dataType, inputTensor.sizes, builder->GetTensorPolicy()); // Same as input
+
+ DML_ELEMENT_WISE_CONSTANT_POW_OPERATOR_DESC desc = {};
+ desc.InputTensor = inputTensor.AsPtr();
+ desc.OutputTensor = outputTensor.AsPtr();
+ desc.ScaleBias = scaleBias ? &scaleBias.value() : nullptr;
+ desc.Exponent = exponent;
+
+ detail::NodeOutput* const inputs[] = { input.Impl() };
+ detail::NodeID node = builder->CreateOperatorNode(DML_OPERATOR_ELEMENT_WISE_CONSTANT_POW, &desc, inputs);
+ detail::NodeOutput* output = builder->CreateNodeOutput(node, 0, std::move(outputTensor));
+
+ return output;
+ }
+
+ inline Expression Recip(Expression input, const Optional& scaleBias = NullOpt)
+ {
+ return detail::ElementWiseUnary(input, scaleBias);
+ }
+
+ inline Expression Sin(Expression input, const Optional& scaleBias = NullOpt)
+ {
+ return detail::ElementWiseUnary(input, scaleBias);
+ }
+
+ inline Expression Sqrt(Expression input, const Optional& scaleBias = NullOpt)
+ {
+ return detail::ElementWiseUnary(input, scaleBias);
+ }
+
+#if DML_TARGET_VERSION >= 0x3100
+
+ inline Expression DifferenceSquare(Expression a, Expression b)
+ {
+ return detail::ElementWiseBinary(a, b);
+ }
+
+#endif // DML_TARGET_VERSION >= 0x3100
+
+ inline Expression Subtract(Expression a, Expression b)
+ {
+ return detail::ElementWiseBinary(a, b);
+ }
+
+ inline Expression Tan(Expression input, const Optional& scaleBias = NullOpt)
+ {
+ return detail::ElementWiseUnary(input, scaleBias);
+ }
+
+ inline Expression Threshold(Expression input, float min, const Optional& scaleBias = NullOpt)
+ {
+ detail::GraphBuilder* builder = input.Impl()->GetGraphBuilder();
+
+ TensorDesc inputTensor = input.Impl()->GetOutputDesc();
+ TensorDesc outputTensor(inputTensor.dataType, inputTensor.sizes, builder->GetTensorPolicy()); // Same as input
+
+ DML_ELEMENT_WISE_THRESHOLD_OPERATOR_DESC desc = {};
+ desc.InputTensor = inputTensor.AsPtr();
+ desc.OutputTensor = outputTensor.AsPtr();
+ desc.ScaleBias = scaleBias ? &scaleBias.value() : nullptr;
+ desc.Min = min;
+
+ detail::NodeOutput* const inputs[] = { input.Impl() };
+ detail::NodeID node = builder->CreateOperatorNode(DML_OPERATOR_ELEMENT_WISE_THRESHOLD, &desc, inputs);
+ detail::NodeOutput* output = builder->CreateNodeOutput(node, 0, std::move(outputTensor));
+
+ return output;
+ }
+
+ inline Expression QuantizeLinear(Expression input, Expression scale, Expression zeroPoint, DML_TENSOR_DATA_TYPE outputDataType = DML_TENSOR_DATA_TYPE_UINT8)
+ {
+ assert(detail::HasSameOwner({ input, scale, zeroPoint }));
+
+ detail::GraphBuilder* builder = input.Impl()->GetGraphBuilder();
+
+ TensorDesc inputTensor = input.Impl()->GetOutputDesc();
+ TensorDesc scaleTensor = scale.Impl()->GetOutputDesc();
+ TensorDesc zeroPointTensor = zeroPoint.Impl()->GetOutputDesc();
+ TensorDesc outputTensor(outputDataType, inputTensor.sizes, builder->GetTensorPolicy());
+
+ DML_ELEMENT_WISE_QUANTIZE_LINEAR_OPERATOR_DESC desc = {};
+ desc.InputTensor = inputTensor.AsPtr();
+ desc.ScaleTensor = scaleTensor.AsPtr();
+ desc.ZeroPointTensor = zeroPointTensor.AsPtr();
+ desc.OutputTensor = outputTensor.AsPtr();
+
+ detail::NodeOutput* const inputs[] = { input.Impl(), scale.Impl(), zeroPoint.Impl() };
+ detail::NodeID node = builder->CreateOperatorNode(DML_OPERATOR_ELEMENT_WISE_QUANTIZE_LINEAR, &desc, inputs);
+ detail::NodeOutput* output = builder->CreateNodeOutput(node, 0, std::move(outputTensor));
+
+ return output;
+ }
+
+ inline Expression DequantizeLinear(Expression input, Expression scale, Expression zeroPoint)
+ {
+ assert(detail::HasSameOwner({ input, scale, zeroPoint }));
+
+ detail::GraphBuilder* builder = input.Impl()->GetGraphBuilder();
+
+ TensorDesc inputTensor = input.Impl()->GetOutputDesc();
+ TensorDesc scaleTensor = scale.Impl()->GetOutputDesc();
+ TensorDesc zeroPointTensor = zeroPoint.Impl()->GetOutputDesc();
+ TensorDesc outputTensor(scaleTensor.dataType, inputTensor.sizes, builder->GetTensorPolicy());
+
+ DML_ELEMENT_WISE_DEQUANTIZE_LINEAR_OPERATOR_DESC desc = {};
+ desc.InputTensor = inputTensor.AsPtr();
+ desc.ScaleTensor = scaleTensor.AsPtr();
+ desc.ZeroPointTensor = zeroPointTensor.AsPtr();
+ desc.OutputTensor = outputTensor.AsPtr();
+
+ detail::NodeOutput* const inputs[] = { input.Impl(), scale.Impl(), zeroPoint.Impl() };
+ detail::NodeID node = builder->CreateOperatorNode(DML_OPERATOR_ELEMENT_WISE_DEQUANTIZE_LINEAR, &desc, inputs);
+ detail::NodeOutput* output = builder->CreateNodeOutput(node, 0, std::move(outputTensor));
+
+ return output;
+ }
+
+ inline Expression Sign(Expression a)
+ {
+ return detail::ElementWiseUnary(a);
+ }
+
+ inline Expression IsNaN(Expression input, DML_TENSOR_DATA_TYPE outputDataType = DML_TENSOR_DATA_TYPE_UINT8)
+ {
+ return detail::ElementWiseUnary(input, outputDataType);
+ }
+
+ inline Expression Erf(Expression input, const Optional& scaleBias = NullOpt)
+ {
+ return detail::ElementWiseUnary(input, scaleBias);
+ }
+
+ inline Expression Sinh(Expression input, const Optional& scaleBias = NullOpt)
+ {
+ return detail::ElementWiseUnary(input, scaleBias);
+ }
+
+ inline Expression Cosh(Expression input, const Optional& scaleBias = NullOpt)
+ {
+ return detail::ElementWiseUnary(input, scaleBias);
+ }
+
+ inline Expression Tanh(Expression input, const Optional& scaleBias = NullOpt)
+ {
+ return detail::ElementWiseUnary(input, scaleBias);
+ }
+
+ inline Expression ASinh(Expression input, const Optional& scaleBias = NullOpt)
+ {
+ return detail::ElementWiseUnary(input, scaleBias);
+ }
+
+ inline Expression ACosh(Expression input, const Optional& scaleBias = NullOpt)
+ {
+ return detail::ElementWiseUnary(input, scaleBias);
+ }
+
+ inline Expression ATanh(Expression input, const Optional& scaleBias = NullOpt)
+ {
+ return detail::ElementWiseUnary(input, scaleBias);
+ }
+
+ inline Expression If(Expression condition, Expression a, Expression b)
+ {
+ assert(detail::HasSameOwner({ condition, a, b }));
+ assert(detail::HasSameDataType({ a, b }));
+
+ detail::GraphBuilder* builder = condition.Impl()->GetGraphBuilder();
+
+ TensorDesc conditionTensor = condition.Impl()->GetOutputDesc();
+ assert(conditionTensor.dataType == DML_TENSOR_DATA_TYPE_UINT8);
+
+ TensorDesc aTensor = a.Impl()->GetOutputDesc();
+ TensorDesc bTensor = b.Impl()->GetOutputDesc();
+ TensorDesc outputTensor(aTensor.dataType, aTensor.sizes, builder->GetTensorPolicy());
+
+ DML_ELEMENT_WISE_IF_OPERATOR_DESC desc = {};
+ desc.ConditionTensor = conditionTensor.AsPtr();
+ desc.ATensor = aTensor.AsPtr();
+ desc.BTensor = bTensor.AsPtr();
+ desc.OutputTensor = outputTensor.AsPtr();
+
+ detail::NodeOutput* const inputs[] = { condition.Impl(), a.Impl(), b.Impl() };
+ detail::NodeID node = builder->CreateOperatorNode(DML_OPERATOR_ELEMENT_WISE_IF, &desc, inputs);
+ detail::NodeOutput* output = builder->CreateNodeOutput(node, 0, std::move(outputTensor));
+
+ return output;
+ }
+
+ inline Expression BitShiftLeft(Expression a, Expression b)
+ {
+ return detail::ElementWiseBinary(a, b);
+ }
+
+ inline Expression BitShiftRight(Expression a, Expression b)
+ {
+ return detail::ElementWiseBinary(a, b);
+ }
+
+ inline Expression BitAnd(Expression a, Expression b)
+ {
+ return detail::ElementWiseBinary(a, b);
+ }
+
+ inline Expression BitOr(Expression a, Expression b)
+ {
+ return detail::ElementWiseBinary(a, b);
+ }
+
+ inline Expression BitXor(Expression a, Expression b)
+ {
+ return detail::ElementWiseBinary(a, b);
+ }
+
+ inline Expression BitNot(Expression a)
+ {
+ return detail::ElementWiseUnary(a);
+ }
+
+ inline Expression BitCount(Expression a, DML_TENSOR_DATA_TYPE outputDataType = DML_TENSOR_DATA_TYPE_UINT8)
+ {
+ return detail::ElementWiseUnary(a, outputDataType);
+ }
+
+ inline Expression Round(Expression input, DML_ROUNDING_MODE roundingMode = DML_ROUNDING_MODE_HALVES_TO_NEAREST_EVEN)
+ {
+ detail::GraphBuilder* builder = input.Impl()->GetGraphBuilder();
+
+ TensorDesc inputTensor = input.Impl()->GetOutputDesc();
+ TensorDesc outputTensor(inputTensor.dataType, inputTensor.sizes, builder->GetTensorPolicy()); // Same as input
+
+ DML_ELEMENT_WISE_ROUND_OPERATOR_DESC desc = {};
+ desc.InputTensor = inputTensor.AsPtr();
+ desc.OutputTensor = outputTensor.AsPtr();
+ desc.RoundingMode = roundingMode;
+
+ detail::NodeOutput* const inputs[] = { input.Impl() };
+ detail::NodeID node = builder->CreateOperatorNode(DML_OPERATOR_ELEMENT_WISE_ROUND, &desc, inputs);
+ detail::NodeOutput* output = builder->CreateNodeOutput(node, 0, std::move(outputTensor));
+
+ return output;
+ }
+
+ inline Expression IsInfinity(
+ Expression input,
+ DML_IS_INFINITY_MODE infinityMode = DML_IS_INFINITY_MODE_EITHER,
+ DML_TENSOR_DATA_TYPE outputDataType = DML_TENSOR_DATA_TYPE_UINT8)
+ {
+ detail::GraphBuilder* builder = input.Impl()->GetGraphBuilder();
+
+ TensorDesc inputTensor = input.Impl()->GetOutputDesc();
+ TensorDesc outputTensor(outputDataType, inputTensor.sizes, builder->GetTensorPolicy());
+
+ DML_ELEMENT_WISE_IS_INFINITY_OPERATOR_DESC desc = {};
+ desc.InputTensor = inputTensor.AsPtr();
+ desc.OutputTensor = outputTensor.AsPtr();
+ desc.InfinityMode = infinityMode;
+
+ detail::NodeOutput* const inputs[] = { input.Impl() };
+ detail::NodeID node = builder->CreateOperatorNode(DML_OPERATOR_ELEMENT_WISE_IS_INFINITY, &desc, inputs);
+ detail::NodeOutput* output = builder->CreateNodeOutput(node, 0, std::move(outputTensor));
+
+ return output;
+ }
+
+ inline Expression ModulusTruncate(Expression a, Expression b)
+ {
+ return detail::ElementWiseBinary(a, b);
+ }
+
+ inline Expression ModulusFloor(Expression a, Expression b)
+ {
+ return detail::ElementWiseBinary(a, b);
+ }
+
+#pragma region detail
+#define DMLX_ACTIVATION_IMPL(_name) \
+ do { \
+ detail::GraphBuilder* builder = input.Impl()->GetGraphBuilder(); \
+ \
+ TensorDesc inputTensor = input.Impl()->GetOutputDesc(); \
+ TensorDesc outputTensor(inputTensor.dataType, inputTensor.sizes, builder->GetTensorPolicy()); \
+ \
+ DML_##_name##_OPERATOR_DESC desc = {}; \
+ desc.InputTensor = inputTensor.AsPtr(); \
+ desc.OutputTensor = outputTensor.AsPtr(); \
+ \
+ detail::NodeOutput* const inputs[] = { input.Impl() }; \
+ detail::NodeID node = builder->CreateOperatorNode(DML_OPERATOR_##_name, &desc, inputs); \
+ detail::NodeOutput* output = builder->CreateNodeOutput(node, 0, std::move(outputTensor)); \
+ \
+ return output; \
+ } while(0)
+
+#define DMLX_ACTIVATION_IMPL_1(_name, _param1Name, _param1) \
+ do { \
+ detail::GraphBuilder* builder = input.Impl()->GetGraphBuilder(); \
+ \
+ TensorDesc inputTensor = input.Impl()->GetOutputDesc(); \
+ TensorDesc outputTensor(inputTensor.dataType, inputTensor.sizes, builder->GetTensorPolicy()); \
+ \
+ DML_##_name##_OPERATOR_DESC desc = {}; \
+ desc.InputTensor = inputTensor.AsPtr(); \
+ desc.OutputTensor = outputTensor.AsPtr(); \
+ desc._param1Name = _param1; \
+ \
+ detail::NodeOutput* const inputs[] = { input.Impl() }; \
+ detail::NodeID node = builder->CreateOperatorNode(DML_OPERATOR_##_name, &desc, inputs); \
+ detail::NodeOutput* output = builder->CreateNodeOutput(node, 0, std::move(outputTensor)); \
+ \
+ return output; \
+ } while(0)
+
+#define DMLX_ACTIVATION_IMPL_2(_name, _param1Name, _param1, _param2Name, _param2) \
+ do { \
+ detail::GraphBuilder* builder = input.Impl()->GetGraphBuilder(); \
+ \
+ TensorDesc inputTensor = input.Impl()->GetOutputDesc(); \
+ TensorDesc outputTensor(inputTensor.dataType, inputTensor.sizes, builder->GetTensorPolicy()); \
+ \
+ DML_##_name##_OPERATOR_DESC desc = {}; \
+ desc.InputTensor = inputTensor.AsPtr(); \
+ desc.OutputTensor = outputTensor.AsPtr(); \
+ desc._param1Name = _param1; \
+ desc._param2Name = _param2; \
+ \
+ detail::NodeOutput* const inputs[] = { input.Impl() }; \
+ detail::NodeID node = builder->CreateOperatorNode(DML_OPERATOR_##_name, &desc, inputs); \
+ detail::NodeOutput* output = builder->CreateNodeOutput(node, 0, std::move(outputTensor)); \
+ \
+ return output; \
+ } while(0)
+#pragma endregion
+
+ inline Expression ActivationElu(Expression input, float alpha = 1.0f)
+ {
+ DMLX_ACTIVATION_IMPL_1(ACTIVATION_ELU, Alpha, alpha);
+ }
+
+ inline Expression ActivationHardmax(Expression input)
+ {
+ DMLX_ACTIVATION_IMPL(ACTIVATION_HARDMAX);
+ }
+
+ inline Expression ActivationHardSigmoid(Expression input, float alpha = 0.2f, float beta = 0.5f)
+ {
+ DMLX_ACTIVATION_IMPL_2(ACTIVATION_HARD_SIGMOID, Alpha, alpha, Beta, beta);
+ }
+
+ inline Expression ActivationIdentity(Expression input)
+ {
+ DMLX_ACTIVATION_IMPL(ACTIVATION_IDENTITY);
+ }
+
+ inline Expression ActivationLeakyRelu(Expression input, float alpha = 0.01f)
+ {
+ DMLX_ACTIVATION_IMPL_1(ACTIVATION_LEAKY_RELU, Alpha, alpha);
+ }
+
+ inline Expression ActivationLinear(Expression input, float alpha, float beta)
+ {
+ DMLX_ACTIVATION_IMPL_2(ACTIVATION_LINEAR, Alpha, alpha, Beta, beta);
+ }
+
+ inline Expression ActivationLogSoftmax(Expression input)
+ {
+ DMLX_ACTIVATION_IMPL(ACTIVATION_LOG_SOFTMAX);
+ }
+
+ inline Expression ActivationParameterizedRelu(Expression input, Expression slope)
+ {
+ assert(detail::HasSameOwner({ input, slope }));
+
+ detail::GraphBuilder* builder = input.Impl()->GetGraphBuilder();
+
+ TensorDesc inputTensor = input.Impl()->GetOutputDesc();
+ TensorDesc slopeTensor = slope.Impl()->GetOutputDesc();
+ TensorDesc outputTensor(inputTensor.dataType, inputTensor.sizes, builder->GetTensorPolicy());
+
+ DML_ACTIVATION_PARAMETERIZED_RELU_OPERATOR_DESC desc = {};
+ desc.InputTensor = inputTensor.AsPtr();
+ desc.SlopeTensor = slopeTensor.AsPtr();
+ desc.OutputTensor = outputTensor.AsPtr();
+
+ detail::NodeOutput* const inputs[] = { input.Impl(), slope.Impl() };
+ detail::NodeID node = builder->CreateOperatorNode(DML_OPERATOR_ACTIVATION_PARAMETERIZED_RELU, &desc, inputs);
+ detail::NodeOutput* output = builder->CreateNodeOutput(node, 0, std::move(outputTensor));
+
+ return output;
+ }
+
+ inline Expression ActivationParametricSoftplus(Expression input, float alpha, float beta)
+ {
+ DMLX_ACTIVATION_IMPL_2(ACTIVATION_PARAMETRIC_SOFTPLUS, Alpha, alpha, Beta, beta);
+ }
+
+ inline Expression ActivationRelu(Expression input)
+ {
+ DMLX_ACTIVATION_IMPL(ACTIVATION_RELU);
+ }
+
+ inline Expression ActivationScaledElu(Expression input, float alpha = 1.67326319217681884765625f, float gamma = 1.05070102214813232421875f)
+ {
+ DMLX_ACTIVATION_IMPL_2(ACTIVATION_SCALED_ELU, Alpha, alpha, Gamma, gamma);
+ }
+
+ inline Expression ActivationScaledTanh(Expression input, float alpha = 1.0f, float beta = 0.5f)
+ {
+ DMLX_ACTIVATION_IMPL_2(ACTIVATION_SCALED_TANH, Alpha, alpha, Beta, beta);
+ }
+
+ inline Expression ActivationSigmoid(Expression input)
+ {
+ DMLX_ACTIVATION_IMPL(ACTIVATION_SIGMOID);
+ }
+
+ inline Expression ActivationSoftmax(Expression input)
+ {
+ DMLX_ACTIVATION_IMPL(ACTIVATION_SOFTMAX);
+ }
+
+ inline Expression ActivationSoftplus(Expression input, float steepness = 1.0f)
+ {
+ DMLX_ACTIVATION_IMPL_1(ACTIVATION_SOFTPLUS, Steepness, steepness);
+ }
+
+ inline Expression ActivationSoftsign(Expression input)
+ {
+ DMLX_ACTIVATION_IMPL(ACTIVATION_SOFTSIGN);
+ }
+
+ inline Expression ActivationTanh(Expression input)
+ {
+ DMLX_ACTIVATION_IMPL(ACTIVATION_TANH);
+ }
+
+ inline Expression ActivationThresholdedRelu(Expression input, float alpha = 1.0f)
+ {
+ DMLX_ACTIVATION_IMPL_1(ACTIVATION_THRESHOLDED_RELU, Alpha, alpha);
+ }
+
+ inline Expression ActivationShrink(Expression input, float bias = 0.0f, float threshold = 0.5f)
+ {
+ DMLX_ACTIVATION_IMPL_2(ACTIVATION_SHRINK, Bias, bias, Threshold, threshold);
+ }
+
+ inline Expression ActivationCelu(Expression input, float alpha = 1.0f)
+ {
+ DMLX_ACTIVATION_IMPL_1(ACTIVATION_CELU, Alpha, alpha);
+ }
+
+#undef DMLX_ACTIVATION_IMPL
+#undef DMLX_ACTIVATION_IMPL_1
+#undef DMLX_ACTIVATION_IMPL_2
+
+ // ---------------------------------------------------------------------------------------------------------------
+
+ // If not specified, parameters are defaulted to the following values:
+ // Mode = DML_CONVOLUTION_MODE_CROSS_CORRELATION
+ // Direction = DML_CONVOLUTION_DIRECTION_FORWARD
+ // Strides = { 1, 1 } for 2D convolution, { 1, 1, 1 } for 3D convolution
+ // Dilations = { 1, 1 } for 2D convolution, { 1, 1, 1 } for 3D convolution
+ // StartPadding = { 0, 0 } for 2D convolution, { 0, 0, 0 } for 3D convolution
+ // EndPadding = { 0, 0 } for 2D convolution, { 0, 0, 0 } for 3D convolution
+ // OutputPadding = { 0, 0 } for 2D convolution, { 0, 0, 0 } for 3D convolution
+ // GroupCount = 1
+ // FusedActivation = nullptr
+ // OutputSizes = computed from other parameters
+ inline Expression Convolution(
+ Expression input,
+ Expression filter,
+ Optional bias = NullOpt,
+ DML_CONVOLUTION_MODE mode = DML_CONVOLUTION_MODE_CROSS_CORRELATION,
+ DML_CONVOLUTION_DIRECTION direction = DML_CONVOLUTION_DIRECTION_FORWARD,
+ Span strides = {},
+ Span dilations = {},
+ Span startPadding = {},
+ Span endPadding = {},
+ Span outputPadding = {},
+ uint32_t groupCount = 1,
+ FusedActivation fusedActivation = FusedActivation::None(),
+ TensorDimensions outputSizes = {})
+ {
+ assert(detail::HasSameOwner({ input, filter }));
+ assert(!bias || detail::HasSameOwner({ input, *bias }));
+
+ detail::GraphBuilder* builder = input.Impl()->GetGraphBuilder();
+
+ TensorDesc inputTensor = input.Impl()->GetOutputDesc();
+ TensorDesc filterTensor = filter.Impl()->GetOutputDesc();
+ TensorDesc biasTensor;
+ if (bias)
+ {
+ biasTensor = bias->Impl()->GetOutputDesc();
+ }
+
+ uint32_t dimensionCount = static_cast(inputTensor.sizes.size());
+
+ assert(dimensionCount == 4 || dimensionCount == 5);
+ uint32_t spatialDimensionCount = dimensionCount - 2;
+
+ // If the spatial dimension count is 2, we'll just use the first two elements by setting
+ // DimensionCount = 2 in the desc
+ const uint32_t defaultStridesAndDilations[3] = { 1, 1, 1 };
+ const uint32_t defaultPadding[3] = { 0, 0, 0 };
+
+ assert(strides.empty() || strides.size() == spatialDimensionCount);
+ assert(dilations.empty() || dilations.size() == spatialDimensionCount);
+ assert(startPadding.empty() || startPadding.size() == spatialDimensionCount);
+ assert(endPadding.empty() || endPadding.size() == spatialDimensionCount);
+ assert(outputPadding.empty() || outputPadding.size() == spatialDimensionCount);
+ assert(outputSizes.empty() || outputSizes.size() == inputTensor.sizes.size());
+
+ strides = strides.empty() ? Span{ defaultStridesAndDilations } : strides;
+ dilations = dilations.empty() ? Span{ defaultStridesAndDilations } : dilations;
+ startPadding = startPadding.empty() ? Span{ defaultPadding } : startPadding;
+ endPadding = endPadding.empty() ? Span{ defaultPadding } : endPadding;
+ outputPadding = outputPadding.empty() ? Span{ defaultPadding } : outputPadding;
+
+ // Compute the output shapes
+
+ if (outputSizes.empty())
+ {
+ if (direction == DML_CONVOLUTION_DIRECTION_FORWARD)
+ {
+ outputSizes.push_back(inputTensor.sizes[0]); // output[N] = input[N]
+ outputSizes.push_back(filterTensor.sizes[0]); // output[C] = filter[N]
+
+ for (uint32_t dim = 0; dim < spatialDimensionCount; ++dim)
+ {
+ uint32_t inputSize = inputTensor.sizes[dim + 2];
+ uint32_t paddedSize = inputSize + startPadding[dim] + endPadding[dim];
+
+ uint32_t windowSize = filterTensor.sizes[dim + 2];
+ uint32_t kernelSize = 1 + (windowSize - 1) * dilations[dim];
+
+ assert(kernelSize <= paddedSize);
+ assert(strides[dim] != 0);
+
+ outputSizes.push_back(1 + (paddedSize - kernelSize) / strides[dim]);
+ }
+ }
+ else if (direction == DML_CONVOLUTION_DIRECTION_BACKWARD)
+ {
+ // TODO: implement me
+ assert(false);
+ }
+ else
+ {
+ assert(false);
+ DMLX_THROW(E_UNEXPECTED);
+ }
+ }
+
+ TensorDesc outputTensor(inputTensor.dataType, std::move(outputSizes), builder->GetTensorPolicy());
+ detail::FusedActivationStorage storage;
+
+ DML_CONVOLUTION_OPERATOR_DESC desc = {};
+ desc.InputTensor = inputTensor.AsPtr();
+ desc.FilterTensor = filterTensor.AsPtr();
+ desc.BiasTensor = bias ? biasTensor.AsPtr() : nullptr;
+ desc.OutputTensor = outputTensor.AsPtr();
+ desc.Mode = mode;
+ desc.Direction = direction;
+ desc.DimensionCount = spatialDimensionCount;
+ desc.Strides = strides.data();
+ desc.Dilations = dilations.data();
+ desc.StartPadding = startPadding.data();
+ desc.EndPadding = endPadding.data();
+ desc.OutputPadding = outputPadding.data();
+ desc.GroupCount = groupCount;
+ desc.FusedActivation = detail::GetFusedActivationPtr(fusedActivation, &storage);
+
+ detail::NodeOutput* const inputs[] = { input.Impl(), filter.Impl(), bias ? bias->Impl() : nullptr };
+ detail::NodeID node = builder->CreateOperatorNode(DML_OPERATOR_CONVOLUTION, &desc, inputs);
+ detail::NodeOutput* output = builder->CreateNodeOutput(node, 0, std::move(outputTensor));
+
+ return output;
+ }
+
+ // Helper for setting parameters for the Convolution operator. Sample usage:
+ //
+ // auto conv = dml::ConvolutionBuilder(...)
+ // .StartPadding(...)
+ // .EndPadding(...)
+ // .Strides(...)
+ // .Build();
+ //
+ // Parameters left unspecified will be defaulted with the same values as dml::Convolution().
+ class ConvolutionBuilder
+ {
+ public:
+ ConvolutionBuilder(Expression input, Expression filter, Optional bias = NullOpt)
+ : m_input(input), m_filter(filter), m_bias(bias)
+ {}
+
+ ConvolutionBuilder& Mode(DML_CONVOLUTION_MODE mode) { m_mode = mode; return *this; }
+ ConvolutionBuilder& Direction(DML_CONVOLUTION_DIRECTION direction) { m_direction = direction; return *this; }
+ ConvolutionBuilder& Strides(Span strides) { m_strides.assign(strides.begin(), strides.end()); return *this; }
+ ConvolutionBuilder& Dilations(Span dilations) { m_dilations.assign(dilations.begin(), dilations.end()); return *this; }
+ ConvolutionBuilder& StartPadding(Span startPadding) { m_startPadding.assign(startPadding.begin(), startPadding.end()); return *this; }
+ ConvolutionBuilder& EndPadding(Span endPadding) { m_endPadding.assign(endPadding.begin(), endPadding.end()); return *this; }
+ ConvolutionBuilder& OutputPadding(Span outputPadding) { m_outputPadding.assign(outputPadding.begin(), outputPadding.end()); return *this; }
+ ConvolutionBuilder& GroupCount(uint32_t groupCount) { m_groupCount = groupCount; return *this; }
+ ConvolutionBuilder& FusedActivation(FusedActivation fusedActivation) { m_fusedActivation = fusedActivation; return *this; }
+ ConvolutionBuilder& OutputSizes(TensorDimensions outputSizes) { m_outputSizes = std::move(outputSizes); return *this; }
+
+ Expression Build() const
+ {
+ return Convolution(
+ m_input,
+ m_filter,
+ m_bias,
+ m_mode,
+ m_direction,
+ m_strides,
+ m_dilations,
+ m_startPadding,
+ m_endPadding,
+ m_outputPadding,
+ m_groupCount,
+ m_fusedActivation,
+ m_outputSizes);
+ }
+
+ private:
+ Expression m_input;
+ Expression m_filter;
+ Optional m_bias;
+ DML_CONVOLUTION_MODE m_mode = DML_CONVOLUTION_MODE_CROSS_CORRELATION;
+ DML_CONVOLUTION_DIRECTION m_direction = DML_CONVOLUTION_DIRECTION_FORWARD;
+ SmallVector m_strides = {};
+ SmallVector