Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[NFC][DirectX] Infrastructure to collect shader flags for each function #112967

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 10 additions & 5 deletions llvm/lib/Target/DirectX/DXContainerGlobals.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,13 +78,18 @@ bool DXContainerGlobals::runOnModule(Module &M) {
}

GlobalVariable *DXContainerGlobals::getFeatureFlags(Module &M) {
const uint64_t FeatureFlags =
static_cast<uint64_t>(getAnalysis<ShaderFlagsAnalysisWrapper>()
.getShaderFlags()
.getFeatureFlags());
const DXILModuleShaderFlagsInfo &MSFI =
getAnalysis<ShaderFlagsAnalysisWrapper>().getShaderFlags();
// TODO: Feature flags mask is obtained as a collection of feature flags
// of the shader flags of all functions in the module. Need to verify
// and modify the computation of feature flags to be used.
uint64_t ConsolidatedFeatureFlags = 0;
for (const auto &FuncFlags : MSFI.FuncShaderFlagsMap) {
ConsolidatedFeatureFlags |= FuncFlags.second.getFeatureFlags();
}

Constant *FeatureFlagsConstant =
ConstantInt::get(M.getContext(), APInt(64, FeatureFlags));
ConstantInt::get(M.getContext(), APInt(64, ConsolidatedFeatureFlags));
return buildContainerGlobal(M, FeatureFlagsConstant, "dx.sfi0", "SFI0");
}

Expand Down
44 changes: 31 additions & 13 deletions llvm/lib/Target/DirectX/DXILShaderFlags.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,29 +20,37 @@
using namespace llvm;
using namespace llvm::dxil;

static void updateFlags(ComputedShaderFlags &Flags, const Instruction &I) {
static void updateFlags(DXILModuleShaderFlagsInfo &MSFI, const Instruction &I) {
ComputedShaderFlags &FSF = MSFI.FuncShaderFlagsMap[I.getFunction()];
Type *Ty = I.getType();
if (Ty->isDoubleTy()) {
Flags.Doubles = true;
FSF.Doubles = true;
switch (I.getOpcode()) {
case Instruction::FDiv:
case Instruction::UIToFP:
case Instruction::SIToFP:
case Instruction::FPToUI:
case Instruction::FPToSI:
Flags.DX11_1_DoubleExtensions = true;
FSF.DX11_1_DoubleExtensions = true;
break;
}
}
}

ComputedShaderFlags ComputedShaderFlags::computeFlags(Module &M) {
ComputedShaderFlags Flags;
for (const auto &F : M)
static DXILModuleShaderFlagsInfo computeFlags(Module &M) {
DXILModuleShaderFlagsInfo MSFI;
for (const auto &F : M) {
if (F.isDeclaration())
continue;
if (!MSFI.FuncShaderFlagsMap.contains(&F)) {
ComputedShaderFlags CSF{};
MSFI.FuncShaderFlagsMap[&F] = CSF;
}
for (const auto &BB : F)
for (const auto &I : BB)
updateFlags(Flags, I);
return Flags;
updateFlags(MSFI, I);
}
return MSFI;
}

void ComputedShaderFlags::print(raw_ostream &OS) const {
Expand All @@ -65,15 +73,25 @@ void ComputedShaderFlags::print(raw_ostream &OS) const {

AnalysisKey ShaderFlagsAnalysis::Key;

ComputedShaderFlags ShaderFlagsAnalysis::run(Module &M,
ModuleAnalysisManager &AM) {
return ComputedShaderFlags::computeFlags(M);
DXILModuleShaderFlagsInfo ShaderFlagsAnalysis::run(Module &M,
ModuleAnalysisManager &AM) {
return computeFlags(M);
}

bool ShaderFlagsAnalysisWrapper::runOnModule(Module &M) {
MSFI = computeFlags(M);
return false;
}

PreservedAnalyses ShaderFlagsAnalysisPrinter::run(Module &M,
ModuleAnalysisManager &AM) {
ComputedShaderFlags Flags = AM.getResult<ShaderFlagsAnalysis>(M);
Flags.print(OS);
DXILModuleShaderFlagsInfo Flags = AM.getResult<ShaderFlagsAnalysis>(M);
OS << "; Shader Flags mask for Module:\n";
Flags.ModuleFlags.print(OS);
for (auto SF : Flags.FuncShaderFlagsMap) {
OS << "; Shader Flags mash for Function: " << SF.first->getName() << "\n";
SF.second.print(OS);
}
return PreservedAnalyses::all();
}

Expand Down
26 changes: 17 additions & 9 deletions llvm/lib/Target/DirectX/DXILShaderFlags.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
#ifndef LLVM_TARGET_DIRECTX_DXILSHADERFLAGS_H
#define LLVM_TARGET_DIRECTX_DXILSHADERFLAGS_H

#include "llvm/ADT/DenseMap.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/PassManager.h"
#include "llvm/Pass.h"
#include "llvm/Support/Compiler.h"
Expand Down Expand Up @@ -60,21 +62,30 @@ struct ComputedShaderFlags {
return FeatureFlags;
}

static ComputedShaderFlags computeFlags(Module &M);
void print(raw_ostream &OS = dbgs()) const;
LLVM_DUMP_METHOD void dump() const { print(); }
};

using FunctionShaderFlagsMap =
SmallDenseMap<Function const *, ComputedShaderFlags>;
struct DXILModuleShaderFlagsInfo {
// Shader Flag mask representing module-level properties
ComputedShaderFlags ModuleFlags;
// Map representing shader flag mask representing properties of each of the
// functions in the module
FunctionShaderFlagsMap FuncShaderFlagsMap;
};

class ShaderFlagsAnalysis : public AnalysisInfoMixin<ShaderFlagsAnalysis> {
friend AnalysisInfoMixin<ShaderFlagsAnalysis>;
static AnalysisKey Key;

public:
ShaderFlagsAnalysis() = default;

using Result = ComputedShaderFlags;
using Result = DXILModuleShaderFlagsInfo;

ComputedShaderFlags run(Module &M, ModuleAnalysisManager &AM);
DXILModuleShaderFlagsInfo run(Module &M, ModuleAnalysisManager &AM);
};

/// Printer pass for ShaderFlagsAnalysis results.
Expand All @@ -92,19 +103,16 @@ class ShaderFlagsAnalysisPrinter
/// This is required because the passes that will depend on this are codegen
/// passes which run through the legacy pass manager.
class ShaderFlagsAnalysisWrapper : public ModulePass {
ComputedShaderFlags Flags;
DXILModuleShaderFlagsInfo MSFI;

public:
static char ID;

ShaderFlagsAnalysisWrapper() : ModulePass(ID) {}

const ComputedShaderFlags &getShaderFlags() { return Flags; }
const DXILModuleShaderFlagsInfo &getShaderFlags() { return MSFI; }

bool runOnModule(Module &M) override {
Flags = ComputedShaderFlags::computeFlags(M);
return false;
}
bool runOnModule(Module &M) override;

void getAnalysisUsage(AnalysisUsage &AU) const override {
AU.setPreservesAll();
Expand Down
46 changes: 28 additions & 18 deletions llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -286,11 +286,6 @@ static MDTuple *emitTopLevelLibraryNode(Module &M, MDNode *RMD,
MDTuple *Properties = nullptr;
if (ShaderFlags != 0) {
SmallVector<Metadata *> MDVals;
// FIXME: ShaderFlagsAnalysis pass needs to collect and provide
// ShaderFlags for each entry function. Currently, ShaderFlags value
// provided by ShaderFlagsAnalysis pass is created by walking *all* the
// function instructions of the module. Is it is correct to use this value
// for metadata of the empty library entry?
MDVals.append(
getTagValueAsMetadata(EntryPropsTag::ShaderFlags, ShaderFlags, Ctx));
Properties = MDNode::get(Ctx, MDVals);
Expand All @@ -302,7 +297,7 @@ static MDTuple *emitTopLevelLibraryNode(Module &M, MDNode *RMD,

static void translateMetadata(Module &M, const DXILResourceMap &DRM,
const Resources &MDResources,
const ComputedShaderFlags &ShaderFlags,
const DXILModuleShaderFlagsInfo &ShaderFlags,
const ModuleMetadataInfo &MMDI) {
LLVMContext &Ctx = M.getContext();
IRBuilder<> IRB(Ctx);
Expand All @@ -318,22 +313,37 @@ static void translateMetadata(Module &M, const DXILResourceMap &DRM,
// See https://github.com/llvm/llvm-project/issues/57928
MDTuple *Signatures = nullptr;

if (MMDI.ShaderProfile == Triple::EnvironmentType::Library)
if (MMDI.ShaderProfile == Triple::EnvironmentType::Library) {
// Create a consolidated shader flag mask of all functions in the library
// to be used as shader flags mask value associated with top-level library
// entry metadata.
uint64_t ConsolidatedMask = ShaderFlags.ModuleFlags;
for (const auto &FunFlags : ShaderFlags.FuncShaderFlagsMap) {
ConsolidatedMask |= FunFlags.second;
}
EntryFnMDNodes.emplace_back(
emitTopLevelLibraryNode(M, ResourceMD, ShaderFlags));
else if (MMDI.EntryPropertyVec.size() > 1) {
emitTopLevelLibraryNode(M, ResourceMD, ConsolidatedMask));
} else if (MMDI.EntryPropertyVec.size() > 1) {
M.getContext().diagnose(DiagnosticInfoTranslateMD(
M, "Non-library shader: One and only one entry expected"));
}

for (const EntryProperties &EntryProp : MMDI.EntryPropertyVec) {
// FIXME: ShaderFlagsAnalysis pass needs to collect and provide
// ShaderFlags for each entry function. For now, assume shader flags value
// of entry functions being compiled for lib_* shader profile viz.,
// EntryPro.Entry is 0.
uint64_t EntryShaderFlags =
(MMDI.ShaderProfile == Triple::EnvironmentType::Library) ? 0
: ShaderFlags;
auto FSFIt = ShaderFlags.FuncShaderFlagsMap.find(EntryProp.Entry);
if (FSFIt == ShaderFlags.FuncShaderFlagsMap.end()) {
M.getContext().diagnose(DiagnosticInfoTranslateMD(
M, "Shader Flags of Function '" + Twine(EntryProp.Entry->getName()) +
"' not found"));
}
// If ShaderProfile is Library, mask is already consolidated in the
// top-level library node. Hence it is not emitted.
uint64_t EntryShaderFlags = 0;
if (MMDI.ShaderProfile != Triple::EnvironmentType::Library) {
// TODO: Create a consolidated shader flag mask of all the entry
// functions and its callees. The following is correct only if
// (*FSIt).first has no call instructions.
EntryShaderFlags = (*FSFIt).second | ShaderFlags.ModuleFlags;
}
if (MMDI.ShaderProfile != Triple::EnvironmentType::Library) {
if (EntryProp.ShaderStage != MMDI.ShaderProfile) {
M.getContext().diagnose(DiagnosticInfoTranslateMD(
Expand Down Expand Up @@ -361,7 +371,7 @@ PreservedAnalyses DXILTranslateMetadata::run(Module &M,
ModuleAnalysisManager &MAM) {
const DXILResourceMap &DRM = MAM.getResult<DXILResourceAnalysis>(M);
const dxil::Resources &MDResources = MAM.getResult<DXILResourceMDAnalysis>(M);
const ComputedShaderFlags &ShaderFlags =
const DXILModuleShaderFlagsInfo &ShaderFlags =
MAM.getResult<ShaderFlagsAnalysis>(M);
const dxil::ModuleMetadataInfo MMDI = MAM.getResult<DXILMetadataAnalysis>(M);

Expand Down Expand Up @@ -393,7 +403,7 @@ class DXILTranslateMetadataLegacy : public ModulePass {
getAnalysis<DXILResourceWrapperPass>().getResourceMap();
const dxil::Resources &MDResources =
getAnalysis<DXILResourceMDWrapper>().getDXILResource();
const ComputedShaderFlags &ShaderFlags =
const DXILModuleShaderFlagsInfo &ShaderFlags =
getAnalysis<ShaderFlagsAnalysisWrapper>().getShaderFlags();
dxil::ModuleMetadataInfo MMDI =
getAnalysis<DXILMetadataAnalysisWrapperPass>().getModuleMetadata();
Expand Down
Loading