Skip to content

Commit

Permalink
eigen test (rust-lang#363)
Browse files Browse the repository at this point in the history
  • Loading branch information
tgymnich authored Oct 29, 2021
1 parent df488bd commit 51284d6
Show file tree
Hide file tree
Showing 4 changed files with 80 additions and 30 deletions.
52 changes: 33 additions & 19 deletions enzyme/Enzyme/AdjointGenerator.h
Original file line number Diff line number Diff line change
Expand Up @@ -7903,8 +7903,6 @@ class AdjointGenerator
IRBuilder<> Builder2(&call);
getForwardBuilder(Builder2);

bool retUsed = subretused;

SmallVector<Value *, 8> args;
std::vector<DIFFE_TYPE> argsInverted;
std::map<int, Type *> gradByVal;
Expand Down Expand Up @@ -7970,7 +7968,7 @@ class AdjointGenerator

auto newcalled = gutils->Logic.CreateForwardDiff(
cast<Function>(called), subretType, argsInverted, gutils->TLI,
TR.analyzer.interprocedural, /*returnValue*/ retUsed,
TR.analyzer.interprocedural, /*returnValue*/ subretused,
/*subdretptr*/ false, DerivativeMode::ForwardMode, nullptr,
nextTypeInfo, {});

Expand All @@ -7989,30 +7987,46 @@ class AdjointGenerator
}
#endif

if (!newcalled->getReturnType()->isVoidTy()) {
bool structret = retUsed && subretType != DIFFE_TYPE::CONSTANT;
auto newcall = gutils->getNewFromOriginal(orig);
Value *diffe;
if (structret) {
diffe = Builder2.CreateExtractValue(diffes, 1);
} else {
diffe = diffes;
}
auto newcall = gutils->getNewFromOriginal(orig);
auto ifound = gutils->invertedPointers.find(orig);
Value *primal = nullptr;
Value *diffe = nullptr;

auto ifound = gutils->invertedPointers.find(orig);
if (ifound != gutils->invertedPointers.end()) {
auto placeholder = cast<PHINode>(&*ifound->second);
if (subretused && subretType != DIFFE_TYPE::CONSTANT) {
primal = Builder2.CreateExtractValue(diffes, 0);
diffe = Builder2.CreateExtractValue(diffes, 1);
} else if (!newcalled->getReturnType()->isVoidTy()) {
diffe = diffes;
}

if (ifound != gutils->invertedPointers.end()) {
auto placeholder = cast<PHINode>(&*ifound->second);
if (primal) {
gutils->replaceAWithB(newcall, primal);
gutils->erase(newcall);
}
if (diffe) {
gutils->replaceAWithB(placeholder, diffe);
gutils->erase(placeholder);
} else {
gutils->replaceAWithB(newcall, diffe);
gutils->invertedPointers.erase(ifound);
}
gutils->erase(placeholder);
} else {
if (primal && diffe) {
gutils->replaceAWithB(newcall, primal);
if (!gutils->isConstantValue(&call)) {
setDiffe(&call, diffe, Builder2);
}
gutils->erase(newcall);
} else if (diffe) {
gutils->replaceAWithB(newcall, diffe);
if (!gutils->isConstantValue(&call)) {
setDiffe(&call, diffe, Builder2);
}
gutils->erase(newcall);
} else {
eraseIfUnused(*orig, /*erase*/ true, /*check*/ false);
}
} else {
eraseIfUnused(*orig, /*erase*/ true, /*check*/ false);
}

return;
Expand Down
16 changes: 12 additions & 4 deletions enzyme/Enzyme/EnzymeLogic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2380,8 +2380,16 @@ void createTerminator(TypeResults &TR, DiffeGradientUtils *gutils,
switch (retVal) {
case ReturnType::Return: {
auto ret = inst->getOperand(0);
toret = retType == DIFFE_TYPE::CONSTANT ? gutils->getNewFromOriginal(ret)
: gutils->diffe(ret, nBuilder);

if (retType == DIFFE_TYPE::CONSTANT) {
toret = gutils->getNewFromOriginal(ret);
} else if (!ret->getType()->isFPOrFPVectorTy() &&
TR.getReturnAnalysis().Inner0().isPossiblePointer()) {
toret = gutils->invertPointerM(ret, nBuilder);
} else {
toret = gutils->diffe(ret, nBuilder);
}

break;
}
case ReturnType::TwoReturns: {
Expand All @@ -2392,7 +2400,8 @@ void createTerminator(TypeResults &TR, DiffeGradientUtils *gutils,
toret =
nBuilder.CreateInsertValue(toret, gutils->getNewFromOriginal(ret), 0);

if (TR.getReturnAnalysis().Inner0().isPossiblePointer()) {
if (!ret->getType()->isFPOrFPVectorTy() &&
TR.getReturnAnalysis().Inner0().isPossiblePointer()) {
toret = nBuilder.CreateInsertValue(
toret, gutils->invertPointerM(ret, nBuilder), 1);
} else {
Expand Down Expand Up @@ -3717,7 +3726,6 @@ Function *EnzymeLogic::CreateForwardDiff(
return foundcalled;
}

auto TRo = TA.analyzeFunction(oldTypeInfo);
bool retActive = retType != DIFFE_TYPE::CONSTANT;

ReturnType retVal =
Expand Down
14 changes: 7 additions & 7 deletions enzyme/test/Enzyme/ForwardMode/ptr-ret.ll
Original file line number Diff line number Diff line change
Expand Up @@ -36,15 +36,15 @@ declare dso_local double @_Z16__enzyme_fwddiffz(...)

; CHECK: define internal double @fwddiffe_Z6squared(double %x, double %"x'")
; CHECK-NEXT: entry:
; CHECK-NEXT: %call = call double* @_Z6toHeapd(double %x)
; CHECK-NEXT: %0 = call { double*, double* } @fwddiffe_Z6toHeapd(double %x, double %"x'")
; CHECK-NEXT: %1 = extractvalue { double*, double* } %0, 1
; CHECK-NEXT: %2 = load double, double* %call, align 8
; CHECK-NEXT: %1 = extractvalue { double*, double* } %0, 0
; CHECK-NEXT: %2 = extractvalue { double*, double* } %0, 1
; CHECK-NEXT: %3 = load double, double* %1, align 8
; CHECK-NEXT: %4 = fmul fast double %3, %x
; CHECK-NEXT: %5 = fmul fast double %"x'", %2
; CHECK-NEXT: %6 = fadd fast double %4, %5
; CHECK-NEXT: ret double %6
; CHECK-NEXT: %4 = load double, double* %2, align 8
; CHECK-NEXT: %5 = fmul fast double %4, %x
; CHECK-NEXT: %6 = fmul fast double %"x'", %3
; CHECK-NEXT: %7 = fadd fast double %5, %6
; CHECK-NEXT: ret double %7
; CHECK-NEXT: }

; CHECK: define internal { double*, double* } @fwddiffe_Z6toHeapd(double %x, double %"x'")
Expand Down
28 changes: 28 additions & 0 deletions enzyme/test/Integration/ForwardMode/eigen.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
// RUN: %clang++ -mllvm -force-vector-width=1 -ffast-math -fno-unroll-loops -fno-vectorize -fno-slp-vectorize -fno-exceptions -O3 %s -S -emit-llvm -o - | %opt - %loadEnzyme -enzyme -S | %lli -
// RUN: %clang++ -fno-unroll-loops -fno-vectorize -fno-slp-vectorize -fno-exceptions -O2 %s -S -emit-llvm -o - | %opt - %loadEnzyme -enzyme -S | %lli -
// RUN: %clang++ -fno-unroll-loops -fno-vectorize -fno-slp-vectorize -fno-exceptions -O1 %s -S -emit-llvm -o - | %opt - %loadEnzyme -enzyme -S | %lli -
// RUN: %clang++ -fno-unroll-loops -fno-vectorize -fno-slp-vectorize -fno-exceptions -O3 %s -S -emit-llvm -o - | %opt - %loadEnzyme -enzyme -enzyme-inline=1 -S | %lli -
// RUN: %clang++ -fno-unroll-loops -fno-vectorize -fno-slp-vectorize -fno-exceptions -O2 %s -S -emit-llvm -o - | %opt - %loadEnzyme -enzyme -enzyme-inline=1 -S | %lli -
// RUN: %clang++ -fno-unroll-loops -fno-vectorize -fno-slp-vectorize -fno-exceptions -O1 %s -S -emit-llvm -o - | %opt - %loadEnzyme -enzyme -enzyme-inline=1 -S | %lli -

#include "test_utils.h"
#include <eigen3/Eigen/Core>
#include <eigen3/Eigen/Dense>

double __enzyme_fwddiff(double(double), double, double);

double square(double x) {
Eigen::Vector3d v(x, x * x, x * x * x);
v *= 2;
return v[1];
}

double dsquare(double x) { return __enzyme_fwddiff(square, x, 1.0); }

int main() {
double x = 4;
double res = dsquare(x);
APPROX_EQ(res, 16.0, 1e-10);
printf("dsquare(%f)=%f\n", x, res);
return 0;
}

0 comments on commit 51284d6

Please sign in to comment.