Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relay][Training][Pass] Factor out first-order AD to a module pass #7677

Merged
merged 6 commits into from
Mar 18, 2021

Conversation

altanh
Copy link
Contributor

@altanh altanh commented Mar 17, 2021

Since there is negligible code sharing between the first-order and higher-order AD in Relay, I've factored out the first-order AD pass to a separate file. Additionally, I've made it a proper IRModule pass, and fixed a few spots where adding tuples wasn't being lifted correctly. I tried my best to add some diagnostics.

To complement AD, I've also added a pass called ConcretizeLike that transforms the *_like operators generated during AD to the explicit-shape equivalent (e.g. zeros_like(x, y) -> zeros(x, y.shape)), which removes some forward dependencies and could enable more opportunity for operator fusion. Note that this is the same thing as LikeZapp from @t-vi's blogpost with perhaps a bit more error checking.

cc @t-vi @MarisaKirisame @tqchen @jroesch

Comment on lines 50 to 53
auto attrs = make_object<InitOpAttrs>();
auto cshape =
MakeConstantTensor(DataType::Int(32), {static_cast<int64_t>(shape.size())}, shape);
attrs->shape = shape;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
auto attrs = make_object<InitOpAttrs>();
auto cshape =
MakeConstantTensor(DataType::Int(32), {static_cast<int64_t>(shape.size())}, shape);
attrs->shape = shape;
auto attrs = make_object<InitOpAttrs>();
attrs->shape = shape;
auto cshape =
MakeConstantTensor(DataType::Int(32), {static_cast<int64_t>(shape.size())}, shape);

Comment on lines 108 to 112
if (const auto* imm = dim.as<IntImmNode>()) {
cshape.push_back(Integer(GetRef<IntImm>(imm)));
continue;
}
return post;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if (const auto* imm = dim.as<IntImmNode>()) {
cshape.push_back(Integer(GetRef<IntImm>(imm)));
continue;
}
return post;
if (const auto* imm = dim.as<IntImmNode>()) {
cshape.push_back(Integer(GetRef<IntImm>(imm)));
} else {
return post;
}

if (call_node->args.size() == 2) {
return concrete_map_.at(op_ref)(node_map[data_pat_][0], cshape, like_ty->dtype);
}
return concrete_map_.at(op_ref)(Expr(), cshape, like_ty->dtype);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why empty Expr()?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah maybe this is too much of a hack, I'm just using it as a placeholder since the unary matches won't have a corresponding data Expr node. I'll rework this tmrw

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might be better to refer to SimplifyExpr pass to separate unary and binary ops. Then maybe we could have a base struct to put the sharable logic.

@tqchen
Copy link
Member

tqchen commented Mar 17, 2021

cc @comaniac

@tqchen tqchen self-assigned this Mar 17, 2021
Copy link
Contributor

@comaniac comaniac left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also cc @yzhliu

Comment on lines 61 to 67
for (const auto& pr : concrete_map_) {
if (!op_pat_.defined()) {
op_pat_ = IsExpr(pr.first);
} else {
op_pat_ = op_pat_ || IsExpr(pr.first);
}
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. Could you ellaborate what this loop does? It seems to me that it unions all the patterns that match ops in concrete_map_, but I didn't find op_pat_ being used else where.
  2. The construction of concrete_map_ could be static, so we should be able to move it out of the constructor.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. thanks for the catch, this was left over from before
  2. true, I'm not sure what the code style is for defining static variables like this but I'll try something

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just to follow up on the idea of making these static, I'm mainly concerned about how the static lifetime interacts with the Operator registry, do you know how that works?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think it would be an issue. If you search static const Op& op in the code base, you'll find lots of use cases.

Comment on lines 69 to 73
data_pat_ = IsWildcard();
like_pat_ = IsWildcard();
unary_like_pat_ = (IsOp("zeros_like") || IsOp("ones_like"))({like_pat_});
binary_like_pat_ = (IsOp("reshape_like") || IsOp("collapse_sum_like") ||
IsOp("broadcast_to_like"))({data_pat_, like_pat_});
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. These patterns can also be defined statically.
  2. Now we have two places to specify the supported *like ops. I feel we should define a list of "unary like" and "binary like" ops once and used them both in concrete_map_ and here. As a result, we could have the logic similar to L61-67 to construct the pattern.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good points, thanks

Comment on lines 93 to 99
// skip trying to support this for now (ironic, as I was the one who added the feature)
if (const auto* attrs = call_node->attrs.as<ReshapeLikeAttrs>()) {
if (attrs->lhs_begin != 0 || attrs->rhs_begin != 0 || attrs->lhs_end.defined() ||
attrs->rhs_end.defined()) {
return post;
}
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is too ad-hoc. It means we may not concretize *like ops in certain situations. Instead of hacking the unified callabck function, we should maintain this logic in the op specific function. I can think of two ways to achieve this:

  1. Put the logic in the beginning of concrete_map_ functions and return the same op if not applicable.
  2. Construct another checker map that include checker functions of each op, and invoke the corresponding checker function here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah I agree, thanks

Comment on lines 101 to 102
CHECK(like->checked_type_.defined())
<< "ConcretizeLike requires checked types to be populated, please run type inference";
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We already have L88-90 so this check seems useless.

if (call_node->args.size() == 2) {
return concrete_map_.at(op_ref)(node_map[data_pat_][0], cshape, like_ty->dtype);
}
return concrete_map_.at(op_ref)(Expr(), cshape, like_ty->dtype);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might be better to refer to SimplifyExpr pass to separate unary and binary ops. Then maybe we could have a base struct to put the sharable logic.

data = relay.var("data", shape=(2, 3, 4), dtype="float32")
shape_like = relay.var("shape_like", shape=(6, 2, 2), dtype="float32")
f = relay.Function([data, shape_like], relay.reshape_like(data, shape_like))
f_expected = relay.Function([data, shape_like], relay.reshape(data, (6, 2, 2)))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This exposes a question of the current implementation: after the ConcretizeLike pass we expect unused arguments. If we don't need shape_like tensor anymore, we should remove it from the function.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is a good point that I was wondering about while I wrote the tests. However this might be complicated, as removing an parameter means we'll need to remove the argument at each callsite (and then maybe re-apply the optimization until fixed point). I imagine we won't want to remove the parameter from a global function as then the user will have to ensure they don't pass the removed argument.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should be fine to have another internal pass being called after concretizing like ops to check and mutate the function signatures.


mod = tvm.IRModule.from_expr(f)
mod_concrete = relay.transform.ConcretizeLike()(mod)
assert tvm.ir.structural_equal(mod_concrete["main"], f_expected)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add if __name__ == "__main__":

@altanh
Copy link
Contributor Author

altanh commented Mar 17, 2021

@comaniac thanks for the helpful feedback, it seems there's some design decisions that would be worth combining with existing approaches. I'll spend some time this week to incorporate the feedback, but in the meantime, I think I'll remove the ConcretizeLike pass from this PR and send a follow up (since the AD refactor can stand alone). How does that sound?

@comaniac
Copy link
Contributor

@comaniac thanks for the helpful feedback, it seems there's some design decisions that would be worth combining with existing approaches. I'll spend some time this week to incorporate the feedback, but in the meantime, I think I'll remove the ConcretizeLike pass from this PR and send a follow up (since the AD refactor can stand alone). How does that sound?

Sounds good to me and I'm good with the first-order AD changes in general.

@tqchen tqchen changed the title [Relay][Training][Pass] Factor out first-order AD to a module pass, and add ConcretizeLike pass [Relay][Training][Pass] Factor out first-order AD to a module pass Mar 17, 2021
@tqchen
Copy link
Member

tqchen commented Mar 18, 2021

@tqchen tqchen merged commit 45442ed into apache:main Mar 18, 2021
@tqchen
Copy link
Member

tqchen commented Mar 18, 2021

Thanks @altanh @comaniac @MarisaKirisame

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants