From 957aefbb174749602fcb138f88bd877f6ff9cd3c Mon Sep 17 00:00:00 2001 From: Thomas Viehmann Date: Tue, 30 Jun 2020 05:34:20 +0200 Subject: [PATCH] [RELAY][GRAD] handle Tuple/TupleGetItem in first order gradient (#5946) * handle Tuple/TupleGetItem in first order gradient * Unify MultiOnes/MultiZeros. --- src/relay/transforms/gradient.cc | 79 +++++++++++++++++++++++- src/relay/transforms/pattern_util.h | 6 ++ tests/python/relay/test_pass_gradient.py | 28 ++++++--- 3 files changed, 104 insertions(+), 9 deletions(-) diff --git a/src/relay/transforms/gradient.cc b/src/relay/transforms/gradient.cc index 4bc643935dc9..6fee40c51337 100644 --- a/src/relay/transforms/gradient.cc +++ b/src/relay/transforms/gradient.cc @@ -106,6 +106,34 @@ struct ADValueNode { } }; +template +Expr MultiFactory(const Type& t, F factory) { + if (auto* tt = t.as()) { + return factory(tt->shape, tt->dtype); + } else if (auto* tt = t.as()) { + std::vector res; + for (size_t i = 0; i < tt->fields.size(); i++) { + res.push_back(MultiFactory(tt->fields[i], factory)); + } + return Tuple(res); + } else { + LOG(FATAL) << "unsupported type to create tensors of: " << tt; + throw; + } +} + +template +Expr MultiFactoryLike(const Expr& e, const Type& t, F factory, F2 factory_like) { + if (t.as()) { + return factory_like(e); + } else if (auto* tt = t.as()) { + return MultiFactory(t, factory); + } else { + LOG(FATAL) << "unsupported type to tensors of: " << tt; + throw; + } +} + using ADValue = std::shared_ptr; /*! \brief AD over a program which generates a tensor output. */ @@ -113,7 +141,9 @@ struct ADTensor : ADValueNode { Expr forward; mutable Expr reverse; // must be a variable to avoid duplication ADTensor(LetList* ll, const Expr& forward) - : forward(ll->Push(forward)), reverse(ll->Push(ZerosLike(this->forward))) { + : forward(ll->Push(forward)), + reverse( + ll->Push(MultiFactoryLike(this->forward, forward->checked_type(), Zeros, ZerosLike))) { this->forward->checked_type_ = forward->checked_type(); } }; @@ -165,6 +195,51 @@ struct FirstOrderReverseAD : ExprFunctor { }); } + ADValue VisitExpr_(const TupleGetItemNode* op) final { + Expr e = GetRef(op); + ADValue tup = VisitExpr(op->tuple); + auto tt = op->tuple->checked_type().as(); + size_t size = tt->fields.size(); + size_t idx = op->index; + auto ret = std::make_shared(ll, e); + backprop_actions.push_back([tup, idx, size, ret](LetList* ll) { + auto rev = tup->get().reverse; + // special-case Tuple, to avoid long chains of GetItem/Tuple, + // but we might have functions using tuples, so we don't know + // that the reverse node is always a tuple + std::vector grfields; + if (auto tup_node = rev.as()) { + for (size_t i = 0; i < size; ++i) { + grfields.push_back(i != idx ? tup_node->fields[i] + : Add(tup_node->fields[i], ret->reverse)); + } + } else { + for (size_t i = 0; i < size; ++i) { + grfields.push_back(i != idx ? TupleGetItem(rev, i) + : Add(TupleGetItem(rev, i), ret->reverse)); + } + } + tup->get().reverse = ll->Push(Tuple(grfields)); + }); + return ret; + } + + ADValue VisitExpr_(const TupleNode* op) final { + Expr e = GetRef(op); + std::vector fields; + for (const auto& f : op->fields) { + fields.push_back(VisitExpr(f)); + } + auto ret = std::make_shared(ll, e); + backprop_actions.push_back([fields, ret](LetList* ll) { + for (size_t i = 0; i < fields.size(); ++i) { + fields[i]->get().reverse = + ll->Push(Add(fields[i]->get().reverse, TupleGetItem(ret->reverse, i))); + } + }); + return ret; + } + ADValue VisitExpr_(const ConstantNode* op) final { Expr e = GetRef(op); return std::make_shared(ll, e); @@ -235,7 +310,7 @@ Expr FirstOrderGradient(const Expr& re, const Optional& mod) { auto c = rev->get().func(f->checked_type(), args, Attrs(), {}); const auto& res = c->get(); Expr grad = LetList::With([&](LetList* ll) { - res.reverse = OnesLike(res.forward); + res.reverse = MultiFactoryLike(res.forward, res.forward->checked_type(), Ones, OnesLike); for (auto it = reverse_ad.backprop_actions.rbegin(); it != reverse_ad.backprop_actions.rend(); ++it) { (*it)(ll); diff --git a/src/relay/transforms/pattern_util.h b/src/relay/transforms/pattern_util.h index 7518eb9ac81a..d55041163054 100644 --- a/src/relay/transforms/pattern_util.h +++ b/src/relay/transforms/pattern_util.h @@ -524,6 +524,12 @@ inline Expr OnesLike(Expr e) { return Call(op, {e}); } +Expr MakeOnes(Expr shape, DataType dtype); + +inline Expr Ones(Array shape, DataType dtype) { + return MakeOnes(CheckConstantShape(shape), dtype); +} + inline Expr CollapseSumLike(Expr e) { static const Op& op = Op::Get("collapse_sum_like"); return Call(op, {e}); diff --git a/tests/python/relay/test_pass_gradient.py b/tests/python/relay/test_pass_gradient.py index efd01cbe1a6b..e28eb4a6b249 100644 --- a/tests/python/relay/test_pass_gradient.py +++ b/tests/python/relay/test_pass_gradient.py @@ -158,20 +158,27 @@ def test_broadcast_subtract(): -np.ones_like(expected_forward).sum(axis=(0, 1), keepdims=True).squeeze(axis=0)) -def test_tuple(): +def _test_tuple(mode): shape = (10, 10) dtype = 'float32' t = relay.TensorType(shape, dtype) x = relay.var("x", t) y = relay.var("y", t) z = relay.var("z", t) - tup = relay.Var("tup") - func = relay.Function([x, y, z], relay.Let(tup, relay.Tuple([x, y, z]), - relay.TupleGetItem(tup, 0) + - relay.TupleGetItem(tup, 1) - - relay.TupleGetItem(tup, 2))) + if mode == "higher_order": + tup = relay.Var("tup") + func = relay.Function([x, y, z], relay.Let(tup, relay.Tuple([x, y, z]), + relay.TupleGetItem(tup, 0) + + relay.TupleGetItem(tup, 1) - + relay.TupleGetItem(tup, 2))) + else: + # first order does not do let. + tup = relay.Tuple([x, y, z]) + func = relay.Function([x, y, z], relay.TupleGetItem(tup, 0) + + relay.TupleGetItem(tup, 1) - + relay.TupleGetItem(tup, 2)) func = run_infer_type(func) - back_func = run_infer_type(gradient(func)) + back_func = run_infer_type(gradient(func, mode=mode)) assert back_func.checked_type == relay.FuncType([t, t, t], relay.TupleType([t, relay.TupleType([t, t, t])])) x_nd = rand(dtype, *shape) y_nd = rand(dtype, *shape) @@ -188,6 +195,12 @@ def test_tuple(): tvm.testing.assert_allclose(grad_z.asnumpy(), -1 * np.ones_like(grad_z.asnumpy())) +def test_tuple(): + _test_tuple("higher_order") + +def test_tuple_first_order(): + _test_tuple("first_order") + def test_pow(): mod = tvm.IRModule() p = Prelude(mod) @@ -304,6 +317,7 @@ def test_concat(): test_broadcast_add() test_broadcast_subtract() test_tuple() + test_tuple_first_order() test_pow() test_ref() test_square_second_order()