Skip to content

Commit

Permalink
Add support for SPV_KHR_compute_shader_derivative (#5817)
Browse files Browse the repository at this point in the history
* Add support for SPV_KHR_compute_shader_derivative

* Update tests for SPV_KHR_compute_shader_derivatives

---------

Co-authored-by: MagicPoncho <[email protected]>
  • Loading branch information
EpicJeanNoeMorissette and MagicPoncho authored Sep 25, 2024
1 parent 362ce7c commit 44936c4
Show file tree
Hide file tree
Showing 9 changed files with 90 additions and 69 deletions.
2 changes: 1 addition & 1 deletion source/opt/aggressive_dead_code_elim_pass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1010,7 +1010,7 @@ void AggressiveDCEPass::InitExtensions() {
"SPV_NV_bindless_texture",
"SPV_EXT_shader_atomic_float_add",
"SPV_EXT_fragment_shader_interlock",
"SPV_NV_compute_shader_derivatives",
"SPV_KHR_compute_shader_derivatives",
"SPV_NV_cooperative_matrix",
"SPV_KHR_cooperative_matrix",
"SPV_KHR_ray_tracing_position_fetch"
Expand Down
6 changes: 3 additions & 3 deletions source/opt/local_access_chain_convert_pass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -428,9 +428,9 @@ void LocalAccessChainConvertPass::InitExtensions() {
"SPV_KHR_uniform_group_instructions",
"SPV_KHR_fragment_shader_barycentric", "SPV_KHR_vulkan_memory_model",
"SPV_NV_bindless_texture", "SPV_EXT_shader_atomic_float_add",
"SPV_EXT_fragment_shader_interlock", "SPV_NV_compute_shader_derivatives",
"SPV_NV_cooperative_matrix", "SPV_KHR_cooperative_matrix",
"SPV_KHR_ray_tracing_position_fetch"});
"SPV_EXT_fragment_shader_interlock",
"SPV_KHR_compute_shader_derivatives", "SPV_NV_cooperative_matrix",
"SPV_KHR_cooperative_matrix", "SPV_KHR_ray_tracing_position_fetch"});
}

bool LocalAccessChainConvertPass::AnyIndexIsOutOfBounds(
Expand Down
2 changes: 1 addition & 1 deletion source/opt/local_single_block_elim_pass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,7 @@ void LocalSingleBlockLoadStoreElimPass::InitExtensions() {
"SPV_NV_bindless_texture",
"SPV_EXT_shader_atomic_float_add",
"SPV_EXT_fragment_shader_interlock",
"SPV_NV_compute_shader_derivatives",
"SPV_KHR_compute_shader_derivatives",
"SPV_NV_cooperative_matrix",
"SPV_KHR_cooperative_matrix",
"SPV_KHR_ray_tracing_position_fetch"});
Expand Down
2 changes: 1 addition & 1 deletion source/opt/local_single_store_elim_pass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ void LocalSingleStoreElimPass::InitExtensionAllowList() {
"SPV_NV_bindless_texture",
"SPV_EXT_shader_atomic_float_add",
"SPV_EXT_fragment_shader_interlock",
"SPV_NV_compute_shader_derivatives",
"SPV_KHR_compute_shader_derivatives",
"SPV_NV_cooperative_matrix",
"SPV_KHR_cooperative_matrix",
"SPV_KHR_ray_tracing_position_fetch"});
Expand Down
4 changes: 2 additions & 2 deletions source/opt/trim_capabilities_pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,8 @@ class TrimCapabilitiesPass : public Pass {
// contains unsupported instruction, the pass could yield bad results.
static constexpr std::array kSupportedCapabilities{
// clang-format off
spv::Capability::ComputeDerivativeGroupLinearNV,
spv::Capability::ComputeDerivativeGroupQuadsNV,
spv::Capability::ComputeDerivativeGroupLinearKHR,
spv::Capability::ComputeDerivativeGroupQuadsKHR,
spv::Capability::Float16,
spv::Capability::Float64,
spv::Capability::FragmentShaderPixelInterlockEXT,
Expand Down
30 changes: 18 additions & 12 deletions source/val/validate_derivatives.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,12 +60,14 @@ spv_result_t DerivativesPass(ValidationState_t& _, const Instruction* inst) {
->RegisterExecutionModelLimitation([opcode](spv::ExecutionModel model,
std::string* message) {
if (model != spv::ExecutionModel::Fragment &&
model != spv::ExecutionModel::GLCompute) {
model != spv::ExecutionModel::GLCompute &&
model != spv::ExecutionModel::MeshEXT &&
model != spv::ExecutionModel::TaskEXT) {
if (message) {
*message =
std::string(
"Derivative instructions require Fragment or GLCompute "
"execution model: ") +
"Derivative instructions require Fragment, GLCompute, "
"MeshEXT or TaskEXT execution model: ") +
spvOpcodeString(opcode);
}
return false;
Expand All @@ -79,19 +81,23 @@ spv_result_t DerivativesPass(ValidationState_t& _, const Instruction* inst) {
const auto* models = state.GetExecutionModels(entry_point->id());
const auto* modes = state.GetExecutionModes(entry_point->id());
if (models &&
models->find(spv::ExecutionModel::GLCompute) != models->end() &&
(models->find(spv::ExecutionModel::GLCompute) !=
models->end() ||
models->find(spv::ExecutionModel::MeshEXT) != models->end() ||
models->find(spv::ExecutionModel::TaskEXT) != models->end()) &&
(!modes ||
(modes->find(spv::ExecutionMode::DerivativeGroupLinearNV) ==
(modes->find(spv::ExecutionMode::DerivativeGroupLinearKHR) ==
modes->end() &&
modes->find(spv::ExecutionMode::DerivativeGroupQuadsNV) ==
modes->find(spv::ExecutionMode::DerivativeGroupQuadsKHR) ==
modes->end()))) {
if (message) {
*message = std::string(
"Derivative instructions require "
"DerivativeGroupQuadsNV "
"or DerivativeGroupLinearNV execution mode for "
"GLCompute execution model: ") +
spvOpcodeString(opcode);
*message =
std::string(
"Derivative instructions require "
"DerivativeGroupQuadsKHR "
"or DerivativeGroupLinearKHR execution mode for "
"GLCompute, MeshEXT or TaskEXT execution model: ") +
spvOpcodeString(opcode);
}
return false;
}
Expand Down
57 changes: 34 additions & 23 deletions source/val/validate_image.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2026,11 +2026,13 @@ spv_result_t ValidateImageQueryLod(ValidationState_t& _,
->RegisterExecutionModelLimitation(
[&](spv::ExecutionModel model, std::string* message) {
if (model != spv::ExecutionModel::Fragment &&
model != spv::ExecutionModel::GLCompute) {
model != spv::ExecutionModel::GLCompute &&
model != spv::ExecutionModel::MeshEXT &&
model != spv::ExecutionModel::TaskEXT) {
if (message) {
*message = std::string(
"OpImageQueryLod requires Fragment or GLCompute execution "
"model");
"OpImageQueryLod requires Fragment, GLCompute, MeshEXT or "
"TaskEXT execution model");
}
return false;
}
Expand All @@ -2042,16 +2044,20 @@ spv_result_t ValidateImageQueryLod(ValidationState_t& _,
std::string* message) {
const auto* models = state.GetExecutionModels(entry_point->id());
const auto* modes = state.GetExecutionModes(entry_point->id());
if (models->find(spv::ExecutionModel::GLCompute) != models->end() &&
modes->find(spv::ExecutionMode::DerivativeGroupLinearNV) ==
modes->end() &&
modes->find(spv::ExecutionMode::DerivativeGroupQuadsNV) ==
modes->end()) {
if (models &&
(models->find(spv::ExecutionModel::GLCompute) != models->end() ||
models->find(spv::ExecutionModel::MeshEXT) != models->end() ||
models->find(spv::ExecutionModel::TaskEXT) != models->end()) &&
(!modes ||
(modes->find(spv::ExecutionMode::DerivativeGroupLinearKHR) ==
modes->end() &&
modes->find(spv::ExecutionMode::DerivativeGroupQuadsKHR) ==
modes->end()))) {
if (message) {
*message = std::string(
"OpImageQueryLod requires DerivativeGroupQuadsNV "
"or DerivativeGroupLinearNV execution mode for GLCompute "
"execution model");
"OpImageQueryLod requires DerivativeGroupQuadsKHR "
"or DerivativeGroupLinearKHR execution mode for GLCompute, "
"MeshEXT or TaskEXT execution model");
}
return false;
}
Expand Down Expand Up @@ -2320,12 +2326,14 @@ spv_result_t ImagePass(ValidationState_t& _, const Instruction* inst) {
->RegisterExecutionModelLimitation([opcode](spv::ExecutionModel model,
std::string* message) {
if (model != spv::ExecutionModel::Fragment &&
model != spv::ExecutionModel::GLCompute) {
model != spv::ExecutionModel::GLCompute &&
model != spv::ExecutionModel::MeshEXT &&
model != spv::ExecutionModel::TaskEXT) {
if (message) {
*message =
std::string(
"ImplicitLod instructions require Fragment or GLCompute "
"execution model: ") +
"ImplicitLod instructions require Fragment, GLCompute, "
"MeshEXT or TaskEXT execution model: ") +
spvOpcodeString(opcode);
}
return false;
Expand All @@ -2339,19 +2347,22 @@ spv_result_t ImagePass(ValidationState_t& _, const Instruction* inst) {
const auto* models = state.GetExecutionModels(entry_point->id());
const auto* modes = state.GetExecutionModes(entry_point->id());
if (models &&
models->find(spv::ExecutionModel::GLCompute) != models->end() &&
(models->find(spv::ExecutionModel::GLCompute) != models->end() ||
models->find(spv::ExecutionModel::MeshEXT) != models->end() ||
models->find(spv::ExecutionModel::TaskEXT) != models->end()) &&
(!modes ||
(modes->find(spv::ExecutionMode::DerivativeGroupLinearNV) ==
(modes->find(spv::ExecutionMode::DerivativeGroupLinearKHR) ==
modes->end() &&
modes->find(spv::ExecutionMode::DerivativeGroupQuadsNV) ==
modes->find(spv::ExecutionMode::DerivativeGroupQuadsKHR) ==
modes->end()))) {
if (message) {
*message =
std::string(
"ImplicitLod instructions require DerivativeGroupQuadsNV "
"or DerivativeGroupLinearNV execution mode for GLCompute "
"execution model: ") +
spvOpcodeString(opcode);
*message = std::string(
"ImplicitLod instructions require "
"DerivativeGroupQuadsKHR "
"or DerivativeGroupLinearKHR execution mode for "
"GLCompute, "
"MeshEXT or TaskEXT execution model: ") +
spvOpcodeString(opcode);
}
return false;
}
Expand Down
9 changes: 5 additions & 4 deletions test/val/val_derivatives_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -156,8 +156,8 @@ TEST_F(ValidateDerivatives, OpDPdxWrongExecutionModel) {
CompileSuccessfully(GenerateShaderCode(body, "", "Vertex").c_str());
ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
EXPECT_THAT(getDiagnosticString(),
HasSubstr("Derivative instructions require Fragment or GLCompute "
"execution model: DPdx"));
HasSubstr("Derivative instructions require Fragment, GLCompute, "
"MeshEXT or TaskEXT execution model: DPdx"));
}

TEST_F(ValidateDerivatives, NoExecutionModeGLCompute) {
Expand All @@ -181,8 +181,9 @@ OpFunctionEnd
EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
EXPECT_THAT(getDiagnosticString(),
HasSubstr("Derivative instructions require "
"DerivativeGroupQuadsNV or DerivativeGroupLinearNV "
"execution mode for GLCompute execution model"));
"DerivativeGroupQuadsKHR or DerivativeGroupLinearKHR "
"execution mode for GLCompute, MeshEXT or TaskEXT "
"execution model"));
}

using ValidateHalfDerivatives = spvtest::ValidateBase<std::string>;
Expand Down
47 changes: 25 additions & 22 deletions test/val/val_image_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4780,7 +4780,8 @@ TEST_F(ValidateImage, QueryLodWrongExecutionModel) {
EXPECT_THAT(
getDiagnosticString(),
HasSubstr(
"OpImageQueryLod requires Fragment or GLCompute execution model"));
"OpImageQueryLod requires Fragment, GLCompute, MeshEXT or TaskEXT "
"execution model"));
}

TEST_F(ValidateImage, QueryLodWrongExecutionModelWithFunc) {
Expand All @@ -4801,7 +4802,8 @@ OpFunctionEnd
EXPECT_THAT(
getDiagnosticString(),
HasSubstr(
"OpImageQueryLod requires Fragment or GLCompute execution model"));
"OpImageQueryLod requires Fragment, GLCompute, MeshEXT or TaskEXT "
"execution model"));
}

TEST_F(ValidateImage, QueryLodComputeShaderDerivatives) {
Expand All @@ -4813,12 +4815,12 @@ TEST_F(ValidateImage, QueryLodComputeShaderDerivatives) {
)";

const std::string extra = R"(
OpCapability ComputeDerivativeGroupLinearNV
OpExtension "SPV_NV_compute_shader_derivatives"
OpCapability ComputeDerivativeGroupLinearKHR
OpExtension "SPV_KHR_compute_shader_derivatives"
)";
const std::string mode = R"(
OpExecutionMode %main LocalSize 8 8 1
OpExecutionMode %main DerivativeGroupLinearNV
OpExecutionMode %main DerivativeGroupLinearKHR
)";
CompileSuccessfully(
GenerateShaderCode(body, extra, "GLCompute", mode).c_str());
Expand Down Expand Up @@ -4930,8 +4932,8 @@ TEST_F(ValidateImage, QueryLodComputeShaderDerivativesMissingMode) {
)";

const std::string extra = R"(
OpCapability ComputeDerivativeGroupLinearNV
OpExtension "SPV_NV_compute_shader_derivatives"
OpCapability ComputeDerivativeGroupLinearKHR
OpExtension "SPV_KHR_compute_shader_derivatives"
)";
const std::string mode = R"(
OpExecutionMode %main LocalSize 8 8 1
Expand All @@ -4940,9 +4942,9 @@ OpExecutionMode %main LocalSize 8 8 1
GenerateShaderCode(body, extra, "GLCompute", mode).c_str());
ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
EXPECT_THAT(getDiagnosticString(),
HasSubstr("OpImageQueryLod requires DerivativeGroupQuadsNV or "
"DerivativeGroupLinearNV execution mode for GLCompute "
"execution model"));
HasSubstr("OpImageQueryLod requires DerivativeGroupQuadsKHR or "
"DerivativeGroupLinearKHR execution mode for "
"GLCompute, MeshEXT or TaskEXT execution model"));
}

TEST_F(ValidateImage, ImplicitLodWrongExecutionModel) {
Expand All @@ -4956,8 +4958,8 @@ TEST_F(ValidateImage, ImplicitLodWrongExecutionModel) {
CompileSuccessfully(GenerateShaderCode(body, "", "Vertex").c_str());
ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
EXPECT_THAT(getDiagnosticString(),
HasSubstr("ImplicitLod instructions require Fragment or "
"GLCompute execution model"));
HasSubstr("ImplicitLod instructions require Fragment, "
"GLCompute, MeshEXT or TaskEXT execution model"));
}

TEST_F(ValidateImage, ImplicitLodComputeShaderDerivatives) {
Expand All @@ -4969,12 +4971,12 @@ TEST_F(ValidateImage, ImplicitLodComputeShaderDerivatives) {
)";

const std::string extra = R"(
OpCapability ComputeDerivativeGroupLinearNV
OpExtension "SPV_NV_compute_shader_derivatives"
OpCapability ComputeDerivativeGroupLinearKHR
OpExtension "SPV_KHR_compute_shader_derivatives"
)";
const std::string mode = R"(
OpExecutionMode %main LocalSize 8 8 1
OpExecutionMode %main DerivativeGroupLinearNV
OpExecutionMode %main DerivativeGroupLinearKHR
)";
CompileSuccessfully(
GenerateShaderCode(body, extra, "GLCompute", mode).c_str());
Expand All @@ -4990,8 +4992,8 @@ TEST_F(ValidateImage, ImplicitLodComputeShaderDerivativesMissingMode) {
)";

const std::string extra = R"(
OpCapability ComputeDerivativeGroupLinearNV
OpExtension "SPV_NV_compute_shader_derivatives"
OpCapability ComputeDerivativeGroupLinearKHR
OpExtension "SPV_KHR_compute_shader_derivatives"
)";
const std::string mode = R"(
OpExecutionMode %main LocalSize 8 8 1
Expand All @@ -5001,9 +5003,9 @@ OpExecutionMode %main LocalSize 8 8 1
ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
EXPECT_THAT(
getDiagnosticString(),
HasSubstr("ImplicitLod instructions require DerivativeGroupQuadsNV or "
"DerivativeGroupLinearNV execution mode for GLCompute "
"execution model"));
HasSubstr("ImplicitLod instructions require DerivativeGroupQuadsKHR or "
"DerivativeGroupLinearKHR execution mode for GLCompute, "
"MeshEXT or TaskEXT execution model"));
}

TEST_F(ValidateImage, ReadSubpassDataWrongExecutionModel) {
Expand Down Expand Up @@ -6505,8 +6507,9 @@ OpFunctionEnd
EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
EXPECT_THAT(getDiagnosticString(),
HasSubstr("ImplicitLod instructions require "
"DerivativeGroupQuadsNV or DerivativeGroupLinearNV "
"execution mode for GLCompute execution model"));
"DerivativeGroupQuadsKHR or DerivativeGroupLinearKHR "
"execution mode for GLCompute, MeshEXT or TaskEXT "
"execution model"));
}

TEST_F(ValidateImage, TypeSampledImageNotBufferPost1p6) {
Expand Down

0 comments on commit 44936c4

Please sign in to comment.