Skip to content

Commit

Permalink
[Relay][Training][Pass] Factor out first-order AD to a module pass (a…
Browse files Browse the repository at this point in the history
  • Loading branch information
altanh authored and Trevor Morris committed May 6, 2021
1 parent 14e8449 commit 2fd5c04
Show file tree
Hide file tree
Showing 4 changed files with 391 additions and 272 deletions.
26 changes: 25 additions & 1 deletion python/tvm/relay/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -800,12 +800,36 @@ def gradient(expr, mod=None, mode="higher_order"):
The transformed expression.
"""
if mode == "first_order":
return _ffi_api.first_order_gradient(expr, mod)
warnings.warn(
"using transform.gradient for first-order AD is deprecated, please use the"
"FirstOrderGradient module pass",
DeprecationWarning,
)
if mod is not None:
raise RuntimeError(
"to run first-order AD on a module, please use the FirstOrderGradient module pass."
)
return FirstOrderGradient()(tvm.IRModule.from_expr(expr))["main"]
if mode == "higher_order":
return _ffi_api.gradient(expr, mod)
raise Exception("unknown mode")


def FirstOrderGradient():
"""
Transforms all global functions in the module to return the original result, paired with the
gradients of the inputs. This pass transforms each global function independently and does not
support interprocedural AD. Additionally, this pass does not support any control-flow or
references, and should only be used on pure data-flow graphs.
Returns
-------
ret : tvm.transform.Pass
The registered FirstOrderGradient pass.
"""
return _ffi_api.FirstOrderGradient()


def Defunctionalization(func, mod):
"""
Performs defunctionalization on func,
Expand Down
309 changes: 309 additions & 0 deletions src/relay/transforms/first_order_gradient.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,309 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file first_order_gradient.cc
* \brief First-order Automatic Differentiation in Relay for pure dataflow graphs.
*/
#include <tvm/ir/type_functor.h>
#include <tvm/relay/analysis.h>
#include <tvm/relay/dataflow_matcher.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/feature.h>
#include <tvm/relay/transform.h>
#include <tvm/te/operation.h>

#include "gradient.h"
#include "let_list.h"
#include "pass_utils.h"
#include "pattern_utils.h"

namespace tvm {
namespace relay {

template <typename F>
Expr MultiFactory(const Type& t, F factory, DiagnosticContext diag_ctx) {
if (auto* tt = t.as<TensorTypeNode>()) {
return factory(tt->shape, tt->dtype);
} else if (auto* tt = t.as<TupleTypeNode>()) {
std::vector<Expr> res;
for (size_t i = 0; i < tt->fields.size(); i++) {
res.push_back(MultiFactory(tt->fields[i], factory, diag_ctx));
}
return Tuple(res);
} else {
diag_ctx.EmitFatal(Diagnostic::Error(t->span)
<< "could not build tensors using factory for type " << PrettyPrint(t));
throw;
}
}

template <typename F, typename F2>
Expr MultiFactoryLike(const Expr& e, const Type& t, F factory, F2 factory_like,
DiagnosticContext diag_ctx) {
if (t.as<TensorTypeNode>()) {
return factory_like(e);
} else if (auto* tt = t.as<TupleTypeNode>()) {
return MultiFactory(t, factory, diag_ctx);
} else {
diag_ctx.EmitFatal(Diagnostic::Error(t->span)
<< "could not build tensors using factory for type " << PrettyPrint(t));
throw;
}
}

/*! \brief A fragment of the program being built by the automatic differentation
* pass.
*/
struct ADValueNode {
virtual ~ADValueNode() {}
template <typename T>
T& get() {
auto ret = dynamic_cast<T*>(this);
ICHECK(ret) << "cannot downcast";
return *ret;
}
};

using ADValue = std::shared_ptr<ADValueNode>;

/*! \brief AD over a program which generates a tensor output. */
struct ADTensor : ADValueNode {
Expr forward;
mutable Expr reverse; // must be a variable to avoid duplication
ADTensor(LetList* ll, const Expr& forward, DiagnosticContext diag_ctx)
: forward(ll->Push(forward)),
reverse(ll->Push(
MultiFactoryLike(this->forward, forward->checked_type(), Zeros, ZerosLike, diag_ctx))) {
this->forward->checked_type_ = forward->checked_type();
}
};

/*! \brief A staged representation of the program, we reflect
* Relay functions into a function over fragments of AD. We
* can compute away this function to obtain a reverse mode program.
*/
struct ADFunction : ADValueNode {
// (ad_args, orig) -> ad_ret
using ADFunctionType = ADValue(const std::vector<ADValue>&, const Call&);
std::function<ADFunctionType> func;
explicit ADFunction(const std::function<ADFunctionType>& func) : func(func) {}
};

struct FirstOrderReverseAD : ExprFunctor<ADValue(const Expr&)> {
const OpAttrMap<FPrimalGradient> rev_map = Op::GetAttrMap<FPrimalGradient>("FPrimalGradient");
std::vector<std::function<void(LetList* ll)>> backprop_actions;
// we assume no closure so no need for lexical scoping
std::unordered_map<Expr, ADValue, ObjectPtrHash, ObjectPtrEqual> env;
LetList* ll;
DiagnosticContext diag_ctx;

FirstOrderReverseAD(LetList* ll, DiagnosticContext diag_ctx) : ll(ll), diag_ctx(diag_ctx) {}

ADValue VisitExpr(const Expr& n) final {
if (env.count(n)) {
return env.at(n);
}
auto ret = ExprFunctor::VisitExpr(n);
env[n] = ret;
return ret;
}

static Expr LiftedAdd(const Type& t, const Expr& x, const Expr& y, LetList* ll) {
if (t.as<TensorTypeNode>()) {
return ll->Push(Add(x, y));
} else if (auto* tt = t.as<TupleTypeNode>()) {
Array<Expr> fields;
for (size_t i = 0; i < tt->fields.size(); ++i) {
fields.push_back(
LiftedAdd(tt->fields[i], ll->Push(GetField(x, i)), ll->Push(GetField(y, i)), ll));
}
return ll->Push(Tuple(fields));
} else {
LOG(FATAL) << "cannot lift addition for type " << PrettyPrint(t);
throw;
}
}

ADValue VisitExpr_(const OpNode* op) final {
Op op_ref = GetRef<Op>(op);
if (!rev_map.count(op_ref)) {
diag_ctx.EmitFatal(Diagnostic::Error(op->span)
<< "the operator " << op->name << " does not have a registered gradient.");
}
return std::make_shared<ADFunction>([this, op_ref](const std::vector<ADValue>& ad_args,
const Call& orig) {
std::vector<Expr> orig_args;
for (const ADValue& adval : ad_args) {
orig_args.push_back(adval->get<ADTensor>().forward);
}
auto orig_new = Call(op_ref, orig_args, orig->attrs, orig->type_args);
orig_new->checked_type_ = orig->checked_type();
auto ret = std::make_shared<ADTensor>(ll, orig_new, diag_ctx);
backprop_actions.push_back([this, ad_args, orig_new, ret, op_ref](LetList* ll) {
tvm::Array<Expr> rev = rev_map[op_ref](orig_new, ret->reverse);
if (ad_args.size() != rev.size()) {
diag_ctx.EmitFatal(Diagnostic::Error(op_ref->span)
<< "arity mismatch for operator " << op_ref->name
<< " and its registered gradient: expected " << ad_args.size()
<< " but got " << rev.size() << " gradients.");
}
for (size_t i = 0; i < ad_args.size(); ++i) {
auto& ad_arg = ad_args[i]->get<ADTensor>();
ad_arg.reverse = LiftedAdd(ad_arg.forward->checked_type(), ad_arg.reverse, rev[i], ll);
}
});
return ret;
});
}

ADValue VisitExpr_(const TupleGetItemNode* op) final {
Expr e = GetRef<Expr>(op);
ADValue tup = VisitExpr(op->tuple);
auto tt = op->tuple->checked_type().as<TupleTypeNode>();
size_t idx = op->index;
auto ret = std::make_shared<ADTensor>(ll, e, diag_ctx);
backprop_actions.push_back([tup, tt, idx, ret](LetList* ll) {
auto& ad_tup = tup->get<ADTensor>();
std::vector<Expr> updated_grads;
for (size_t i = 0; i < tt->fields.size(); ++i) {
Expr grad_pre = GetField(ad_tup.reverse, i);
updated_grads.push_back(i != idx ? grad_pre
: LiftedAdd(tt->fields[i], grad_pre, ret->reverse, ll));
}
ad_tup.reverse = ll->Push(Tuple(updated_grads));
});
return ret;
}

ADValue VisitExpr_(const TupleNode* op) final {
Expr e = GetRef<Expr>(op);
std::vector<ADValue> fields;
for (const auto& f : op->fields) {
fields.push_back(VisitExpr(f));
}
auto tt = op->checked_type().as<TupleTypeNode>();
auto ret = std::make_shared<ADTensor>(ll, e, diag_ctx);
backprop_actions.push_back([fields, tt, ret](LetList* ll) {
for (size_t i = 0; i < fields.size(); ++i) {
auto& ad_field = fields[i]->get<ADTensor>();
ad_field.reverse =
LiftedAdd(tt->fields[i], ad_field.reverse, GetField(ret->reverse, i), ll);
}
});
return ret;
}

ADValue VisitExpr_(const ConstantNode* op) final {
Expr e = GetRef<Expr>(op);
return std::make_shared<ADTensor>(ll, e, diag_ctx);
}

ADValue VisitExpr_(const CallNode* op) final {
ADValue f = VisitExpr(op->op);
std::vector<ADValue> args;
for (const auto& arg : op->args) {
args.push_back(VisitExpr(arg));
}
return f->get<ADFunction>().func(args, GetRef<Call>(op));
}

ADValue VisitExpr_(const FunctionNode* op) final {
Function f = GetRef<Function>(op);
// todo: assert no closure
return std::make_shared<ADFunction>(
[this, f](const std::vector<ADValue>& ad_args, const Call& orig) {
ICHECK_EQ(f->params.size(), ad_args.size());
for (size_t i = 0; i < f->params.size(); ++i) {
env[f->params[i]] = ad_args[i];
}
return VisitExpr(f->body);
});
}

// Var will always be in env, handled in VisitExpr (without _), so we don't need
// to implement its VisitExpr_.
};

namespace transform {

Pass FirstOrderGradient() {
runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> f = [](IRModule mod, PassContext ctx) {
CheckFeature(
mod, FeatureSet({fVar, fConstant, fTuple, fTupleGetItem, fFunction, fOp, fCall, fGraph}));
IRModule ad_mod = GetRef<IRModule>(mod.CopyOnWrite());
DiagnosticContext diag_ctx = DiagnosticContext::Default(ad_mod);

if (mod->functions.size() > 1) {
LOG(WARNING) << "IRModule contains multiple global functions: first-order AD will transform "
"them indepedently!";
}

for (const auto& pr : mod->functions) {
const FunctionNode* func = pr.second.as<FunctionNode>();
if (!func) {
diag_ctx.Emit(Diagnostic::Warning(pr.second->span)
<< "AD can only be performed on Relay functions, skipping "
<< PrettyPrint(pr.first));
}
if (func->type_params.size() > 0) {
diag_ctx.EmitFatal(Diagnostic::Error(pr.second->span)
<< "first-order AD does not support polymorphism yet.");
}
Expr body = LetList::With([&](LetList* ll) {
FirstOrderReverseAD reverse_ad(ll, diag_ctx);
ADValue rev = reverse_ad(pr.second);
std::vector<ADValue> args;
for (const auto& p : func->params) {
args.push_back(std::make_shared<ADTensor>(ll, p, diag_ctx));
}
Call placeholder = Call(GetRef<Function>(func), {});
placeholder->checked_type_ = func->checked_type().as<FuncTypeNode>()->ret_type;
auto grad_call = rev->get<ADFunction>().func(args, placeholder);
auto& res = grad_call->get<ADTensor>();
Expr grad_tuple = LetList::With([&](LetList* ll) {
res.reverse =
MultiFactoryLike(res.forward, res.forward->checked_type(), Ones, OnesLike, diag_ctx);
for (auto it = reverse_ad.backprop_actions.rbegin();
it != reverse_ad.backprop_actions.rend(); ++it) {
(*it)(ll);
}
std::vector<Expr> grads;
for (const auto& a : args) {
grads.push_back(a->get<ADTensor>().reverse);
}
return Tuple(grads);
});
return Pair(res.forward, grad_tuple);
});
ad_mod->Update(pr.first,
Function(func->params, body, GradRetType(GetRef<Function>(func)), {}));
}

return ad_mod;
};
return CreateModulePass(f, 0, "FirstOrderGradient", {});
}

TVM_REGISTER_GLOBAL("relay._transform.FirstOrderGradient").set_body_typed(FirstOrderGradient);

} // namespace transform

} // namespace relay
} // namespace tvm
Loading

0 comments on commit 2fd5c04

Please sign in to comment.