Skip to content

Commit

Permalink
Merge pull request #328 from hvdijk/manual-type-legalization-fma
Browse files Browse the repository at this point in the history
Include FMA in manual type legalization.
  • Loading branch information
hvdijk authored Feb 2, 2024
2 parents b1d1e75 + 84d055b commit 9a570d4
Showing 1 changed file with 84 additions and 41 deletions.
125 changes: 84 additions & 41 deletions modules/compiler/utils/source/manual_type_legalization_pass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include <llvm/IR/IRBuilder.h>
#include <llvm/IR/InstrTypes.h>
#include <llvm/IR/Instructions.h>
#include <llvm/IR/IntrinsicInst.h>
#include <llvm/IR/Module.h>
#include <llvm/IR/PassManager.h>
#include <llvm/IR/Type.h>
Expand All @@ -31,25 +32,39 @@ using namespace llvm;

PreservedAnalyses compiler::utils::ManualTypeLegalizationPass::run(
Function &F, FunctionAnalysisManager &FAM) {
auto &TTI = FAM.getResult<TargetIRAnalysis>(F);

auto *HalfT = Type::getHalfTy(F.getContext());
auto *FloatT = Type::getFloatTy(F.getContext());

// Targets where half is a legal type do not need this pass. Targets where
// half is promoted using "soft promotion" rules also do not need this pass.
// We cannot reliably determine which targets these are, but that is okay, on
// targets where this pass is not needed it does no harm, it merely wastes
// time.
auto *DoubleT = Type::getDoubleTy(F.getContext());

// Targets where half is a legal type, and targets where half is promoted
// using "soft promotion" rules, are assumed to implement basic operators
// correctly. We cannot reliably determine which targets use "soft promotion"
// rules so we hardcode the list here.
//
// FMA is promoted incorrectly on all targets without hardware support, even
// when using "soft promotion" rules; only targets that have native support
// implement it correctly at the moment.
//
// Both for operators and FMA, whether the target implements the operation
// correctly may depend on the target feature string. We ignore that here for
// simplicity.
const llvm::Triple TT(F.getParent()->getTargetTriple());
if (TTI.isTypeLegal(HalfT) || TT.isX86() || TT.isRISCV()) {

auto &TTI = FAM.getResult<TargetIRAnalysis>(F);
const bool HaveCorrectHalfOps =
TTI.isTypeLegal(HalfT) || TT.isX86() || TT.isRISCV();
const bool HaveCorrectHalfFMA = TT.isRISCV();
if (HaveCorrectHalfOps && HaveCorrectHalfFMA) {
return PreservedAnalyses::all();
}

DenseMap<Value *, Value *> FPExtVals;
IRBuilder<> B(F.getContext());

auto CreateFPExt = [&](Value *V, Type *ExtTy) {
auto CreateFPExt = [&](Value *V, Type *Ty, Type *ExtTy) {
(void)Ty;
assert(V->getType() == Ty &&
"Expected matching types for floating point operation");
auto *&FPExt = FPExtVals[V];
if (!FPExt) {
if (auto *I = dyn_cast<Instruction>(V)) {
Expand Down Expand Up @@ -78,43 +93,71 @@ PreservedAnalyses compiler::utils::ManualTypeLegalizationPass::run(

for (auto &BB : F) {
for (auto &I : make_early_inc_range(BB)) {
auto *BO = dyn_cast<BinaryOperator>(&I);
if (!BO) continue;

auto *T = BO->getType();
auto *T = I.getType();
auto *VecT = dyn_cast<VectorType>(T);
auto *ElT = VecT ? VecT->getElementType() : T;

if (ElT != HalfT) continue;

auto *LHS = BO->getOperand(0);
auto *RHS = BO->getOperand(1);
assert(LHS->getType() == T &&
"Expected matching types for floating point operation");
assert(RHS->getType() == T &&
"Expected matching types for floating point operation");

auto *ExtElT = FloatT;
auto *ExtT =
VecT ? VectorType::get(ExtElT, VecT->getElementCount()) : ExtElT;

auto *LHSExt = CreateFPExt(LHS, ExtT);
auto *RHSExt = CreateFPExt(RHS, ExtT);

B.SetInsertPoint(BO);

B.setFastMathFlags(BO->getFastMathFlags());
auto *OpExt = B.CreateBinOp(BO->getOpcode(), LHSExt, RHSExt,
BO->getName() + ".fpext");
B.clearFastMathFlags();

auto *Trunc = B.CreateFPTrunc(OpExt, T);
Trunc->takeName(BO);

BO->replaceAllUsesWith(Trunc);
BO->eraseFromParent();
if (!HaveCorrectHalfOps) {
if (auto *BO = dyn_cast<BinaryOperator>(&I)) {
Type *const ExtElT = FloatT;
Type *const ExtT =
VecT ? VectorType::get(ExtElT, VecT->getElementCount()) : ExtElT;
Value *const PromotedOperands[] = {
CreateFPExt(BO->getOperand(0), T, ExtT),
CreateFPExt(BO->getOperand(1), T, ExtT),
};
B.SetInsertPoint(BO);
B.setFastMathFlags(BO->getFastMathFlags());
auto *const PromotedOperation =
B.CreateBinOp(BO->getOpcode(), PromotedOperands[0],
PromotedOperands[1], BO->getName() + ".fpext");
B.clearFastMathFlags();

auto *const Trunc = B.CreateFPTrunc(PromotedOperation, T);
Trunc->takeName(BO);

BO->replaceAllUsesWith(Trunc);
BO->eraseFromParent();

Changed = true;
continue;
}
}

Changed = true;
if (!HaveCorrectHalfFMA) {
if (auto *II = dyn_cast<IntrinsicInst>(&I)) {
if (II->getIntrinsicID() == Intrinsic::fma) {
Type *const ExtElT = DoubleT;
Type *const ExtT =
VecT ? VectorType::get(ExtElT, VecT->getElementCount())
: ExtElT;
Value *const PromotedArguments[] = {
CreateFPExt(II->getArgOperand(0), T, ExtT),
CreateFPExt(II->getArgOperand(1), T, ExtT),
CreateFPExt(II->getArgOperand(2), T, ExtT),
};
B.SetInsertPoint(II);
// Because the arguments are promoted halfs, the multiplication in
// type double is exact and the result is the same even if multiply
// and add are kept as separate operations, so use FMulAdd rather
// than FMA.
auto *const PromotedOperation =
B.CreateIntrinsic(ExtT, Intrinsic::fmuladd, PromotedArguments,
II, II->getName() + ".fpext");

auto *const Trunc = B.CreateFPTrunc(PromotedOperation, T);
Trunc->takeName(II);

II->replaceAllUsesWith(Trunc);
II->eraseFromParent();

Changed = true;
continue;
}
}
}
}
}

Expand Down

0 comments on commit 9a570d4

Please sign in to comment.