Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update readme #471

Draft
wants to merge 11 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion DxDispatch/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
cmake_minimum_required(VERSION 3.18)
project(dxdispatch VERSION 0.15.1 LANGUAGES CXX)
project(dxdispatch VERSION 0.15.3 LANGUAGES CXX)

# ==============================================================================
# External Libraries/Helpers
Expand Down
2 changes: 1 addition & 1 deletion DxDispatch/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ The default redistributable versions of components (e.g. nuget, archives):
- **Direct3D 12 (nuget)**: [Microsoft.Direct3D.D3D12 (1.610.2)](https://www.nuget.org/packages/Microsoft.Direct3D.D3D12/1.610.2) - 2023/04/20
- **DX Compiler (archive)**: [December 2022 (v1.7.2212.1)](https://github.com/microsoft/DirectXShaderCompiler/releases/tag/v1.7.2212.1) - 2023/03/02
- **PIX Event Runtime (nuget)**: [WinPixEventRuntime (1.0.230302001)](https://www.nuget.org/packages/WinPixEventRuntime/1.0.230302001) - 2023/03/02
- **ONNX Runtime (nuget)**: [Microsoft.ML.OnnxRuntime.DirectML (1.14.1)](https://www.nuget.org/packages/Microsoft.ML.OnnxRuntime.DirectML/1.14.1) - 2023/02/27
- **ONNX Runtime (nuget)**: [Microsoft.ML.OnnxRuntime.DirectML (1.15.0)](https://www.nuget.org/packages/Microsoft.ML.OnnxRuntime.DirectML/1.15.0) - 2023/05/24

Configuration is done using CMake cache variables. For example, Direct3D can be switched to a system dependency by adding `-DDXD_DIRECT3D_TYPE=winsdk` to the command line when first configuring the project. Use `cmake-gui` or `ccmake` to view the available variables.

Expand Down
12 changes: 6 additions & 6 deletions DxDispatch/cgmanifest.json
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
"type": "nuget",
"nuget": {
"name": "Microsoft.Direct3D.D3D12",
"version": "1.608.2"
"version": "1.610.2"
}
}
},
Expand All @@ -24,7 +24,7 @@
"type": "nuget",
"nuget": {
"name": "Microsoft.AI.DirectML",
"version": "1.10.1"
"version": "1.12.0"
}
}
},
Expand All @@ -33,8 +33,8 @@
"type": "other",
"other": {
"name": "DirectX Shader Compiler",
"version": "2022_12_16",
"downloadUrl": "https://github.com/microsoft/DirectXShaderCompiler/releases/download/v1.7.2212/dxc_2022_12_16.zip"
"version": "2023_03_01",
"downloadUrl": "https://github.com/microsoft/DirectXShaderCompiler/releases/download/v1.7.2212.1/dxc_2023_03_01.zip"
}
}
},
Expand Down Expand Up @@ -89,7 +89,7 @@
"type": "nuget",
"nuget": {
"name": "WinPixEventRuntime",
"version": "1.0.220124001"
"version": "1.0.230302001"
}
}
},
Expand All @@ -116,7 +116,7 @@
"type": "nuget",
"nuget": {
"name": "Microsoft.ML.OnnxRuntime.DirectML",
"version": "1.14.1"
"version": "1.15.0"
}
}
}
Expand Down
4 changes: 2 additions & 2 deletions DxDispatch/cmake/onnxruntime.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,13 @@ function(init_onnxruntime_cache_variables prefix)

# <PREFIX>_ONNXRUNTIME_NUGET_VERSION
set(${prefix}_ONNXRUNTIME_NUGET_VERSION
1.14.1
1.15.0
CACHE STRING "Version of the ONNX Runtime NuGet package (TYPE == nuget)."
)

# <PREFIX>_ONNXRUNTIME_NUGET_HASH
set(${prefix}_ONNXRUNTIME_NUGET_HASH
c8ae7623385b19cd5de968d0df5383e13b97d1b3a6771c9177eac15b56013a5a
C168D1C9C73E14041DF904E4B38F01A7F955AEF94AAFDEB4ED996F0656054062
CACHE STRING "SHA256 hash of the ONNX Runtime NuGet package (TYPE == nuget)."
)

Expand Down
15 changes: 13 additions & 2 deletions DxDispatch/src/dxdispatch/Executor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -357,13 +357,15 @@ std::ostream& operator<<(std::ostream& os, const BufferDataView<T>& view)
{
uint32_t elementCount = view.desc.initialValues.size() / Device::GetSizeInBytes(view.desc.initialValuesDataType);
auto values = reinterpret_cast<const T*>(view.byteValues.data());
printf("elementCount=%d\n", elementCount);
for (uint32_t elementIndex = 0; elementIndex < elementCount; elementIndex++)
{
os << values[elementIndex];
if (elementIndex < elementCount - 1)
{
os << ", ";
}

}
return os;
}
Expand Down Expand Up @@ -399,7 +401,16 @@ void Executor::operator()(const Model::PrintCommand& command)
auto outputValues = m_device->Download(resource.Get());
auto& resourceDesc = m_model.GetResource(command.resourceName);
auto& bufferDesc = std::get<Model::BufferDesc>(resourceDesc.value);
LogInfo(fmt::format("Resource '{}': {}", command.resourceName, ToString(outputValues, bufferDesc)));
// print only output tensor
if (command.resourceName == "output")
{
LogInfo(fmt::format("Resource '{}': {}", command.resourceName, ToString(outputValues, bufferDesc)));
}
if (command.resourceName == "stackedKeyValue")
{
LogInfo(fmt::format("Resource '{}': {}", command.resourceName, ToString(outputValues, bufferDesc)));
}

}
catch (const std::exception& e)
{
Expand Down Expand Up @@ -441,7 +452,7 @@ void Executor::operator()(const Model::WriteFileCommand& command)
}

file.write(reinterpret_cast<const char*>(fileData.data()), fileData.size());
LogInfo(fmt::format("Resource '{}' written to '{}'", command.resourceName, command.targetPath));
//LogInfo(fmt::format("Resource '{}' written to '{}'", command.resourceName, command.targetPath));
}
catch (const std::exception& e)
{
Expand Down
6 changes: 3 additions & 3 deletions DxDispatch/src/dxdispatch/OnnxDispatchable.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -472,9 +472,9 @@ void OnnxDispatchable::Bind(const Bindings& jsonBindings, uint32_t iteration)
{
for (auto& binding : m_mergedBindings)
{
LogInfo(fmt::format("{} Tensor '{}':", (binding.isInput ? "Input" : "Output"), binding.name));
LogInfo(fmt::format(" Resource = {}", binding.resourceType));
LogInfo(fmt::format(" Data Type = {}", GetOnnxTensorTypeString(binding.dataType)));
//LogInfo(fmt::format("{} Tensor '{}':", (binding.isInput ? "Input" : "Output"), binding.name));
//LogInfo(fmt::format(" Resource = {}", binding.resourceType));
//LogInfo(fmt::format(" Data Type = {}", GetOnnxTensorTypeString(binding.dataType)));
std::string shapeString = "[";
for (size_t i = 0; i < binding.shape.size(); i++)
{
Expand Down
52 changes: 52 additions & 0 deletions DxDispatch/src/model/JsonParsers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
#include "StdSupport.h"
#include "NpyReaderWriter.h"

#include <random>

#ifndef WIN32
#define _stricmp strcasecmp
#endif
Expand Down Expand Up @@ -1049,6 +1051,51 @@ std::vector<std::byte> GenerateInitialValuesFromSequence(DML_TENSOR_DATA_TYPE da
}
}

std::vector<std::byte> GenerateInitialValuesFromRandom(DML_TENSOR_DATA_TYPE dataType, const rapidjson::Value& object)
{
auto valueCount = ParseUInt32Field(object, "valueCount");
auto seed = ParseUInt32Field(object, "seed");
auto valueMin = ParseFloat32Field(object, "min");
auto valueMax = ParseFloat32Field(object, "max");

// randomize data
std::mt19937 random_generator(seed); // static, create it once!
std::uniform_real_distribution<float> uniform_distribution(valueMin, valueMax);

auto AsBytes = [&](auto& parser, auto defaultValue)->std::vector<std::byte>
{

std::vector<std::byte> allBytes;
allBytes.reserve(sizeof(defaultValue) * valueCount);
for (size_t i = 0; i < valueCount; i++)
{
const auto f32 = uniform_distribution(random_generator);
const auto value = static_cast<decltype(defaultValue)>(f32);
for (auto byte : gsl::as_bytes(gsl::make_span(&value, 1)))
{
allBytes.push_back(byte);
}
}
return allBytes;
};

switch (dataType)
{
case DML_TENSOR_DATA_TYPE_FLOAT16: return AsBytes(ParseFloat16Field, half_float::half(0));
case DML_TENSOR_DATA_TYPE_FLOAT32: return AsBytes(ParseFloat32Field, 0.0f);
case DML_TENSOR_DATA_TYPE_FLOAT64: return AsBytes(ParseFloat64Field, 0.0);
case DML_TENSOR_DATA_TYPE_UINT8: return AsBytes(ParseUInt8Field, static_cast<uint8_t>(0));
case DML_TENSOR_DATA_TYPE_UINT16: return AsBytes(ParseUInt16Field, static_cast<uint16_t>(0));
case DML_TENSOR_DATA_TYPE_UINT32: return AsBytes(ParseUInt32Field, static_cast<uint32_t>(0));
case DML_TENSOR_DATA_TYPE_UINT64: return AsBytes(ParseUInt64Field, static_cast<uint64_t>(0));
case DML_TENSOR_DATA_TYPE_INT8: return AsBytes(ParseInt8Field, static_cast<int8_t>(0));
case DML_TENSOR_DATA_TYPE_INT16: return AsBytes(ParseInt16Field, static_cast<int16_t>(0));
case DML_TENSOR_DATA_TYPE_INT32: return AsBytes(ParseInt32Field, static_cast<int32_t>(0));
case DML_TENSOR_DATA_TYPE_INT64: return AsBytes(ParseInt64Field, static_cast<int64_t>(0));
default: throw std::invalid_argument(fmt::format("Invalid tensor data type."));
}
}

std::filesystem::path ResolveInputFilePath(const std::filesystem::path& parentPath, std::string_view sourcePath)
{
auto filePathRelativeToParent = std::filesystem::absolute(parentPath / sourcePath);
Expand Down Expand Up @@ -1162,6 +1209,11 @@ Model::BufferDesc ParseModelBufferDesc(const std::filesystem::path& parentPath,
ensureInitialValuesDataType();
buffer.initialValues = GenerateInitialValuesFromSequence(buffer.initialValuesDataType, initialValuesField->value);
}
else if (initialValuesField->value.HasMember("seed"))
{
ensureInitialValuesDataType();
buffer.initialValues = GenerateInitialValuesFromRandom(buffer.initialValuesDataType, initialValuesField->value);
}
// e.g. "initialValues": { "sourcePath": "inputFile.npy" }
else if (initialValuesField->value.HasMember("sourcePath"))
{
Expand Down
77 changes: 77 additions & 0 deletions DxDispatch/src/model/JsonParsersGenerated.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,7 @@ DML_OPERATOR_TYPE ParseDmlOperatorType(const rapidjson::Value& value)
if (!strcmp(valueString, "DML_OPERATOR_RESAMPLE2") || !strcmp(valueString, "RESAMPLE2")) { return DML_OPERATOR_RESAMPLE2; }
if (!strcmp(valueString, "DML_OPERATOR_RESAMPLE_GRAD1") || !strcmp(valueString, "RESAMPLE_GRAD1")) { return DML_OPERATOR_RESAMPLE_GRAD1; }
if (!strcmp(valueString, "DML_OPERATOR_DIAGONAL_MATRIX1") || !strcmp(valueString, "DIAGONAL_MATRIX1")) { return DML_OPERATOR_DIAGONAL_MATRIX1; }
if (!strcmp(valueString, "DML_OPERATOR_MULTIHEAD_ATTENTION") || !strcmp(valueString, "MULTIHEAD_ATTENTION")) { return DML_OPERATOR_MULTIHEAD_ATTENTION; }
throw std::invalid_argument(fmt::format("'{}' is not a recognized value for DML_OPERATOR_TYPE.", valueString));
}

Expand Down Expand Up @@ -429,6 +430,10 @@ DML_FEATURE_LEVEL ParseDmlFeatureLevel(const rapidjson::Value& value)
if (!strcmp(valueString, "DML_FEATURE_LEVEL_4_0") || !strcmp(valueString, "4_0")) { return DML_FEATURE_LEVEL_4_0; }
if (!strcmp(valueString, "DML_FEATURE_LEVEL_4_1") || !strcmp(valueString, "4_1")) { return DML_FEATURE_LEVEL_4_1; }
if (!strcmp(valueString, "DML_FEATURE_LEVEL_5_0") || !strcmp(valueString, "5_0")) { return DML_FEATURE_LEVEL_5_0; }
if (!strcmp(valueString, "DML_FEATURE_LEVEL_5_1") || !strcmp(valueString, "5_1")) { return DML_FEATURE_LEVEL_5_1; }
if (!strcmp(valueString, "DML_FEATURE_LEVEL_5_2") || !strcmp(valueString, "5_2")) { return DML_FEATURE_LEVEL_5_2; }
if (!strcmp(valueString, "DML_FEATURE_LEVEL_6_0") || !strcmp(valueString, "6_0")) { return DML_FEATURE_LEVEL_6_0; }
if (!strcmp(valueString, "DML_FEATURE_LEVEL_6_1") || !strcmp(valueString, "6_1")) { return DML_FEATURE_LEVEL_6_1; }
throw std::invalid_argument(fmt::format("'{}' is not a recognized value for DML_FEATURE_LEVEL.", valueString));
}

Expand Down Expand Up @@ -535,6 +540,28 @@ DML_RANDOM_GENERATOR_TYPE ParseDmlRandomGeneratorTypeField(const rapidjson::Valu
});
}

DML_MULTIHEAD_ATTENTION_MASK_TYPE ParseDmlMultiheadAttentionMaskType(const rapidjson::Value& value)
{
if (value.GetType() != rapidjson::Type::kStringType)
{
throw std::invalid_argument("DML_MULTIHEAD_ATTENTION_MASK_TYPE must be a string.");
}
auto valueString = value.GetString();
if (!strcmp(valueString, "DML_MULTIHEAD_ATTENTION_MASK_TYPE_NONE") || !strcmp(valueString, "NONE")) { return DML_MULTIHEAD_ATTENTION_MASK_TYPE_NONE; }
if (!strcmp(valueString, "DML_MULTIHEAD_ATTENTION_MASK_TYPE_KEY_SEQUENCE_LENGTH") || !strcmp(valueString, "KEY_SEQUENCE_LENGTH")) { return DML_MULTIHEAD_ATTENTION_MASK_TYPE_KEY_SEQUENCE_LENGTH; }
if (!strcmp(valueString, "DML_MULTIHEAD_ATTENTION_MASK_TYPE_KEY_SEQUENCE_END_START") || !strcmp(valueString, "KEY_SEQUENCE_END_START")) { return DML_MULTIHEAD_ATTENTION_MASK_TYPE_KEY_SEQUENCE_END_START; }
if (!strcmp(valueString, "DML_MULTIHEAD_ATTENTION_MASK_TYPE_KEY_QUERY_SEQUENCE_LENGTH_START_END") || !strcmp(valueString, "KEY_QUERY_SEQUENCE_LENGTH_START_END")) { return DML_MULTIHEAD_ATTENTION_MASK_TYPE_KEY_QUERY_SEQUENCE_LENGTH_START_END; }
if (!strcmp(valueString, "DML_MULTIHEAD_ATTENTION_MASK_TYPE_BOOLEAN") || !strcmp(valueString, "BOOLEAN")) { return DML_MULTIHEAD_ATTENTION_MASK_TYPE_BOOLEAN; }
throw std::invalid_argument(fmt::format("'{}' is not a recognized value for DML_MULTIHEAD_ATTENTION_MASK_TYPE.", valueString));
}

DML_MULTIHEAD_ATTENTION_MASK_TYPE ParseDmlMultiheadAttentionMaskTypeField(const rapidjson::Value& object, std::string_view fieldName, bool required, DML_MULTIHEAD_ATTENTION_MASK_TYPE defaultValue)
{
return ParseFieldHelper<DML_MULTIHEAD_ATTENTION_MASK_TYPE>(object, fieldName, required, defaultValue, [](auto& value){
return ParseDmlMultiheadAttentionMaskType(value);
});
}

// ====================================================================================================
// DIRECTML FLAGS
// ====================================================================================================
Expand Down Expand Up @@ -3981,6 +4008,54 @@ Model::DmlDispatchableDesc::BindPoints GetBindPoints(const DML_DIAGONAL_MATRIX1_
return bindPoints;
}

DML_OPERATOR_DESC* ParseDmlMultiheadAttentionOperatorDesc(const rapidjson::Value& value, bool fused, BucketAllocator& allocator)
{
if (!value.IsObject()) { throw std::invalid_argument("Expected a valid JSON object."); }
auto desc = allocator.Allocate<DML_MULTIHEAD_ATTENTION_OPERATOR_DESC>();
desc->QueryTensor = fused ? nullptr : ParseDmlTensorDescField(value, "QueryTensor", allocator, false);
desc->KeyTensor = fused ? nullptr : ParseDmlTensorDescField(value, "KeyTensor", allocator, false);
desc->ValueTensor = fused ? nullptr : ParseDmlTensorDescField(value, "ValueTensor", allocator, false);
desc->StackedQueryKeyTensor = fused ? nullptr : ParseDmlTensorDescField(value, "StackedQueryKeyTensor", allocator, false);
desc->StackedKeyValueTensor = fused ? nullptr : ParseDmlTensorDescField(value, "StackedKeyValueTensor", allocator, false);
desc->StackedQueryKeyValueTensor = fused ? nullptr : ParseDmlTensorDescField(value, "StackedQueryKeyValueTensor", allocator, false);
desc->BiasTensor = fused ? nullptr : ParseDmlTensorDescField(value, "BiasTensor", allocator, false);
desc->MaskTensor = fused ? nullptr : ParseDmlTensorDescField(value, "MaskTensor", allocator, false);
desc->RelativePositionBiasTensor = fused ? nullptr : ParseDmlTensorDescField(value, "RelativePositionBiasTensor", allocator, false);
desc->PastKeyTensor = fused ? nullptr : ParseDmlTensorDescField(value, "PastKeyTensor", allocator, false);
desc->PastValueTensor = fused ? nullptr : ParseDmlTensorDescField(value, "PastValueTensor", allocator, false);
desc->OutputTensor = fused ? nullptr : ParseDmlTensorDescField(value, "OutputTensor", allocator, true);
desc->OutputPresentKeyTensor = fused ? nullptr : ParseDmlTensorDescField(value, "OutputPresentKeyTensor", allocator, false);
desc->OutputPresentValueTensor = fused ? nullptr : ParseDmlTensorDescField(value, "OutputPresentValueTensor", allocator, false);
desc->Scale = ParseFloat32Field(value, "Scale", true);
desc->MaskFilterValue = ParseFloat32Field(value, "MaskFilterValue", true);
desc->HeadCount = ParseUInt32Field(value, "HeadCount", true);
desc->MaskType = ParseDmlMultiheadAttentionMaskTypeField(value, "MaskType", true, {});
auto opDesc = allocator.Allocate<DML_OPERATOR_DESC>();
opDesc->Type = DML_OPERATOR_MULTIHEAD_ATTENTION;
opDesc->Desc = desc;
return opDesc;
}

Model::DmlDispatchableDesc::BindPoints GetBindPoints(const DML_MULTIHEAD_ATTENTION_OPERATOR_DESC& desc)
{
Model::DmlDispatchableDesc::BindPoints bindPoints = {};
bindPoints.inputs.push_back({"QueryTensor", 1, false});
bindPoints.inputs.push_back({"KeyTensor", 1, false});
bindPoints.inputs.push_back({"ValueTensor", 1, false});
bindPoints.inputs.push_back({"StackedQueryKeyTensor", 1, false});
bindPoints.inputs.push_back({"StackedKeyValueTensor", 1, false});
bindPoints.inputs.push_back({"StackedQueryKeyValueTensor", 1, false});
bindPoints.inputs.push_back({"BiasTensor", 1, false});
bindPoints.inputs.push_back({"MaskTensor", 1, false});
bindPoints.inputs.push_back({"RelativePositionBiasTensor", 1, false});
bindPoints.inputs.push_back({"PastKeyTensor", 1, false});
bindPoints.inputs.push_back({"PastValueTensor", 1, false});
bindPoints.outputs.push_back({"OutputTensor", 1, true});
bindPoints.outputs.push_back({"OutputPresentKeyTensor", 1, false});
bindPoints.outputs.push_back({"OutputPresentValueTensor", 1, false});
return bindPoints;
}

DML_OPERATOR_DESC* ParseDmlActivationEluOperatorDesc(const rapidjson::Value& value, bool fused, BucketAllocator& allocator)
{
if (!value.IsObject()) { throw std::invalid_argument("Expected a valid JSON object."); }
Expand Down Expand Up @@ -4651,6 +4726,7 @@ DML_OPERATOR_DESC* ParseDmlOperatorDesc(const rapidjson::Value& value, bool fuse
if (!strcmp(type, "DML_OPERATOR_RESAMPLE2") || !strcmp(type, "RESAMPLE2")) return ParseDmlResample2OperatorDesc(descValue, fused, allocator);
if (!strcmp(type, "DML_OPERATOR_RESAMPLE_GRAD1") || !strcmp(type, "RESAMPLE_GRAD1")) return ParseDmlResampleGrad1OperatorDesc(descValue, fused, allocator);
if (!strcmp(type, "DML_OPERATOR_DIAGONAL_MATRIX1") || !strcmp(type, "DIAGONAL_MATRIX1")) return ParseDmlDiagonalMatrix1OperatorDesc(descValue, fused, allocator);
if (!strcmp(type, "DML_OPERATOR_MULTIHEAD_ATTENTION") || !strcmp(type, "MULTIHEAD_ATTENTION")) return ParseDmlMultiheadAttentionOperatorDesc(descValue, fused, allocator);
if (!strcmp(type, "DML_OPERATOR_ACTIVATION_ELU") || !strcmp(type, "ACTIVATION_ELU")) return ParseDmlActivationEluOperatorDesc(descValue, fused, allocator);
if (!strcmp(type, "DML_OPERATOR_ACTIVATION_CELU") || !strcmp(type, "ACTIVATION_CELU")) return ParseDmlActivationCeluOperatorDesc(descValue, fused, allocator);
if (!strcmp(type, "DML_OPERATOR_ACTIVATION_HARDMAX") || !strcmp(type, "ACTIVATION_HARDMAX")) return ParseDmlActivationHardmaxOperatorDesc(descValue, fused, allocator);
Expand Down Expand Up @@ -4821,6 +4897,7 @@ Model::DmlDispatchableDesc::BindPoints GetBindPoints(const DML_OPERATOR_DESC& de
case DML_OPERATOR_RESAMPLE2: return GetBindPoints(*reinterpret_cast<const DML_RESAMPLE2_OPERATOR_DESC*>(desc.Desc));
case DML_OPERATOR_RESAMPLE_GRAD1: return GetBindPoints(*reinterpret_cast<const DML_RESAMPLE_GRAD1_OPERATOR_DESC*>(desc.Desc));
case DML_OPERATOR_DIAGONAL_MATRIX1: return GetBindPoints(*reinterpret_cast<const DML_DIAGONAL_MATRIX1_OPERATOR_DESC*>(desc.Desc));
case DML_OPERATOR_MULTIHEAD_ATTENTION: return GetBindPoints(*reinterpret_cast<const DML_MULTIHEAD_ATTENTION_OPERATOR_DESC*>(desc.Desc));
case DML_OPERATOR_ACTIVATION_ELU: return GetBindPoints(*reinterpret_cast<const DML_ACTIVATION_ELU_OPERATOR_DESC*>(desc.Desc));
case DML_OPERATOR_ACTIVATION_CELU: return GetBindPoints(*reinterpret_cast<const DML_ACTIVATION_CELU_OPERATOR_DESC*>(desc.Desc));
case DML_OPERATOR_ACTIVATION_HARDMAX: return GetBindPoints(*reinterpret_cast<const DML_ACTIVATION_HARDMAX_OPERATOR_DESC*>(desc.Desc));
Expand Down
Loading