Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[NFC][DirectX] Infrastructure to collect shader flags for each function #112967

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

bharadwajy
Copy link
Contributor

Currently, ShaderFlagsAnalysis pass represents various module-level properties as well as function-level properties of a DXIL Module using a single mask. However, separate flags to represent module-level properties and function-level properties are needed for accurate computation of shader flags mask, such as for entry function metadata creation.

This change introduces a structure that allows separate representation of

(a) shader flag mask to represent module properties
(b) a map of function to shader flag mask that represent function properties

instead of a single shader flag mask that represents module properties and properties of all function. The result type of ShaderFlagsAnalysis pass is changed to newly-defined structure type instead of a single shader flags mask.

This seperation allows accurate computation of shader flags of an entry function for use during its metadata generation (DXILTranslateMetadata pass) and its feature flags in DX container globals construction (DXContainerGlobals pass) based on the shader flags mask of functions called in entry function. However, note that the change to implement such callee-based shader flags mask computation is planned in a follow-on PR. Consequently, this PR changes shader flag mask computation in DXILTranslateMetadata and DXContainerGlobals passes to simply be a union of module flags and shader flags of all functions, thereby retaining the existing effect of using a single shader flag mask.

Currently, ShaderFlagsAnalysis pass represents various module-level
properties as well as function-level properties of a DXIL Module
using a single mask. However, separate flags to represent module-level
properties and function-level properties are needed for accurate computation
of shader flags mask, such as for entry function metadata creation.

This change introduces a structure that allows separate representation of

(a) shader flag mask to represent module properties
(b) a map of function to shader flag mask that represent function properties

instead of a single shader flag mask that represents module properties
and properties of all function. The result type of ShaderFlagsAnalysis
pass is changed to newly-defined structure type instead of a single shader
flags mask.

This seperation allows accurate computation of shader flags of an entry
function for use during its metadata generation (DXILTranslateMetadata pass)
and its feature flags in DX container globals construction (DXContainerGlobals
pass) based on the shader flags mask of functions called in entry function.
However, note that the change to implement such callee-based shader flags mask
computation is planned in a follow-on PR. Consequently, this PR changes shader
flag mask computation in DXILTranslateMetadata and DXContainerGlobals passes
to simply be a union of module flags and shader flags of all functions, thereby
retaining the existing effect of using a single shader flag mask.
@llvmbot
Copy link
Collaborator

llvmbot commented Oct 18, 2024

@llvm/pr-subscribers-backend-directx

Author: S. Bharadwaj Yadavalli (bharadwajy)

Changes

Currently, ShaderFlagsAnalysis pass represents various module-level properties as well as function-level properties of a DXIL Module using a single mask. However, separate flags to represent module-level properties and function-level properties are needed for accurate computation of shader flags mask, such as for entry function metadata creation.

This change introduces a structure that allows separate representation of

(a) shader flag mask to represent module properties
(b) a map of function to shader flag mask that represent function properties

instead of a single shader flag mask that represents module properties and properties of all function. The result type of ShaderFlagsAnalysis pass is changed to newly-defined structure type instead of a single shader flags mask.

This seperation allows accurate computation of shader flags of an entry function for use during its metadata generation (DXILTranslateMetadata pass) and its feature flags in DX container globals construction (DXContainerGlobals pass) based on the shader flags mask of functions called in entry function. However, note that the change to implement such callee-based shader flags mask computation is planned in a follow-on PR. Consequently, this PR changes shader flag mask computation in DXILTranslateMetadata and DXContainerGlobals passes to simply be a union of module flags and shader flags of all functions, thereby retaining the existing effect of using a single shader flag mask.


Full diff: https://github.com/llvm/llvm-project/pull/112967.diff

4 Files Affected:

  • (modified) llvm/lib/Target/DirectX/DXContainerGlobals.cpp (+10-5)
  • (modified) llvm/lib/Target/DirectX/DXILShaderFlags.cpp (+32-14)
  • (modified) llvm/lib/Target/DirectX/DXILShaderFlags.h (+17-9)
  • (modified) llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp (+28-18)
diff --git a/llvm/lib/Target/DirectX/DXContainerGlobals.cpp b/llvm/lib/Target/DirectX/DXContainerGlobals.cpp
index 2c11373504e8c7..c7202cc04c26dc 100644
--- a/llvm/lib/Target/DirectX/DXContainerGlobals.cpp
+++ b/llvm/lib/Target/DirectX/DXContainerGlobals.cpp
@@ -78,13 +78,18 @@ bool DXContainerGlobals::runOnModule(Module &M) {
 }
 
 GlobalVariable *DXContainerGlobals::getFeatureFlags(Module &M) {
-  const uint64_t FeatureFlags =
-      static_cast<uint64_t>(getAnalysis<ShaderFlagsAnalysisWrapper>()
-                                .getShaderFlags()
-                                .getFeatureFlags());
+  const DXILModuleShaderFlagsInfo &MSFI =
+      getAnalysis<ShaderFlagsAnalysisWrapper>().getShaderFlags();
+  // TODO: Feature flags mask is obtained as a collection of feature flags
+  // of the shader flags of all functions in the module. Need to verify
+  // and modify the computation of feature flags to be used.
+  uint64_t ConsolidatedFeatureFlags = 0;
+  for (const auto &FuncFlags : MSFI.FuncShaderFlagsMap) {
+    ConsolidatedFeatureFlags |= FuncFlags.second.getFeatureFlags();
+  }
 
   Constant *FeatureFlagsConstant =
-      ConstantInt::get(M.getContext(), APInt(64, FeatureFlags));
+      ConstantInt::get(M.getContext(), APInt(64, ConsolidatedFeatureFlags));
   return buildContainerGlobal(M, FeatureFlagsConstant, "dx.sfi0", "SFI0");
 }
 
diff --git a/llvm/lib/Target/DirectX/DXILShaderFlags.cpp b/llvm/lib/Target/DirectX/DXILShaderFlags.cpp
index 9fa137b4c025e1..8c590862008862 100644
--- a/llvm/lib/Target/DirectX/DXILShaderFlags.cpp
+++ b/llvm/lib/Target/DirectX/DXILShaderFlags.cpp
@@ -20,33 +20,41 @@
 using namespace llvm;
 using namespace llvm::dxil;
 
-static void updateFlags(ComputedShaderFlags &Flags, const Instruction &I) {
+static void updateFlags(DXILModuleShaderFlagsInfo &MSFI, const Instruction &I) {
+  ComputedShaderFlags &FSF = MSFI.FuncShaderFlagsMap[I.getFunction()];
   Type *Ty = I.getType();
   if (Ty->isDoubleTy()) {
-    Flags.Doubles = true;
+    FSF.Doubles = true;
     switch (I.getOpcode()) {
     case Instruction::FDiv:
     case Instruction::UIToFP:
     case Instruction::SIToFP:
     case Instruction::FPToUI:
     case Instruction::FPToSI:
-      Flags.DX11_1_DoubleExtensions = true;
+      FSF.DX11_1_DoubleExtensions = true;
       break;
     }
   }
 }
 
-ComputedShaderFlags ComputedShaderFlags::computeFlags(Module &M) {
-  ComputedShaderFlags Flags;
-  for (const auto &F : M)
+static DXILModuleShaderFlagsInfo computeFlags(Module &M) {
+  DXILModuleShaderFlagsInfo MSFI;
+  for (const auto &F : M) {
+    if (F.isDeclaration())
+      continue;
+    if (!MSFI.FuncShaderFlagsMap.contains(&F)) {
+      ComputedShaderFlags CSF{};
+      MSFI.FuncShaderFlagsMap[&F] = CSF;
+    }
     for (const auto &BB : F)
       for (const auto &I : BB)
-        updateFlags(Flags, I);
-  return Flags;
+        updateFlags(MSFI, I);
+  }
+  return MSFI;
 }
 
 void ComputedShaderFlags::print(raw_ostream &OS) const {
-  uint64_t FlagVal = (uint64_t) * this;
+  uint64_t FlagVal = (uint64_t)*this;
   OS << formatv("; Shader Flags Value: {0:x8}\n;\n", FlagVal);
   if (FlagVal == 0)
     return;
@@ -65,15 +73,25 @@ void ComputedShaderFlags::print(raw_ostream &OS) const {
 
 AnalysisKey ShaderFlagsAnalysis::Key;
 
-ComputedShaderFlags ShaderFlagsAnalysis::run(Module &M,
-                                             ModuleAnalysisManager &AM) {
-  return ComputedShaderFlags::computeFlags(M);
+DXILModuleShaderFlagsInfo ShaderFlagsAnalysis::run(Module &M,
+                                                   ModuleAnalysisManager &AM) {
+  return computeFlags(M);
+}
+
+bool ShaderFlagsAnalysisWrapper::runOnModule(Module &M) {
+  MSFI = computeFlags(M);
+  return false;
 }
 
 PreservedAnalyses ShaderFlagsAnalysisPrinter::run(Module &M,
                                                   ModuleAnalysisManager &AM) {
-  ComputedShaderFlags Flags = AM.getResult<ShaderFlagsAnalysis>(M);
-  Flags.print(OS);
+  DXILModuleShaderFlagsInfo Flags = AM.getResult<ShaderFlagsAnalysis>(M);
+  OS << "; Shader Flags mask for Module:\n";
+  Flags.ModuleFlags.print(OS);
+  for (auto SF : Flags.FuncShaderFlagsMap) {
+    OS << "; Shader Flags mash for Function: " << SF.first->getName() << "\n";
+    SF.second.print(OS);
+  }
   return PreservedAnalyses::all();
 }
 
diff --git a/llvm/lib/Target/DirectX/DXILShaderFlags.h b/llvm/lib/Target/DirectX/DXILShaderFlags.h
index 1df7d27de13d3c..6f81ff74384d0c 100644
--- a/llvm/lib/Target/DirectX/DXILShaderFlags.h
+++ b/llvm/lib/Target/DirectX/DXILShaderFlags.h
@@ -14,6 +14,8 @@
 #ifndef LLVM_TARGET_DIRECTX_DXILSHADERFLAGS_H
 #define LLVM_TARGET_DIRECTX_DXILSHADERFLAGS_H
 
+#include "llvm/ADT/DenseMap.h"
+#include "llvm/IR/Function.h"
 #include "llvm/IR/PassManager.h"
 #include "llvm/Pass.h"
 #include "llvm/Support/Compiler.h"
@@ -60,11 +62,20 @@ struct ComputedShaderFlags {
     return FeatureFlags;
   }
 
-  static ComputedShaderFlags computeFlags(Module &M);
   void print(raw_ostream &OS = dbgs()) const;
   LLVM_DUMP_METHOD void dump() const { print(); }
 };
 
+using FunctionShaderFlagsMap =
+    SmallDenseMap<Function const *, ComputedShaderFlags>;
+struct DXILModuleShaderFlagsInfo {
+  // Shader Flag mask representing module-level properties
+  ComputedShaderFlags ModuleFlags;
+  // Map representing shader flag mask representing properties of each of the
+  // functions in the module
+  FunctionShaderFlagsMap FuncShaderFlagsMap;
+};
+
 class ShaderFlagsAnalysis : public AnalysisInfoMixin<ShaderFlagsAnalysis> {
   friend AnalysisInfoMixin<ShaderFlagsAnalysis>;
   static AnalysisKey Key;
@@ -72,9 +83,9 @@ class ShaderFlagsAnalysis : public AnalysisInfoMixin<ShaderFlagsAnalysis> {
 public:
   ShaderFlagsAnalysis() = default;
 
-  using Result = ComputedShaderFlags;
+  using Result = DXILModuleShaderFlagsInfo;
 
-  ComputedShaderFlags run(Module &M, ModuleAnalysisManager &AM);
+  DXILModuleShaderFlagsInfo run(Module &M, ModuleAnalysisManager &AM);
 };
 
 /// Printer pass for ShaderFlagsAnalysis results.
@@ -92,19 +103,16 @@ class ShaderFlagsAnalysisPrinter
 /// This is required because the passes that will depend on this are codegen
 /// passes which run through the legacy pass manager.
 class ShaderFlagsAnalysisWrapper : public ModulePass {
-  ComputedShaderFlags Flags;
+  DXILModuleShaderFlagsInfo MSFI;
 
 public:
   static char ID;
 
   ShaderFlagsAnalysisWrapper() : ModulePass(ID) {}
 
-  const ComputedShaderFlags &getShaderFlags() { return Flags; }
+  const DXILModuleShaderFlagsInfo &getShaderFlags() { return MSFI; }
 
-  bool runOnModule(Module &M) override {
-    Flags = ComputedShaderFlags::computeFlags(M);
-    return false;
-  }
+  bool runOnModule(Module &M) override;
 
   void getAnalysisUsage(AnalysisUsage &AU) const override {
     AU.setPreservesAll();
diff --git a/llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp b/llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp
index be370e10df6943..2da4fe83a066c2 100644
--- a/llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp
+++ b/llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp
@@ -286,11 +286,6 @@ static MDTuple *emitTopLevelLibraryNode(Module &M, MDNode *RMD,
   MDTuple *Properties = nullptr;
   if (ShaderFlags != 0) {
     SmallVector<Metadata *> MDVals;
-    // FIXME: ShaderFlagsAnalysis pass needs to collect and provide
-    // ShaderFlags for each entry function. Currently, ShaderFlags value
-    // provided by ShaderFlagsAnalysis pass is created by walking *all* the
-    // function instructions of the module. Is it is correct to use this value
-    // for metadata of the empty library entry?
     MDVals.append(
         getTagValueAsMetadata(EntryPropsTag::ShaderFlags, ShaderFlags, Ctx));
     Properties = MDNode::get(Ctx, MDVals);
@@ -302,7 +297,7 @@ static MDTuple *emitTopLevelLibraryNode(Module &M, MDNode *RMD,
 
 static void translateMetadata(Module &M, const DXILResourceMap &DRM,
                               const Resources &MDResources,
-                              const ComputedShaderFlags &ShaderFlags,
+                              const DXILModuleShaderFlagsInfo &ShaderFlags,
                               const ModuleMetadataInfo &MMDI) {
   LLVMContext &Ctx = M.getContext();
   IRBuilder<> IRB(Ctx);
@@ -318,22 +313,37 @@ static void translateMetadata(Module &M, const DXILResourceMap &DRM,
   // See https://github.com/llvm/llvm-project/issues/57928
   MDTuple *Signatures = nullptr;
 
-  if (MMDI.ShaderProfile == Triple::EnvironmentType::Library)
+  if (MMDI.ShaderProfile == Triple::EnvironmentType::Library) {
+    // Create a consolidated shader flag mask of all functions in the library
+    // to be used as shader flags mask value associated with top-level library
+    // entry metadata.
+    uint64_t ConsolidatedMask = ShaderFlags.ModuleFlags;
+    for (const auto &FunFlags : ShaderFlags.FuncShaderFlagsMap) {
+      ConsolidatedMask |= FunFlags.second;
+    }
     EntryFnMDNodes.emplace_back(
-        emitTopLevelLibraryNode(M, ResourceMD, ShaderFlags));
-  else if (MMDI.EntryPropertyVec.size() > 1) {
+        emitTopLevelLibraryNode(M, ResourceMD, ConsolidatedMask));
+  } else if (MMDI.EntryPropertyVec.size() > 1) {
     M.getContext().diagnose(DiagnosticInfoTranslateMD(
         M, "Non-library shader: One and only one entry expected"));
   }
 
   for (const EntryProperties &EntryProp : MMDI.EntryPropertyVec) {
-    // FIXME: ShaderFlagsAnalysis pass needs to collect and provide
-    // ShaderFlags for each entry function. For now, assume shader flags value
-    // of entry functions being compiled for lib_* shader profile viz.,
-    // EntryPro.Entry is 0.
-    uint64_t EntryShaderFlags =
-        (MMDI.ShaderProfile == Triple::EnvironmentType::Library) ? 0
-                                                                 : ShaderFlags;
+    auto FSFIt = ShaderFlags.FuncShaderFlagsMap.find(EntryProp.Entry);
+    if (FSFIt == ShaderFlags.FuncShaderFlagsMap.end()) {
+      M.getContext().diagnose(DiagnosticInfoTranslateMD(
+          M, "Shader Flags of Function '" + Twine(EntryProp.Entry->getName()) +
+                 "' not found"));
+    }
+    // If ShaderProfile is Library, mask is already consolidated in the
+    // top-level library node. Hence it is not emitted.
+    uint64_t EntryShaderFlags = 0;
+    if (MMDI.ShaderProfile != Triple::EnvironmentType::Library) {
+      // TODO: Create a consolidated shader flag mask of all the entry
+      // functions and its callees. The following is correct only if
+      // (*FSIt).first has no call instructions.
+      EntryShaderFlags = (*FSFIt).second | ShaderFlags.ModuleFlags;
+    }
     if (MMDI.ShaderProfile != Triple::EnvironmentType::Library) {
       if (EntryProp.ShaderStage != MMDI.ShaderProfile) {
         M.getContext().diagnose(DiagnosticInfoTranslateMD(
@@ -361,7 +371,7 @@ PreservedAnalyses DXILTranslateMetadata::run(Module &M,
                                              ModuleAnalysisManager &MAM) {
   const DXILResourceMap &DRM = MAM.getResult<DXILResourceAnalysis>(M);
   const dxil::Resources &MDResources = MAM.getResult<DXILResourceMDAnalysis>(M);
-  const ComputedShaderFlags &ShaderFlags =
+  const DXILModuleShaderFlagsInfo &ShaderFlags =
       MAM.getResult<ShaderFlagsAnalysis>(M);
   const dxil::ModuleMetadataInfo MMDI = MAM.getResult<DXILMetadataAnalysis>(M);
 
@@ -393,7 +403,7 @@ class DXILTranslateMetadataLegacy : public ModulePass {
         getAnalysis<DXILResourceWrapperPass>().getResourceMap();
     const dxil::Resources &MDResources =
         getAnalysis<DXILResourceMDWrapper>().getDXILResource();
-    const ComputedShaderFlags &ShaderFlags =
+    const DXILModuleShaderFlagsInfo &ShaderFlags =
         getAnalysis<ShaderFlagsAnalysisWrapper>().getShaderFlags();
     dxil::ModuleMetadataInfo MMDI =
         getAnalysis<DXILMetadataAnalysisWrapperPass>().getModuleMetadata();

Copy link

github-actions bot commented Oct 18, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
Status: No status
Development

Successfully merging this pull request may close these issues.

2 participants