Skip to content

Commit

Permalink
[CodeGen] change prototype of regalloc filter function
Browse files Browse the repository at this point in the history
change the prototype of the filter function so that we can
filter not just by RegClass. We need to implement more
complicated filter based upon some other info associated
with each register.

Patch provided by: Gang Chen ([email protected])
  • Loading branch information
cdevadas committed Jun 27, 2024
1 parent 9a9ec22 commit 29e59f4
Show file tree
Hide file tree
Showing 13 changed files with 54 additions and 48 deletions.
6 changes: 3 additions & 3 deletions llvm/include/llvm/CodeGen/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -205,20 +205,20 @@ namespace llvm {
/// possible. It is best suited for debug code where live ranges are short.
///
FunctionPass *createFastRegisterAllocator();
FunctionPass *createFastRegisterAllocator(RegClassFilterFunc F,
FunctionPass *createFastRegisterAllocator(RegAllocFilterFunc F,
bool ClearVirtRegs);

/// BasicRegisterAllocation Pass - This pass implements a degenerate global
/// register allocator using the basic regalloc framework.
///
FunctionPass *createBasicRegisterAllocator();
FunctionPass *createBasicRegisterAllocator(RegClassFilterFunc F);
FunctionPass *createBasicRegisterAllocator(RegAllocFilterFunc F);

/// Greedy register allocation pass - This pass implements a global register
/// allocator for optimized builds.
///
FunctionPass *createGreedyRegisterAllocator();
FunctionPass *createGreedyRegisterAllocator(RegClassFilterFunc F);
FunctionPass *createGreedyRegisterAllocator(RegAllocFilterFunc F);

/// PBQPRegisterAllocation Pass - This pass implements the Partitioned Boolean
/// Quadratic Prograaming (PBQP) based register allocator.
Expand Down
6 changes: 4 additions & 2 deletions llvm/include/llvm/CodeGen/RegAllocCommon.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,20 @@
#ifndef LLVM_CODEGEN_REGALLOCCOMMON_H
#define LLVM_CODEGEN_REGALLOCCOMMON_H

#include "llvm/CodeGen/Register.h"
#include <functional>

namespace llvm {

class TargetRegisterClass;
class TargetRegisterInfo;
class MachineRegisterInfo;

/// Filter function for register classes during regalloc. Default register class
/// filter is nullptr, where all registers should be allocated.
typedef std::function<bool(const TargetRegisterInfo &TRI,
const TargetRegisterClass &RC)>
RegClassFilterFunc;
const MachineRegisterInfo &MRI, const Register Reg)>
RegAllocFilterFunc;
}

#endif // LLVM_CODEGEN_REGALLOCCOMMON_H
2 changes: 1 addition & 1 deletion llvm/include/llvm/CodeGen/RegAllocFast.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
namespace llvm {

struct RegAllocFastPassOptions {
RegClassFilterFunc Filter = nullptr;
RegAllocFilterFunc Filter = nullptr;
StringRef FilterName = "all";
bool ClearVRegs = true;
};
Expand Down
10 changes: 5 additions & 5 deletions llvm/include/llvm/Passes/PassBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -390,9 +390,9 @@ class PassBuilder {
/// returns false.
Error parseAAPipeline(AAManager &AA, StringRef PipelineText);

/// Parse RegClassFilterName to get RegClassFilterFunc.
std::optional<RegClassFilterFunc>
parseRegAllocFilter(StringRef RegClassFilterName);
/// Parse RegAllocFilterName to get RegAllocFilterFunc.
std::optional<RegAllocFilterFunc>
parseRegAllocFilter(StringRef RegAllocFilterName);

/// Print pass names.
void printPassNames(raw_ostream &OS);
Expand Down Expand Up @@ -586,7 +586,7 @@ class PassBuilder {
/// needs it. E.g. AMDGPU requires regalloc passes can handle sgpr and vgpr
/// separately.
void registerRegClassFilterParsingCallback(
const std::function<RegClassFilterFunc(StringRef)> &C) {
const std::function<RegAllocFilterFunc(StringRef)> &C) {
RegClassFilterParsingCallbacks.push_back(C);
}

Expand Down Expand Up @@ -807,7 +807,7 @@ class PassBuilder {
2>
MachineFunctionPipelineParsingCallbacks;
// Callbacks to parse `filter` parameter in register allocation passes
SmallVector<std::function<RegClassFilterFunc(StringRef)>, 2>
SmallVector<std::function<RegAllocFilterFunc(StringRef)>, 2>
RegClassFilterParsingCallbacks;
};

Expand Down
9 changes: 5 additions & 4 deletions llvm/lib/CodeGen/RegAllocBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ class RegAllocBase {

private:
/// Private, callees should go through shouldAllocateRegister
const RegClassFilterFunc ShouldAllocateClass;
const RegAllocFilterFunc shouldAllocateRegisterImpl;

protected:
/// Inst which is a def of an original reg and whose defs are already all
Expand All @@ -81,7 +81,8 @@ class RegAllocBase {
/// always available for the remat of all the siblings of the original reg.
SmallPtrSet<MachineInstr *, 32> DeadRemats;

RegAllocBase(const RegClassFilterFunc F = nullptr) : ShouldAllocateClass(F) {}
RegAllocBase(const RegAllocFilterFunc F = nullptr)
: shouldAllocateRegisterImpl(F) {}

virtual ~RegAllocBase() = default;

Expand All @@ -90,9 +91,9 @@ class RegAllocBase {

/// Get whether a given register should be allocated
bool shouldAllocateRegister(Register Reg) {
if (!ShouldAllocateClass)
if (!shouldAllocateRegisterImpl)
return true;
return ShouldAllocateClass(*TRI, *MRI->getRegClass(Reg));
return shouldAllocateRegisterImpl(*TRI, *MRI, Reg);
}

// The top-level driver. The output is a VirtRegMap that us updated with
Expand Down
10 changes: 4 additions & 6 deletions llvm/lib/CodeGen/RegAllocBasic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ class RABasic : public MachineFunctionPass,
void LRE_WillShrinkVirtReg(Register) override;

public:
RABasic(const RegClassFilterFunc F = nullptr);
RABasic(const RegAllocFilterFunc F = nullptr);

/// Return the pass name.
StringRef getPassName() const override { return "Basic Register Allocator"; }
Expand Down Expand Up @@ -168,10 +168,8 @@ void RABasic::LRE_WillShrinkVirtReg(Register VirtReg) {
enqueue(&LI);
}

RABasic::RABasic(RegClassFilterFunc F):
MachineFunctionPass(ID),
RegAllocBase(F) {
}
RABasic::RABasic(RegAllocFilterFunc F)
: MachineFunctionPass(ID), RegAllocBase(F) {}

void RABasic::getAnalysisUsage(AnalysisUsage &AU) const {
AU.setPreservesCFG();
Expand Down Expand Up @@ -334,6 +332,6 @@ FunctionPass* llvm::createBasicRegisterAllocator() {
return new RABasic();
}

FunctionPass* llvm::createBasicRegisterAllocator(RegClassFilterFunc F) {
FunctionPass *llvm::createBasicRegisterAllocator(RegAllocFilterFunc F) {
return new RABasic(F);
}
16 changes: 8 additions & 8 deletions llvm/lib/CodeGen/RegAllocFast.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -177,9 +177,9 @@ class InstrPosIndexes {

class RegAllocFastImpl {
public:
RegAllocFastImpl(const RegClassFilterFunc F = nullptr,
RegAllocFastImpl(const RegAllocFilterFunc F = nullptr,
bool ClearVirtRegs_ = true)
: ShouldAllocateClass(F), StackSlotForVirtReg(-1),
: ShouldAllocateRegisterImpl(F), StackSlotForVirtReg(-1),
ClearVirtRegs(ClearVirtRegs_) {}

private:
Expand All @@ -188,7 +188,7 @@ class RegAllocFastImpl {
const TargetRegisterInfo *TRI = nullptr;
const TargetInstrInfo *TII = nullptr;
RegisterClassInfo RegClassInfo;
const RegClassFilterFunc ShouldAllocateClass;
const RegAllocFilterFunc ShouldAllocateRegisterImpl;

/// Basic block currently being allocated.
MachineBasicBlock *MBB = nullptr;
Expand Down Expand Up @@ -397,7 +397,7 @@ class RegAllocFast : public MachineFunctionPass {
public:
static char ID;

RegAllocFast(const RegClassFilterFunc F = nullptr, bool ClearVirtRegs_ = true)
RegAllocFast(const RegAllocFilterFunc F = nullptr, bool ClearVirtRegs_ = true)
: MachineFunctionPass(ID), Impl(F, ClearVirtRegs_) {}

bool runOnMachineFunction(MachineFunction &MF) override {
Expand Down Expand Up @@ -440,10 +440,10 @@ INITIALIZE_PASS(RegAllocFast, "regallocfast", "Fast Register Allocator", false,

bool RegAllocFastImpl::shouldAllocateRegister(const Register Reg) const {
assert(Reg.isVirtual());
if (!ShouldAllocateClass)
if (!ShouldAllocateRegisterImpl)
return true;
const TargetRegisterClass &RC = *MRI->getRegClass(Reg);
return ShouldAllocateClass(*TRI, RC);

return ShouldAllocateRegisterImpl(*TRI, *MRI, Reg);
}

void RegAllocFastImpl::setPhysRegState(MCPhysReg PhysReg, unsigned NewState) {
Expand Down Expand Up @@ -1841,7 +1841,7 @@ void RegAllocFastPass::printPipeline(

FunctionPass *llvm::createFastRegisterAllocator() { return new RegAllocFast(); }

FunctionPass *llvm::createFastRegisterAllocator(RegClassFilterFunc Ftor,
FunctionPass *llvm::createFastRegisterAllocator(RegAllocFilterFunc Ftor,
bool ClearVirtRegs) {
return new RegAllocFast(Ftor, ClearVirtRegs);
}
10 changes: 4 additions & 6 deletions llvm/lib/CodeGen/RegAllocGreedy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -192,14 +192,12 @@ FunctionPass* llvm::createGreedyRegisterAllocator() {
return new RAGreedy();
}

FunctionPass *llvm::createGreedyRegisterAllocator(RegClassFilterFunc Ftor) {
FunctionPass *llvm::createGreedyRegisterAllocator(RegAllocFilterFunc Ftor) {
return new RAGreedy(Ftor);
}

RAGreedy::RAGreedy(RegClassFilterFunc F):
MachineFunctionPass(ID),
RegAllocBase(F) {
}
RAGreedy::RAGreedy(RegAllocFilterFunc F)
: MachineFunctionPass(ID), RegAllocBase(F) {}

void RAGreedy::getAnalysisUsage(AnalysisUsage &AU) const {
AU.setPreservesCFG();
Expand Down Expand Up @@ -2306,7 +2304,7 @@ void RAGreedy::tryHintRecoloring(const LiveInterval &VirtReg) {
if (Reg.isPhysical())
continue;

// This may be a skipped class
// This may be a skipped register.
if (!VRM->hasPhys(Reg)) {
assert(!shouldAllocateRegister(Reg) &&
"We have an unallocated variable which should have been handled");
Expand Down
2 changes: 1 addition & 1 deletion llvm/lib/CodeGen/RegAllocGreedy.h
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ class LLVM_LIBRARY_VISIBILITY RAGreedy : public MachineFunctionPass,
bool ReverseLocalAssignment = false;

public:
RAGreedy(const RegClassFilterFunc F = nullptr);
RAGreedy(const RegAllocFilterFunc F = nullptr);

/// Return the pass name.
StringRef getPassName() const override { return "Greedy Register Allocator"; }
Expand Down
4 changes: 2 additions & 2 deletions llvm/lib/Passes/PassBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1159,7 +1159,7 @@ parseRegAllocFastPassOptions(PassBuilder &PB, StringRef Params) {
std::tie(ParamName, Params) = Params.split(';');

if (ParamName.consume_front("filter=")) {
std::optional<RegClassFilterFunc> Filter =
std::optional<RegAllocFilterFunc> Filter =
PB.parseRegAllocFilter(ParamName);
if (!Filter) {
return make_error<StringError>(
Expand Down Expand Up @@ -2169,7 +2169,7 @@ Error PassBuilder::parseAAPipeline(AAManager &AA, StringRef PipelineText) {
return Error::success();
}

std::optional<RegClassFilterFunc>
std::optional<RegAllocFilterFunc>
PassBuilder::parseRegAllocFilter(StringRef FilterName) {
if (FilterName == "all")
return nullptr;
Expand Down
15 changes: 9 additions & 6 deletions llvm/lib/Target/AMDGPU/AMDGPUTargetMachine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,16 +85,19 @@ class VGPRRegisterRegAlloc : public RegisterRegAllocBase<VGPRRegisterRegAlloc> {
};

static bool onlyAllocateSGPRs(const TargetRegisterInfo &TRI,
const TargetRegisterClass &RC) {
return static_cast<const SIRegisterInfo &>(TRI).isSGPRClass(&RC);
const MachineRegisterInfo &MRI,
const Register Reg) {
const TargetRegisterClass *RC = MRI.getRegClass(Reg);
return static_cast<const SIRegisterInfo &>(TRI).isSGPRClass(RC);
}

static bool onlyAllocateVGPRs(const TargetRegisterInfo &TRI,
const TargetRegisterClass &RC) {
return !static_cast<const SIRegisterInfo &>(TRI).isSGPRClass(&RC);
const MachineRegisterInfo &MRI,
const Register Reg) {
const TargetRegisterClass *RC = MRI.getRegClass(Reg);
return !static_cast<const SIRegisterInfo &>(TRI).isSGPRClass(RC);
}


/// -{sgpr|vgpr}-regalloc=... command line option.
static FunctionPass *useDefaultRegisterAllocator() { return nullptr; }

Expand Down Expand Up @@ -741,7 +744,7 @@ void AMDGPUTargetMachine::registerPassBuilderCallbacks(PassBuilder &PB) {
});

PB.registerRegClassFilterParsingCallback(
[](StringRef FilterName) -> RegClassFilterFunc {
[](StringRef FilterName) -> RegAllocFilterFunc {
if (FilterName == "sgpr")
return onlyAllocateSGPRs;
if (FilterName == "vgpr")
Expand Down
6 changes: 4 additions & 2 deletions llvm/lib/Target/RISCV/RISCVTargetMachine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -277,8 +277,10 @@ class RVVRegisterRegAlloc : public RegisterRegAllocBase<RVVRegisterRegAlloc> {
};

static bool onlyAllocateRVVReg(const TargetRegisterInfo &TRI,
const TargetRegisterClass &RC) {
return RISCVRegisterInfo::isRVVRegClass(&RC);
const MachineRegisterInfo &MRI,
const Register Reg) {
const TargetRegisterClass *RC = MRI.getRegClass(Reg);
return RISCVRegisterInfo::isRVVRegClass(RC);
}

static FunctionPass *useDefaultRegisterAllocator() { return nullptr; }
Expand Down
6 changes: 4 additions & 2 deletions llvm/lib/Target/X86/X86TargetMachine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -675,8 +675,10 @@ std::unique_ptr<CSEConfigBase> X86PassConfig::getCSEConfig() const {
}

static bool onlyAllocateTileRegisters(const TargetRegisterInfo &TRI,
const TargetRegisterClass &RC) {
return static_cast<const X86RegisterInfo &>(TRI).isTileRegisterClass(&RC);
const MachineRegisterInfo &MRI,
const Register Reg) {
const TargetRegisterClass *RC = MRI.getRegClass(Reg);
return static_cast<const X86RegisterInfo &>(TRI).isTileRegisterClass(RC);
}

bool X86PassConfig::addRegAssignAndRewriteOptimized() {
Expand Down

0 comments on commit 29e59f4

Please sign in to comment.