Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TIR][PASS] dtype rewrite for indexing variables #5092

Merged
merged 25 commits into from
Apr 2, 2020

Conversation

hzfan
Copy link
Contributor

@hzfan hzfan commented Mar 18, 2020

Changes:

  • enable indexing with i64 vars, so that large tensors with more than 2^32 elements can be properly indexed.
  • narrow i64 index which trivially fits into i32 to i32.

Some background:
https://discuss.tvm.ai/t/rfc-support-for-large-tensors/5643

Take the following as an example:

A = te.placeholder((m, n), name='A')
B = te.placeholder((m, n), name='B')
C = te.compute((m, n), lambda *idx: A[idx] + B[idx])

m, n = te.var(’m’, dtype=‘int64’), te.var(’n’, dtype=‘int64’) yields

produce compute {
  for (i0.int64, (int64)0, m.int64) {
    for (i1.int64, (int64)0, n.int64) {
      compute[((i0.int64*stride.int64) + (i1.int64*stride.int64))] = (A[((i0.int64*stride.int64) + (i1.int64*stride.int64))] + B[((i0.int64*stride.int64) + (i1.int64*stride.int64))])
    }
  }
}

m, n = tvm.tir.const(2, dtype="int64"), tvm.tir.const(2, dtype="int64") yields

produce compute {
  for (i0.int32, 0, 2) {
    for (i1.int32, 0, 2) {
      compute[((i0.int32*2) + i1.int32)] = (A[((i0.int32*2) + i1.int32)] + B[((i0.int32*2) + i1.int32)])
    }
  }
}

@yzhliu Could you review?

Copy link
Member

@yzhliu yzhliu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will do another round of review tomorrow. @tqchen could you also take a look?

src/tir/pass/rewrite_datatype.cc Outdated Show resolved Hide resolved
src/tir/pass/rewrite_datatype.cc Outdated Show resolved Hide resolved
src/tir/pass/rewrite_datatype.cc Outdated Show resolved Hide resolved
@tqchen tqchen self-assigned this Mar 19, 2020
@tqchen
Copy link
Member

tqchen commented Mar 19, 2020

@junrushao1994 @ZihengJiang @merrymercy @Hzfengsy would be great if you can also help to take a look

src/tir/pass/rewrite_datatype.cc Outdated Show resolved Hide resolved
src/tir/pass/rewrite_datatype.cc Outdated Show resolved Hide resolved
src/tir/pass/rewrite_datatype.cc Outdated Show resolved Hide resolved
src/tir/pass/unroll_loop.cc Show resolved Hide resolved
include/tvm/tir/ir_pass.h Outdated Show resolved Hide resolved
src/tir/pass/narrow_datatype.cc Outdated Show resolved Hide resolved
src/tir/pass/narrow_datatype.cc Outdated Show resolved Hide resolved
src/tir/pass/narrow_datatype.cc Outdated Show resolved Hide resolved
@hzfan hzfan force-pushed the large-tensor-pass_pr branch 4 times, most recently from c668c00 to 3ba40c2 Compare March 21, 2020 17:13
bounds = te.schedule.InferBound(sch)
stmt = te.schedule.ScheduleOps(sch, bounds)
stmt = tvm.tir.ir_pass.StorageFlatten(stmt, binds, 64, False)
stmt = tvm.tir.ir_pass.NarrowDataType(stmt, 32)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be great if we can add more test cases:

  • Please also consider use ir_builde to directly build loops
  • Have testcase that narrows to i16
  • Have a loop variable occurs in multiple expressions, one expr overflows, another does not

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tests added. test_slice covers the last case you mentioned.

tests/python/unittest/test_tir_pass_narrow_datatype.py Outdated Show resolved Hide resolved
src/tir/pass/narrow_datatype.cc Outdated Show resolved Hide resolved
ConstIntBound bound = analyzer_.const_int_bound(e);
int64_t ubound = Downcast<IntImm, PrimExpr>(max_value(DataType::Int(target_bits_)))->value;
int64_t lbound = Downcast<IntImm, PrimExpr>(min_value(DataType::Int(target_bits_)))->value;
if (e.dtype().bits() <= target_bits_ ||
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shall we rewrite lower bits into higher ones? cc @yzhliu

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you have an example? We previously reviewed some of the scenarios, not seeing needs doing so.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To clarify, this code seems to indicate that we can rewrite lower bits into higher ones, and I think we do not need this behavior.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I think this pass aligns with what you said. It only narrows and never promotes. This code means lower bits fits into higher bits. For example, Consider e is i.i64 <= j.i64, a bool expression with only 1 bit. So e fits into i32 and does not hinder narrowing i to i32.

https://github.com/apache/incubator-tvm/pull/5092/files#diff-98ae729cf00e30cff311ed80b4a25df9R129 ensures dtype promotion does not occcur.

if (vmap.find(op) == vmap.end()) {
vmap[op] = op->dtype.with_bits(bits);
} else {
vmap[op] = op->dtype.with_bits(std::max(vmap[op].bits(), bits));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be great to add more comment here. e.g. We are taking maximum bits for all the possible Exprs that a Var occurs.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the example. Comments added.


void VisitExpr_(const VarNode* op) {
if (vextent_.find(op) != vextent_.end()) {
int bits = std::min(vextent_[op].bits(), bits_);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add more comments about the algorithm

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added. Here we ensure that datatype is not promoted.

using arith::IRMutatorWithAnalyzer;
using arith::ConstIntBound;

class DataTypeVisitor final : public StmtExprVisitor {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please document:

  • input, what would vmap eventually store
  • The general algorithm we use(e.g. we propagate the bits backwards into vmap)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Documented.

void VisitExpr(const PrimExpr& e) {
if (e.dtype().is_int()) {
int bits = max_bits_;
ConstIntBound bound = analyzer_.const_int_bound(e);
Copy link
Member

@tqchen tqchen Mar 25, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NOTE: the constant int bound here is not necessarily the most efficient for deep nested expressions.

As we are recursively calling const int bound for all sub-expressions of e as well(when we recursively visit). Perhaps we want to add a const int bound with memoization option that allows the analyzer to pass a memo(of each subexpr to the const int bound).

We could add it as a TODO item, or directly do it in this PR, but would be great to be resolved in next few weeks cc @yzhliu

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about a state flag like with_memo in class ConstIntBoundAnalyzer? Once it is set, we first look up in the table when a new Expr comes. We can have an unordered_map like unordered_map<PrimExpr*, Entry> in ConstIntBoundAnalyzer to achieve this.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@hzfan I think it is good. but why not always doing memorization? @tqchen

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

memo can have unintended consequences if the vars can be bound to different context dependent info(e.g. if (x<10) {x+1; } else x; x<10 is only effective in the then branch.

I would say perhaps we could have another API to pass in a unordered map, and ask the analyzer to record every intermediate steps into the map

@tqchen
Copy link
Member

tqchen commented Mar 25, 2020

Some more comments, mainly wrt to clarity of the code, test coverage and efficiency concerns. Thanks for bringing in the PR. given that this is critical to a lot of the codebase, let us try to https://docs.tvm.ai/contribute/code_review.html#hold-the-highest-standard :) Let us work to polish it to the best state. Thanks @hzfan for good work so far

@tqchen tqchen added the status: need update need update based on feedbacks label Mar 25, 2020
@tqchen tqchen changed the title [PASS] dtype rewrite for indexing variables [TIR][PASS] dtype rewrite for indexing variables Mar 25, 2020
@tqchen
Copy link
Member

tqchen commented Mar 25, 2020

@spectrometerHBH @FrozenGene @Hzfengsy please also help to take a look

if (v1 != nullptr) {
// integers that do not fit in int32_t are treated as symbolic,
// as it's impossible to unroll such large loops
if (v1 != nullptr && v1->value <= std::numeric_limits<int>::max()) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if we should use int32_t here rather than int. I'm not sure, but just worry about int will represent different types (int16_t, int32_t or int64_t) on different systems and devices.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My motivation here is to prevent overflow in the next line (which uses int):

value = static_cast<int>(v1->value);

IMO it might be fine to use int here, as it's consistent with other parts of the pass. What do you think?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's fine

Copy link
Member

@yzhliu yzhliu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tqchen @Hzfengsy please check again.

void VisitExpr(const PrimExpr& e) {
if (e.dtype().is_int()) {
int bits = max_bits_;
ConstIntBound bound = analyzer_.const_int_bound(e);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@hzfan I think it is good. but why not always doing memorization? @tqchen

@tqchen
Copy link
Member

tqchen commented Apr 1, 2020

@hzfan Good work. We are in the progress of migrating to the new transform pass manager API. Can you also add a variant of the pass for IRModule and change the testcases to the new style? We can still keep using the old API until we migrated everything itno the new pass style.

reference #5198

@tqchen tqchen merged commit 4e5c584 into apache:master Apr 2, 2020
@tqchen tqchen added status: accepted and removed status: need review status: need update need update based on feedbacks labels Apr 2, 2020
@tqchen
Copy link
Member

tqchen commented Apr 2, 2020

Thanks @hzfan for keep improving the code and maintaining a high standard. Thanks @yzhliu @Hzfengsy for helpful reviews, this PR is now merged

trevor-m pushed a commit to trevor-m/tvm that referenced this pull request Apr 16, 2020
zhiics pushed a commit to neo-ai/tvm that referenced this pull request Apr 17, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants