Skip to content

Commit

Permalink
Handle indirect function malloc (rust-lang#391)
Browse files Browse the repository at this point in the history
* Handle indirect function malloc

* Add sprintf
  • Loading branch information
wsmoses authored Dec 14, 2021
1 parent f7c9fe5 commit cb44221
Show file tree
Hide file tree
Showing 9 changed files with 182 additions and 25 deletions.
4 changes: 3 additions & 1 deletion enzyme/Enzyme/ActivityAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -112,10 +112,13 @@ const std::map<std::string, size_t> MPIInactiveCommAllocators = {
};

const std::set<std::string> KnownInactiveFunctions = {
"abort",
"__assert_fail",
"__cxa_guard_acquire",
"__cxa_guard_release",
"__cxa_guard_abort",
"snprintf",
"sprintf",
"printf",
"putchar",
"fprintf",
Expand Down Expand Up @@ -777,7 +780,6 @@ bool ActivityAnalyzer::isConstantValue(TypeResults &TR, Value *Val) {
// of the global
auto res = TR.query(GI).Data0();
auto dt = res[{-1}];
dt |= res[{0}];
if (dt.isIntegral()) {
if (EnzymePrintActivity)
llvm::errs() << " VALUE const as global int pointer " << *Val
Expand Down
5 changes: 5 additions & 0 deletions enzyme/Enzyme/AdjointGenerator.h
Original file line number Diff line number Diff line change
Expand Up @@ -2294,6 +2294,11 @@ class AdjointGenerator
}
}
}
EmitWarning("CannotDeduceType", MTI.getDebugLoc(), gutils->oldFunc,
MTI.getParent(), &MTI, "failed to deduce type of copy ",
MTI);
vd = TypeTree(BaseType::Pointer).Only(0);
goto known;
}
EmitFailure("CannotDeduceType", MTI.getDebugLoc(), &MTI,
"failed to deduce type of copy ", MTI);
Expand Down
11 changes: 11 additions & 0 deletions enzyme/Enzyme/DifferentialUseAnalysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,17 @@ static inline bool is_value_needed_in_reverse(
return seen[idx] = true;
}
}
#if LLVM_VERSION_MAJOR >= 11
const Value *F = CI->getCalledOperand();
#else
const Value *F = CI->getCalledValue();
#endif
if (F == inst) {
if (!gutils->isConstantInstruction(const_cast<Instruction *>(user)) ||
!gutils->isConstantValue(const_cast<Value *>((Value *)user))) {
return seen[idx] = true;
}
}
}

if (isa<ReturnInst>(user)) {
Expand Down
10 changes: 8 additions & 2 deletions enzyme/Enzyme/Enzyme.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -570,7 +570,14 @@ class Enzyme : public ModulePass {
}
}
}
if (!res->getType()->canLosslesslyBitCastTo(PTy)) {
if (res->getType()->canLosslesslyBitCastTo(PTy)) {
res = Builder.CreateBitCast(res, PTy);
}
if (res->getType() != PTy && res->getType()->isIntegerTy() &&
PTy->isIntegerTy(1)) {
res = Builder.CreateTrunc(res, PTy);
}
if (res->getType() != PTy) {
auto loc = CI->getDebugLoc();
if (auto arg = dyn_cast<Instruction>(res)) {
loc = arg->getDebugLoc();
Expand All @@ -581,7 +588,6 @@ class Enzyme : public ModulePass {
" - to arg ", truei, " ", *PTy);
return false;
}
res = Builder.CreateBitCast(res, PTy);
}
#if LLVM_VERSION_MAJOR >= 9
if (CI->isByValArgument(i)) {
Expand Down
34 changes: 25 additions & 9 deletions enzyme/Enzyme/EnzymeLogic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,11 @@ llvm::cl::opt<bool>
EnzymePrint("enzyme-print", cl::init(false), cl::Hidden,
cl::desc("Print before and after fns for autodiff"));

llvm::cl::opt<bool>
EnzymePrintUnnecessary("enzyme-print-unnecessary", cl::init(false),
cl::Hidden,
cl::desc("Print unnecessary values in function"));

cl::opt<bool> looseTypeAnalysis("enzyme-loose-types", cl::init(false),
cl::Hidden,
cl::desc("Allow looser use of types"));
Expand Down Expand Up @@ -910,16 +915,27 @@ void calculateUnusedValuesInFunction(
}
return UseReq::Recur;
});
#if 0
llvm::errs() << "unnecessaryValues of " << func.getName() << ": mode=" << to_string(mode) << "\n";
for (auto a : unnecessaryValues) {
llvm::errs() << *a << "\n";
}
llvm::errs() << "unnecessaryInstructions " << func.getName() << ":\n";
for (auto a : unnecessaryInstructions) {
llvm::errs() << *a << "\n";

if (EnzymePrintUnnecessary) {
llvm::errs() << "unnecessaryValues of " << func.getName()
<< ": mode=" << to_string(mode) << "\n";
for (auto a : unnecessaryValues) {
bool ivn = is_value_needed_in_reverse<ValueType::Primal>(
TR, gutils, a, mode, PrimalSeen, oldUnreachable);
bool isn = is_value_needed_in_reverse<ValueType::ShadowPtr>(
TR, gutils, a, mode, PrimalSeen, oldUnreachable);
llvm::errs() << *a << " ivn=" << (int)ivn << " isn: " << (int)isn;
auto found = gutils->knownRecomputeHeuristic.find(a);
if (found != gutils->knownRecomputeHeuristic.end()) {
llvm::errs() << " krc=" << (int)found->second;
}
llvm::errs() << "\n";
}
llvm::errs() << "unnecessaryInstructions " << func.getName() << ":\n";
for (auto a : unnecessaryInstructions) {
llvm::errs() << *a << "\n";
}
}
#endif
}

void calculateUnusedStoresInFunction(
Expand Down
39 changes: 39 additions & 0 deletions enzyme/Enzyme/GradientUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2639,8 +2639,47 @@ Constant *GradientUtils::GetOrCreateShadowFunction(EnzymeLogic &Logic,
// indirect augmented calls), topLevel MUST be true otherwise subcalls will
// not be able to lookup the augmenteddata/subdata (triggering an assertion
// failure, among much worse)
bool isRealloc = false;
if (fn->empty()) {
if (hasMetadata(fn, "enzyme_callwrapper")) {
auto md = fn->getMetadata("enzyme_callwrapper");
if (!isa<MDTuple>(md)) {
llvm::errs() << *fn << "\n";
llvm::errs() << *md << "\n";
assert(0 && "callwrapper of incorrect type");
report_fatal_error("callwrapper of incorrect type");
}
auto md2 = cast<MDTuple>(md);
assert(md2->getNumOperands() == 1);
auto gvemd = cast<ConstantAsMetadata>(md2->getOperand(0));
fn = cast<Function>(gvemd->getValue());
} else {
auto oldfn = fn;
fn = Function::Create(oldfn->getFunctionType(), Function::InternalLinkage,
"callwrap_" + oldfn->getName(), oldfn->getParent());
BasicBlock *entry = BasicBlock::Create(fn->getContext(), "entry", fn);
IRBuilder<> B(entry);
SmallVector<Value *, 4> args;
for (auto &a : fn->args())
args.push_back(&a);
auto res = B.CreateCall(oldfn, args);
if (fn->getReturnType()->isVoidTy())
B.CreateRetVoid();
else
B.CreateRet(res);
oldfn->setMetadata(
"enzyme_callwrapper",
MDTuple::get(oldfn->getContext(), {ConstantAsMetadata::get(fn)}));
if (oldfn->getName() == "realloc")
isRealloc = true;
}
}
std::map<Argument *, bool> uncacheable_args;
FnTypeInfo type_args(fn);
if (isRealloc) {
llvm::errs() << "warning: assuming realloc only creates pointers\n";
type_args.Return.insert({-1, -1}, BaseType::Pointer);
}

// conservatively assume that we can only cache existing floating types
// (i.e. that all args are uncacheable)
Expand Down
41 changes: 30 additions & 11 deletions enzyme/Enzyme/GradientUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -434,6 +434,8 @@ class GradientUtils : public CacheUtility {

Value *getNewFromOriginal(const Value *originst) const {
assert(originst);
if (isa<ConstantData>(originst))
return const_cast<Value *>(originst);
auto f = originalToNewFn.find(originst);
if (f == originalToNewFn.end()) {
llvm::errs() << *oldFunc << "\n";
Expand Down Expand Up @@ -691,6 +693,20 @@ class GradientUtils : public CacheUtility {
placeholder->setName("");
IRBuilder<> bb(placeholder);

Function *Fn = orig->getCalledFunction();

#if LLVM_VERSION_MAJOR >= 11
if (auto castinst = dyn_cast<ConstantExpr>(orig->getCalledOperand()))
#else
if (auto castinst = dyn_cast<ConstantExpr>(orig->getCalledValue()))
#endif
{
if (castinst->isCast())
if (auto fn = dyn_cast<Function>(castinst->getOperand(0)))
Fn = fn;
}
assert(Fn);

SmallVector<Value *, 8> args;
#if LLVM_VERSION_MAJOR >= 14
for (auto &arg : orig->args())
Expand All @@ -701,14 +717,12 @@ class GradientUtils : public CacheUtility {
args.push_back(getNewFromOriginal(arg));
}

if (shadowHandlers.find(orig->getCalledFunction()->getName().str()) !=
shadowHandlers.end()) {
if (shadowHandlers.find(Fn->getName().str()) != shadowHandlers.end()) {
bb.SetInsertPoint(placeholder);
Value *anti = placeholder;

if (mode != DerivativeMode::ReverseModeGradient) {
anti = shadowHandlers[orig->getCalledFunction()->getName().str()](
bb, orig, args);
anti = shadowHandlers[Fn->getName().str()](bb, orig, args);

invertedPointers.erase(found);
bb.SetInsertPoint(placeholder);
Expand All @@ -726,8 +740,14 @@ class GradientUtils : public CacheUtility {
return anti;
}

#if LLVM_VERSION_MAJOR >= 11
Value *anti =
bb.CreateCall(orig->getCalledFunction(), args, orig->getName() + "'mi");
bb.CreateCall(orig->getFunctionType(), orig->getCalledOperand(), args,
orig->getName() + "'mi");
#else
Value *anti =
bb.CreateCall(orig->getCalledValue(), args, orig->getName() + "'mi");
#endif
cast<CallInst>(anti)->setAttributes(orig->getAttributes());
cast<CallInst>(anti)->setCallingConv(orig->getCallingConv());
cast<CallInst>(anti)->setTailCallKind(orig->getTailCallKind());
Expand All @@ -745,8 +765,7 @@ class GradientUtils : public CacheUtility {
Attribute::NonNull);
#endif
unsigned derefBytes = 0;
if (orig->getCalledFunction()->getName() == "malloc" ||
orig->getCalledFunction()->getName() == "_Znwm") {
if (Fn->getName() == "malloc" || Fn->getName() == "_Znwm") {
if (auto ci = dyn_cast<ConstantInt>(args[0])) {
derefBytes = ci->getLimitedValue();
CallInst *cal = cast<CallInst>(getNewFromOriginal(orig));
Expand Down Expand Up @@ -789,7 +808,7 @@ class GradientUtils : public CacheUtility {
std::make_pair((const Value *)orig, InvertedPointerVH(this, anti)));

if (tape == nullptr) {
if (orig->getCalledFunction()->getName() == "julia.gc_alloc_obj") {
if (Fn->getName() == "julia.gc_alloc_obj") {
Type *tys[] = {
PointerType::get(StructType::get(orig->getContext()), 10)};
FunctionType *FT =
Expand All @@ -799,7 +818,7 @@ class GradientUtils : public CacheUtility {
anti);
}

if (orig->getCalledFunction()->getName() == "swift_allocObject") {
if (Fn->getName() == "swift_allocObject") {
EmitFailure(
"SwiftShadowAllocation", orig->getDebugLoc(), orig,
"Haven't implemented shadow allocator for `swift_allocObject`",
Expand All @@ -817,7 +836,7 @@ class GradientUtils : public CacheUtility {
auto val_arg = ConstantInt::get(Type::getInt8Ty(orig->getContext()), 0);
Value *size;
// todo check if this memset is legal and if a write barrier is needed
if (orig->getCalledFunction()->getName() == "julia.gc_alloc_obj") {
if (Fn->getName() == "julia.gc_alloc_obj") {
size = args[1];
} else {
size = args[0];
Expand Down Expand Up @@ -1667,7 +1686,7 @@ class DiffeGradientUtils : public GradientUtils {
llvm::errs() << "module: " << *oldFunc->getParent() << "\n";
llvm::errs() << "oldFunc: " << *oldFunc << "\n";
llvm::errs() << "newFunc: " << *newFunc << "\n";
llvm::errs() << "val: " << *val << " old: " << old << "\n";
llvm::errs() << "val: " << *val << " old: " << *old << "\n";
}
assert(addingType);
assert(addingType->isFPOrFPVectorTy());
Expand Down
5 changes: 3 additions & 2 deletions enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2489,8 +2489,9 @@ void TypeAnalyzer::visitMemTransferCommon(llvm::CallInst &MTI) {
size_t sz = 1;
for (auto val :
fntypeinfo.knownIntegralValues(MTI.getArgOperand(2), *DT, intseen)) {
assert(val >= 0);
sz = max(sz, (size_t)val);
if (val >= 0) {
sz = max(sz, (size_t)val);
}
}

TypeTree res = getAnalysis(MTI.getArgOperand(0)).AtMost(sz).PurgeAnything();
Expand Down
58 changes: 58 additions & 0 deletions enzyme/test/Integration/ReverseMode/metamalloc.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
// RUN: %clang -std=c11 -Xclang -new-struct-path-tbaa -O0 %s -S -emit-llvm -o - | %opt - %loadEnzyme -enzyme -S | %lli -
// RUN: %clang -std=c11 -Xclang -new-struct-path-tbaa -fno-unroll-loops -O1 %s -S -emit-llvm -o - | %opt - %loadEnzyme -enzyme -S | %lli -
// RUN: %clang -std=c11 -Xclang -new-struct-path-tbaa -O2 %s -S -emit-llvm -o - | %opt - %loadEnzyme -enzyme -S | %lli -
// RUN: %clang -std=c11 -Xclang -new-struct-path-tbaa -O3 %s -S -emit-llvm -o - | %opt - %loadEnzyme -enzyme -S | %lli -
// RUN: %clang -std=c11 -Xclang -new-struct-path-tbaa -O0 %s -S -emit-llvm -o - | %opt - %loadEnzyme -enzyme --enzyme-inline=1 -S | %lli -
// RUN: %clang -std=c11 -Xclang -new-struct-path-tbaa -O1 %s -S -emit-llvm -o - | %opt - %loadEnzyme -enzyme --enzyme-inline=1 -S | %lli -
// RUN: %clang -std=c11 -Xclang -new-struct-path-tbaa -O2 %s -S -emit-llvm -o - | %opt - %loadEnzyme -enzyme --enzyme-inline=1 -S | %lli -
// RUN: %clang -std=c11 -Xclang -new-struct-path-tbaa -O3 %s -S -emit-llvm -o - | %opt - %loadEnzyme -enzyme --enzyme-inline=1 -S | %lli -

#include <stdio.h>
#include <math.h>
#include <assert.h>

#include "test_utils.h"

double __enzyme_autodiff(void*, ...);

struct {
int count;
void* (*allocfn)(long int);
} tup = {0, malloc};
__attribute__((noinline))
void* metamalloc(long int size) {
void* ret = tup.allocfn(size);
//if (ret != 0)
// tup.count++;
return ret;
}
__attribute__((noinline))
void square(double* x) {
*x *= *x;
}
double alldiv(double x) {
double* mem = (double*)metamalloc(8);
*mem = x;
square(mem);
return mem[0];
}


static void* (*sallocfn)(int) = malloc;
__attribute__((noinline))
void* smetamalloc(int size) {
return sallocfn(size);
}
double salldiv(double x) {
double* mem = (double*)metamalloc(8);
*mem = x * x;
return mem[0];
}

int main(int argc, char** argv) {
double res = __enzyme_autodiff((void*)alldiv, 3.14);
APPROX_EQ(res, 6.28, 1e-6);
double sres = __enzyme_autodiff((void*)salldiv, 3.14);
APPROX_EQ(sres, 6.28, 1e-6);
return 0;
}

0 comments on commit cb44221

Please sign in to comment.