From 44936c4a9d42f1c67e34babb5792adf5bce7f76b Mon Sep 17 00:00:00 2001 From: JN Mo <115294419+EpicJeanNoeMorissette@users.noreply.github.com> Date: Wed, 25 Sep 2024 09:59:33 -0400 Subject: [PATCH] Add support for SPV_KHR_compute_shader_derivative (#5817) * Add support for SPV_KHR_compute_shader_derivative * Update tests for SPV_KHR_compute_shader_derivatives --------- Co-authored-by: MagicPoncho --- source/opt/aggressive_dead_code_elim_pass.cpp | 2 +- .../opt/local_access_chain_convert_pass.cpp | 6 +- source/opt/local_single_block_elim_pass.cpp | 2 +- source/opt/local_single_store_elim_pass.cpp | 2 +- source/opt/trim_capabilities_pass.h | 4 +- source/val/validate_derivatives.cpp | 30 ++++++---- source/val/validate_image.cpp | 57 +++++++++++-------- test/val/val_derivatives_test.cpp | 9 +-- test/val/val_image_test.cpp | 47 ++++++++------- 9 files changed, 90 insertions(+), 69 deletions(-) diff --git a/source/opt/aggressive_dead_code_elim_pass.cpp b/source/opt/aggressive_dead_code_elim_pass.cpp index 6e86c378b9..953a7f5a3f 100644 --- a/source/opt/aggressive_dead_code_elim_pass.cpp +++ b/source/opt/aggressive_dead_code_elim_pass.cpp @@ -1010,7 +1010,7 @@ void AggressiveDCEPass::InitExtensions() { "SPV_NV_bindless_texture", "SPV_EXT_shader_atomic_float_add", "SPV_EXT_fragment_shader_interlock", - "SPV_NV_compute_shader_derivatives", + "SPV_KHR_compute_shader_derivatives", "SPV_NV_cooperative_matrix", "SPV_KHR_cooperative_matrix", "SPV_KHR_ray_tracing_position_fetch" diff --git a/source/opt/local_access_chain_convert_pass.cpp b/source/opt/local_access_chain_convert_pass.cpp index f46c9136fe..ad0277474e 100644 --- a/source/opt/local_access_chain_convert_pass.cpp +++ b/source/opt/local_access_chain_convert_pass.cpp @@ -428,9 +428,9 @@ void LocalAccessChainConvertPass::InitExtensions() { "SPV_KHR_uniform_group_instructions", "SPV_KHR_fragment_shader_barycentric", "SPV_KHR_vulkan_memory_model", "SPV_NV_bindless_texture", "SPV_EXT_shader_atomic_float_add", - "SPV_EXT_fragment_shader_interlock", "SPV_NV_compute_shader_derivatives", - "SPV_NV_cooperative_matrix", "SPV_KHR_cooperative_matrix", - "SPV_KHR_ray_tracing_position_fetch"}); + "SPV_EXT_fragment_shader_interlock", + "SPV_KHR_compute_shader_derivatives", "SPV_NV_cooperative_matrix", + "SPV_KHR_cooperative_matrix", "SPV_KHR_ray_tracing_position_fetch"}); } bool LocalAccessChainConvertPass::AnyIndexIsOutOfBounds( diff --git a/source/opt/local_single_block_elim_pass.cpp b/source/opt/local_single_block_elim_pass.cpp index e0e4f06b9d..7df17d5d62 100644 --- a/source/opt/local_single_block_elim_pass.cpp +++ b/source/opt/local_single_block_elim_pass.cpp @@ -291,7 +291,7 @@ void LocalSingleBlockLoadStoreElimPass::InitExtensions() { "SPV_NV_bindless_texture", "SPV_EXT_shader_atomic_float_add", "SPV_EXT_fragment_shader_interlock", - "SPV_NV_compute_shader_derivatives", + "SPV_KHR_compute_shader_derivatives", "SPV_NV_cooperative_matrix", "SPV_KHR_cooperative_matrix", "SPV_KHR_ray_tracing_position_fetch"}); diff --git a/source/opt/local_single_store_elim_pass.cpp b/source/opt/local_single_store_elim_pass.cpp index 8bdd0f4ea7..bb7bba87eb 100644 --- a/source/opt/local_single_store_elim_pass.cpp +++ b/source/opt/local_single_store_elim_pass.cpp @@ -141,7 +141,7 @@ void LocalSingleStoreElimPass::InitExtensionAllowList() { "SPV_NV_bindless_texture", "SPV_EXT_shader_atomic_float_add", "SPV_EXT_fragment_shader_interlock", - "SPV_NV_compute_shader_derivatives", + "SPV_KHR_compute_shader_derivatives", "SPV_NV_cooperative_matrix", "SPV_KHR_cooperative_matrix", "SPV_KHR_ray_tracing_position_fetch"}); diff --git a/source/opt/trim_capabilities_pass.h b/source/opt/trim_capabilities_pass.h index 3ff6dba2d2..a9d018faa0 100644 --- a/source/opt/trim_capabilities_pass.h +++ b/source/opt/trim_capabilities_pass.h @@ -74,8 +74,8 @@ class TrimCapabilitiesPass : public Pass { // contains unsupported instruction, the pass could yield bad results. static constexpr std::array kSupportedCapabilities{ // clang-format off - spv::Capability::ComputeDerivativeGroupLinearNV, - spv::Capability::ComputeDerivativeGroupQuadsNV, + spv::Capability::ComputeDerivativeGroupLinearKHR, + spv::Capability::ComputeDerivativeGroupQuadsKHR, spv::Capability::Float16, spv::Capability::Float64, spv::Capability::FragmentShaderPixelInterlockEXT, diff --git a/source/val/validate_derivatives.cpp b/source/val/validate_derivatives.cpp index 90cf6645c4..1a473ba880 100644 --- a/source/val/validate_derivatives.cpp +++ b/source/val/validate_derivatives.cpp @@ -60,12 +60,14 @@ spv_result_t DerivativesPass(ValidationState_t& _, const Instruction* inst) { ->RegisterExecutionModelLimitation([opcode](spv::ExecutionModel model, std::string* message) { if (model != spv::ExecutionModel::Fragment && - model != spv::ExecutionModel::GLCompute) { + model != spv::ExecutionModel::GLCompute && + model != spv::ExecutionModel::MeshEXT && + model != spv::ExecutionModel::TaskEXT) { if (message) { *message = std::string( - "Derivative instructions require Fragment or GLCompute " - "execution model: ") + + "Derivative instructions require Fragment, GLCompute, " + "MeshEXT or TaskEXT execution model: ") + spvOpcodeString(opcode); } return false; @@ -79,19 +81,23 @@ spv_result_t DerivativesPass(ValidationState_t& _, const Instruction* inst) { const auto* models = state.GetExecutionModels(entry_point->id()); const auto* modes = state.GetExecutionModes(entry_point->id()); if (models && - models->find(spv::ExecutionModel::GLCompute) != models->end() && + (models->find(spv::ExecutionModel::GLCompute) != + models->end() || + models->find(spv::ExecutionModel::MeshEXT) != models->end() || + models->find(spv::ExecutionModel::TaskEXT) != models->end()) && (!modes || - (modes->find(spv::ExecutionMode::DerivativeGroupLinearNV) == + (modes->find(spv::ExecutionMode::DerivativeGroupLinearKHR) == modes->end() && - modes->find(spv::ExecutionMode::DerivativeGroupQuadsNV) == + modes->find(spv::ExecutionMode::DerivativeGroupQuadsKHR) == modes->end()))) { if (message) { - *message = std::string( - "Derivative instructions require " - "DerivativeGroupQuadsNV " - "or DerivativeGroupLinearNV execution mode for " - "GLCompute execution model: ") + - spvOpcodeString(opcode); + *message = + std::string( + "Derivative instructions require " + "DerivativeGroupQuadsKHR " + "or DerivativeGroupLinearKHR execution mode for " + "GLCompute, MeshEXT or TaskEXT execution model: ") + + spvOpcodeString(opcode); } return false; } diff --git a/source/val/validate_image.cpp b/source/val/validate_image.cpp index e77fc12994..83c9db9209 100644 --- a/source/val/validate_image.cpp +++ b/source/val/validate_image.cpp @@ -2026,11 +2026,13 @@ spv_result_t ValidateImageQueryLod(ValidationState_t& _, ->RegisterExecutionModelLimitation( [&](spv::ExecutionModel model, std::string* message) { if (model != spv::ExecutionModel::Fragment && - model != spv::ExecutionModel::GLCompute) { + model != spv::ExecutionModel::GLCompute && + model != spv::ExecutionModel::MeshEXT && + model != spv::ExecutionModel::TaskEXT) { if (message) { *message = std::string( - "OpImageQueryLod requires Fragment or GLCompute execution " - "model"); + "OpImageQueryLod requires Fragment, GLCompute, MeshEXT or " + "TaskEXT execution model"); } return false; } @@ -2042,16 +2044,20 @@ spv_result_t ValidateImageQueryLod(ValidationState_t& _, std::string* message) { const auto* models = state.GetExecutionModels(entry_point->id()); const auto* modes = state.GetExecutionModes(entry_point->id()); - if (models->find(spv::ExecutionModel::GLCompute) != models->end() && - modes->find(spv::ExecutionMode::DerivativeGroupLinearNV) == - modes->end() && - modes->find(spv::ExecutionMode::DerivativeGroupQuadsNV) == - modes->end()) { + if (models && + (models->find(spv::ExecutionModel::GLCompute) != models->end() || + models->find(spv::ExecutionModel::MeshEXT) != models->end() || + models->find(spv::ExecutionModel::TaskEXT) != models->end()) && + (!modes || + (modes->find(spv::ExecutionMode::DerivativeGroupLinearKHR) == + modes->end() && + modes->find(spv::ExecutionMode::DerivativeGroupQuadsKHR) == + modes->end()))) { if (message) { *message = std::string( - "OpImageQueryLod requires DerivativeGroupQuadsNV " - "or DerivativeGroupLinearNV execution mode for GLCompute " - "execution model"); + "OpImageQueryLod requires DerivativeGroupQuadsKHR " + "or DerivativeGroupLinearKHR execution mode for GLCompute, " + "MeshEXT or TaskEXT execution model"); } return false; } @@ -2320,12 +2326,14 @@ spv_result_t ImagePass(ValidationState_t& _, const Instruction* inst) { ->RegisterExecutionModelLimitation([opcode](spv::ExecutionModel model, std::string* message) { if (model != spv::ExecutionModel::Fragment && - model != spv::ExecutionModel::GLCompute) { + model != spv::ExecutionModel::GLCompute && + model != spv::ExecutionModel::MeshEXT && + model != spv::ExecutionModel::TaskEXT) { if (message) { *message = std::string( - "ImplicitLod instructions require Fragment or GLCompute " - "execution model: ") + + "ImplicitLod instructions require Fragment, GLCompute, " + "MeshEXT or TaskEXT execution model: ") + spvOpcodeString(opcode); } return false; @@ -2339,19 +2347,22 @@ spv_result_t ImagePass(ValidationState_t& _, const Instruction* inst) { const auto* models = state.GetExecutionModels(entry_point->id()); const auto* modes = state.GetExecutionModes(entry_point->id()); if (models && - models->find(spv::ExecutionModel::GLCompute) != models->end() && + (models->find(spv::ExecutionModel::GLCompute) != models->end() || + models->find(spv::ExecutionModel::MeshEXT) != models->end() || + models->find(spv::ExecutionModel::TaskEXT) != models->end()) && (!modes || - (modes->find(spv::ExecutionMode::DerivativeGroupLinearNV) == + (modes->find(spv::ExecutionMode::DerivativeGroupLinearKHR) == modes->end() && - modes->find(spv::ExecutionMode::DerivativeGroupQuadsNV) == + modes->find(spv::ExecutionMode::DerivativeGroupQuadsKHR) == modes->end()))) { if (message) { - *message = - std::string( - "ImplicitLod instructions require DerivativeGroupQuadsNV " - "or DerivativeGroupLinearNV execution mode for GLCompute " - "execution model: ") + - spvOpcodeString(opcode); + *message = std::string( + "ImplicitLod instructions require " + "DerivativeGroupQuadsKHR " + "or DerivativeGroupLinearKHR execution mode for " + "GLCompute, " + "MeshEXT or TaskEXT execution model: ") + + spvOpcodeString(opcode); } return false; } diff --git a/test/val/val_derivatives_test.cpp b/test/val/val_derivatives_test.cpp index e605f3a032..6ddafe493e 100644 --- a/test/val/val_derivatives_test.cpp +++ b/test/val/val_derivatives_test.cpp @@ -156,8 +156,8 @@ TEST_F(ValidateDerivatives, OpDPdxWrongExecutionModel) { CompileSuccessfully(GenerateShaderCode(body, "", "Vertex").c_str()); ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); EXPECT_THAT(getDiagnosticString(), - HasSubstr("Derivative instructions require Fragment or GLCompute " - "execution model: DPdx")); + HasSubstr("Derivative instructions require Fragment, GLCompute, " + "MeshEXT or TaskEXT execution model: DPdx")); } TEST_F(ValidateDerivatives, NoExecutionModeGLCompute) { @@ -181,8 +181,9 @@ OpFunctionEnd EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); EXPECT_THAT(getDiagnosticString(), HasSubstr("Derivative instructions require " - "DerivativeGroupQuadsNV or DerivativeGroupLinearNV " - "execution mode for GLCompute execution model")); + "DerivativeGroupQuadsKHR or DerivativeGroupLinearKHR " + "execution mode for GLCompute, MeshEXT or TaskEXT " + "execution model")); } using ValidateHalfDerivatives = spvtest::ValidateBase; diff --git a/test/val/val_image_test.cpp b/test/val/val_image_test.cpp index 07f0200e21..adba2c652b 100644 --- a/test/val/val_image_test.cpp +++ b/test/val/val_image_test.cpp @@ -4780,7 +4780,8 @@ TEST_F(ValidateImage, QueryLodWrongExecutionModel) { EXPECT_THAT( getDiagnosticString(), HasSubstr( - "OpImageQueryLod requires Fragment or GLCompute execution model")); + "OpImageQueryLod requires Fragment, GLCompute, MeshEXT or TaskEXT " + "execution model")); } TEST_F(ValidateImage, QueryLodWrongExecutionModelWithFunc) { @@ -4801,7 +4802,8 @@ OpFunctionEnd EXPECT_THAT( getDiagnosticString(), HasSubstr( - "OpImageQueryLod requires Fragment or GLCompute execution model")); + "OpImageQueryLod requires Fragment, GLCompute, MeshEXT or TaskEXT " + "execution model")); } TEST_F(ValidateImage, QueryLodComputeShaderDerivatives) { @@ -4813,12 +4815,12 @@ TEST_F(ValidateImage, QueryLodComputeShaderDerivatives) { )"; const std::string extra = R"( -OpCapability ComputeDerivativeGroupLinearNV -OpExtension "SPV_NV_compute_shader_derivatives" +OpCapability ComputeDerivativeGroupLinearKHR +OpExtension "SPV_KHR_compute_shader_derivatives" )"; const std::string mode = R"( OpExecutionMode %main LocalSize 8 8 1 -OpExecutionMode %main DerivativeGroupLinearNV +OpExecutionMode %main DerivativeGroupLinearKHR )"; CompileSuccessfully( GenerateShaderCode(body, extra, "GLCompute", mode).c_str()); @@ -4930,8 +4932,8 @@ TEST_F(ValidateImage, QueryLodComputeShaderDerivativesMissingMode) { )"; const std::string extra = R"( -OpCapability ComputeDerivativeGroupLinearNV -OpExtension "SPV_NV_compute_shader_derivatives" +OpCapability ComputeDerivativeGroupLinearKHR +OpExtension "SPV_KHR_compute_shader_derivatives" )"; const std::string mode = R"( OpExecutionMode %main LocalSize 8 8 1 @@ -4940,9 +4942,9 @@ OpExecutionMode %main LocalSize 8 8 1 GenerateShaderCode(body, extra, "GLCompute", mode).c_str()); ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); EXPECT_THAT(getDiagnosticString(), - HasSubstr("OpImageQueryLod requires DerivativeGroupQuadsNV or " - "DerivativeGroupLinearNV execution mode for GLCompute " - "execution model")); + HasSubstr("OpImageQueryLod requires DerivativeGroupQuadsKHR or " + "DerivativeGroupLinearKHR execution mode for " + "GLCompute, MeshEXT or TaskEXT execution model")); } TEST_F(ValidateImage, ImplicitLodWrongExecutionModel) { @@ -4956,8 +4958,8 @@ TEST_F(ValidateImage, ImplicitLodWrongExecutionModel) { CompileSuccessfully(GenerateShaderCode(body, "", "Vertex").c_str()); ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); EXPECT_THAT(getDiagnosticString(), - HasSubstr("ImplicitLod instructions require Fragment or " - "GLCompute execution model")); + HasSubstr("ImplicitLod instructions require Fragment, " + "GLCompute, MeshEXT or TaskEXT execution model")); } TEST_F(ValidateImage, ImplicitLodComputeShaderDerivatives) { @@ -4969,12 +4971,12 @@ TEST_F(ValidateImage, ImplicitLodComputeShaderDerivatives) { )"; const std::string extra = R"( -OpCapability ComputeDerivativeGroupLinearNV -OpExtension "SPV_NV_compute_shader_derivatives" +OpCapability ComputeDerivativeGroupLinearKHR +OpExtension "SPV_KHR_compute_shader_derivatives" )"; const std::string mode = R"( OpExecutionMode %main LocalSize 8 8 1 -OpExecutionMode %main DerivativeGroupLinearNV +OpExecutionMode %main DerivativeGroupLinearKHR )"; CompileSuccessfully( GenerateShaderCode(body, extra, "GLCompute", mode).c_str()); @@ -4990,8 +4992,8 @@ TEST_F(ValidateImage, ImplicitLodComputeShaderDerivativesMissingMode) { )"; const std::string extra = R"( -OpCapability ComputeDerivativeGroupLinearNV -OpExtension "SPV_NV_compute_shader_derivatives" +OpCapability ComputeDerivativeGroupLinearKHR +OpExtension "SPV_KHR_compute_shader_derivatives" )"; const std::string mode = R"( OpExecutionMode %main LocalSize 8 8 1 @@ -5001,9 +5003,9 @@ OpExecutionMode %main LocalSize 8 8 1 ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); EXPECT_THAT( getDiagnosticString(), - HasSubstr("ImplicitLod instructions require DerivativeGroupQuadsNV or " - "DerivativeGroupLinearNV execution mode for GLCompute " - "execution model")); + HasSubstr("ImplicitLod instructions require DerivativeGroupQuadsKHR or " + "DerivativeGroupLinearKHR execution mode for GLCompute, " + "MeshEXT or TaskEXT execution model")); } TEST_F(ValidateImage, ReadSubpassDataWrongExecutionModel) { @@ -6505,8 +6507,9 @@ OpFunctionEnd EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); EXPECT_THAT(getDiagnosticString(), HasSubstr("ImplicitLod instructions require " - "DerivativeGroupQuadsNV or DerivativeGroupLinearNV " - "execution mode for GLCompute execution model")); + "DerivativeGroupQuadsKHR or DerivativeGroupLinearKHR " + "execution mode for GLCompute, MeshEXT or TaskEXT " + "execution model")); } TEST_F(ValidateImage, TypeSampledImageNotBufferPost1p6) {