Skip to content

Commit

Permalink
Fix wrong DIFFE_TYPE for forward mode in GetOrCreateShadowFunction (r…
Browse files Browse the repository at this point in the history
  • Loading branch information
tgymnich authored Mar 7, 2022
1 parent f08a1bf commit aee3008
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions enzyme/Enzyme/GradientUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3542,7 +3542,8 @@ Constant *GradientUtils::GetOrCreateShadowFunction(
std::pair<Argument *, std::set<int64_t>>(&a, {}));
DIFFE_TYPE typ;
if (a.getType()->isFPOrFPVectorTy()) {
typ = DIFFE_TYPE::OUT_DIFF;
typ = mode == DerivativeMode::ForwardMode ? DIFFE_TYPE::DUP_ARG
: DIFFE_TYPE::OUT_DIFF;
} else if (a.getType()->isIntegerTy() &&
cast<IntegerType>(a.getType())->getBitWidth() < 16) {
typ = DIFFE_TYPE::CONSTANT;
Expand All @@ -3554,7 +3555,8 @@ Constant *GradientUtils::GetOrCreateShadowFunction(
types.push_back(typ);
}

DIFFE_TYPE retType = fn->getReturnType()->isFPOrFPVectorTy()
DIFFE_TYPE retType = fn->getReturnType()->isFPOrFPVectorTy() &&
mode != DerivativeMode::ForwardMode
? DIFFE_TYPE::OUT_DIFF
: DIFFE_TYPE::DUP_ARG;
if (fn->getReturnType()->isVoidTy() || fn->getReturnType()->isEmptyTy() ||
Expand Down

0 comments on commit aee3008

Please sign in to comment.