Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RELAY][FUSION] Enhance fusion rule that starts from elemwise and broadcast #2932

Merged
merged 7 commits into from
May 1, 2019

Conversation

zhiics
Copy link
Member

@zhiics zhiics commented Mar 31, 2019

In NNVM we can fuse injective ops to elemwise ops, but in Relay cannot. Is this intended or a bug? This PR enables such a behavior.

The following script shows the fusion effects of NNVM and Relay:

import tvm
from tvm import relay
import nnvm
import nnvm.symbol as sym
import nnvm.graph as graph

print("--------------------- relay ----------------------")
data = relay.var("data", relay.ty.TensorType((1, 32, 32, 3), "float32"))

log = relay.log(data)
seq = relay.squeeze(log)

func = relay.Function([data], seq)
func = relay.ir_pass.infer_type(func)

with relay.build_config(opt_level=2):
    graph, lib, params = relay.build(func, target="llvm")
    print(graph)

print(func)
print("--------------------- nnvm ----------------------")

data1 = sym.Variable("data1")
log1 = sym.log(data1)
seq1 = sym.squeeze(log1)
shape = {"data1": (1, 32, 32, 3)}
with nnvm.compiler.build_config(opt_level=2):
    graph, lib, params = nnvm.compiler.build(seq1, target="llvm", shape=shape)
    print(graph.json())

cc @tqchen @jroesch @wweic @masahi

@wweic
Copy link
Contributor

wweic commented Mar 31, 2019

We observe 15% performance degradation with the bug for a model. Would like to understand if this is the right fix.

@@ -678,6 +678,7 @@ class GraphPartitioner {
} else {
return (kind <= kBroadcast ||
kind == kCommReduce ||
kind == kInjective ||
kind == kOutEWiseFusable);
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this fix is correct, but in this case we can just simplify this condition to kind <= kOutEwiseFusable
I wonder what is the original reason for leaving out kInjective @tqchen

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@masahi yes, we can. I wasn’t quite sure if it was intentionally left out. If not, I’ll change to <=

Copy link
Member

@tqchen tqchen Mar 31, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is something I overlooked. However, it is good to leave things as they are because this way it is more clear. We also need to add a case to fuse into multiple injective ops, i.e. need to enhance condition in the if branch as well. Please add a test-case on that as well
ewise->parallel{injective, injective}->injective

@tqchen tqchen changed the title [relay][bugfix] fuse injective to elemwise and broadcast [RELAY][FUSION] fuse injective to elemwise and broadcast Mar 31, 2019
@tqchen tqchen changed the title [RELAY][FUSION] fuse injective to elemwise and broadcast [RELAY][FUSION] Enhance fusion rule that starts from elemwise and broadcast Mar 31, 2019
@tqchen
Copy link
Member

tqchen commented Mar 31, 2019

Good catch! will leave to @masahi to manage the PR. see my comments in the review block

@masahi
Copy link
Member

masahi commented Apr 2, 2019

@zhiics can you retrigger the CI?

@zhiics
Copy link
Member Author

zhiics commented Apr 2, 2019

@masahi Sorry. I didn't have much time to look into it so far. It looks that the failure in nnvm/tests/python/compiler/test_top_level4.py is triggered by this type of fusion. I can reproduce it locally:

import tvm
from tvm import relay

data = relay.var("data", relay.ty.TensorType((1, 32, 32, 3), "float32"))
data1 = relay.var("data1", relay.ty.TensorType((1, 32, 32, 3), "float32"))

log = relay.log(data)
c1 = relay.copy(data1)
c2 = relay.copy(data1)

func = relay.Function([data, data1],  relay.Tuple(tvm.convert([log, c1, c2])))
func = relay.ir_pass.infer_type(func)

with relay.build_config(opt_level=3):
    graph, lib, params = relay.build(func, target="llvm")

This gives me: Assertion failed: (getOperand(0)->getType()->isFPOrFPVectorTy() && "Invalid operand types for FCmp instruction"), function AssertOK, file /Users/chzhi/tools/llvm/llvm/include/llvm/IR/Instructions.h, line 1257

Without the changed, the fused code is like the following:

%12 = fn (%data: Tensor[(1, 32, 32, 3), float32], %data1: Tensor[(1, 32, 32, 3), float32]) -> (Tensor[(1, 32, 32, 3), float32], Tensor[(1, 32, 32, 3), float32], Tensor[(1, 32, 32, 3), float32]) {
  %1 = fn (%p0: Tensor[(1, 32, 32, 3), float32], __dict__=meta[StrMap][0]) -> Tensor[(1, 32, 32, 3), float32] {
    %0 = log(%p0)
    %0
  }
  %2 = %1(%data)
  %4 = fn (%p01: Tensor[(1, 32, 32, 3), float32], __dict__=meta[StrMap][1]) -> Tensor[(1, 32, 32, 3), float32] {
    %3 = copy(%p01)
    %3
  }
  %5 = %4(%data1)
  %10 = fn (%p02: Tensor[(1, 32, 32, 3), float32], %p1: Tensor[(1, 32, 32, 3), float32], __dict__=meta[StrMap][2]) -> (Tensor[(1, 32, 32, 3), float32], Tensor[(1, 32, 32, 3), float32], Tensor[(1, 32, 32, 3), float32]) {
    %6 = copy(%p02)
    %7 = copy(%p1)
    %8 = copy(%p1)
    %9 = (%6, %7, %8)
    %9
  }
  %11 = %10(%2, %5)
  %11
}
%12

With the change, the fused code is like the following:

%5 = fn (%data: Tensor[(1, 32, 32, 3), float32], %data1: Tensor[(1, 32, 32, 3), float32]) -> (Tensor[(1, 32, 32, 3), float32], Tensor[(1, 32, 32, 3), float32], Tensor[(1, 32, 32, 3), float32]) {
  %3 = fn (%p0: Tensor[(1, 32, 32, 3), float32], %p1: Tensor[(1, 32, 32, 3), float32], __dict__=meta[StrMap][0]) -> (Tensor[(1, 32, 32, 3), float32], Tensor[(1, 32, 32, 3), float32], Tensor[(1, 32, 32, 3), float32]) {
    %0 = log(%p0)
    %1 = copy(%p1)
    %2 = (%0, %1, %1)  # CSE is used here
    %2
  }
  %4 = %3(%data, %data1)
  %4
}
%5

@masahi
Copy link
Member

masahi commented Apr 2, 2019

this seems like a similar error as this one

@zhiics
Copy link
Member Author

zhiics commented Apr 4, 2019

@masahi They might not be the same. The last commit is able to fix the mentioned problem, but I am not very sure if it is correct.

With the fix

%2 = fn (%data: Tensor[(1, 32, 32, 3), float32]) -> (Tensor[(1, 32, 32, 3), float32], Tensor[(1, 32, 32, 3), float32]) {
  %0 = log(%data) // ty=Tensor[(1, 32, 32, 3), float32]
  %1 = (%0, %0)
  %1
}

will produce:

%5 = fn (%data: Tensor[(1, 32, 32, 3), float32]) -> (Tensor[(1, 32, 32, 3), float32], Tensor[(1, 32, 32, 3), float32]) {
  %3 = fn (%p0: Tensor[(1, 32, 32, 3), float32], __dict__=meta[StrMap][0]) -> (Tensor[(1, 32, 32, 3), float32], Tensor[(1, 32, 32, 3), float32]) {
    %0 = log(%p0)
    %1 = copy(%0)
    %2 = (%0, %1)
    %2
  }
  %4 = %3(%data)
  %4
}
%5

Instead of:

%4 = fn (%data: Tensor[(1, 32, 32, 3), float32]) -> (Tensor[(1, 32, 32, 3), float32], Tensor[(1, 32, 32, 3), float32]) {
  %3 = fn (%p0: Tensor[(1, 32, 32, 3), float32], __dict__=meta[StrMap][0]) -> (Tensor[(1, 32, 32, 3), float32], Tensor[(1, 32, 32, 3), float32]) {
    %0 = log(%p0)
    %1 = (%0, %0)
    %1
  }
  %3 = %2(%data)
  %3
}
%4

Another problem caused this type of fusion is in tutorials/frontend/deploy_ssd_gluoncv.py. It is because there is a scalar cannot be found in the schedule here: https://github.com/dmlc/tvm/blob/master/src/relay/backend/compile_engine.cc#L111
I found which operator it is but I haven't figured out why.

It would be great if you can also take a look.

@masahi
Copy link
Member

masahi commented Apr 4, 2019

ok, I'll take a look.

@zhiics
Copy link
Member Author

zhiics commented Apr 4, 2019

@masahi
Copy link
Member

masahi commented Apr 4, 2019

The error seems to be happening at the fused function below. The sequence of adds at the beginning is not fused into this function if I disable this PR. I think scalars associated with add functions are causing the issue.

  %710 = fn (%p091: Tensor[(1, 6132, 21), float32], %p172: Tensor[(1, 100, 32, 32), float32], %p262: Tensor[(16, 1, 1), float32], %p329: Tensor[(1, 150, 16, 16), float32], %p418: Tensor[(24, 1, 1), float32], %p517: Tensor[(1, 150, 8, 8), float32], %p61: Tensor[(24, 1, 1), float32], %p71: Tensor[(1, 150, 4, 4), float32], %p81: Tensor[(24, 1, 1), float32], %p91: Tensor[(1, 100, 2, 2), float32], %p101: Tensor[(16, 1, 1), float32], %p1111: Tensor[(1, 100, 1, 1), float32], %p1210: Tensor[(16, 1, 1), float32], %p1310: Tensor[(1, 1, 128, 128, 16), float32], %p1410: Tensor[(1, 128, 32, 32, 8), float32], %p1510: Tensor[(1, 1, 64, 64, 24), float32], %p1610: Tensor[(1, 256, 16, 16, 8), float32], %p173: Tensor[(1, 1, 32, 32, 24), float32], %p181: Tensor[(1, 64, 8, 8, 8), float32], %p191: Tensor[(1, 1, 16, 16, 24), float32], %p20: Tensor[(1, 64, 4, 4, 8), float32], %p2110: Tensor[(1, 1, 16, 16, 16), float32], %p2210: Tensor[(1, 32, 2, 2, 8), float32], %p2310: Tensor[(1, 1, 16, 16, 16), float32], %p2410: Tensor[(1, 32, 1, 1, 8), float32], __dict__=meta[StrMap][91]) -> Tensor[(1, 122640, 6), float32] {
    %494 = strided_slice(%p091, begin=[0, 0, 1], end=[1, 6132, 21])
    %495 = greater(%494, 0.01f)
    %496 = cast(%495, dtype="float32")
    %497 = strided_slice(%p091, begin=[0, 0, 0], end=[1, 6132, 1])
    %498 = zeros_like(%497)
    %499 = add(%498, 0f)
    %500 = add(%498, 1f)
    %501 = add(%498, 2f)
    %502 = add(%498, 3f)
    %503 = add(%498, 4f)
    %504 = add(%498, 5f)
    %505 = add(%498, 6f)
    %506 = add(%498, 7f)
    %507 = add(%498, 8f)
    %508 = add(%498, 9f)
    %509 = add(%498, 10f)
    %510 = add(%498, 11f)
    %511 = add(%498, 12f)
    %512 = add(%498, 13f)
    %513 = add(%498, 14f)
    %514 = add(%498, 15f)
    %515 = add(%498, 16f)
    %516 = add(%498, 17f)
    %517 = add(%498, 18f)
    %518 = add(%498, 19f)
    %519 = (%499, %500, %501, %502, %503, %504, %505, %506, %507, %508, %509, %510, %511, %512, %513, %514, %515, %516, %517, %518)
    %520 = concatenate(%519, axis=-1)
    %521 = ones_like(%520)
    %522 = multiply(%521, -1f)
    %523 = where(%496, %520, %522)
    %524 = strided_slice(%523, begin=[0, 0, 0], end=[1, 6132, 1])
    %525 = zeros_like(%494)
    %526 = where(%496, %494, %525)
    %527 = strided_slice(%526, begin=[0, 0, 0], end=[1, 6132, 1])
    %528 = strided_slice(%p172, begin=[0, 84], end=[None, 100])
    %529 = add(%528, %p262)
    %530 = transpose(%529, axes=[0, 2, 3, 1])
    %531 = nn.batch_flatten(%530)
    %532 = strided_slice(%p329, begin=[0, 126], end=[None, 150])
    %533 = add(%532, %p418)
    %534 = transpose(%533, axes=[0, 2, 3, 1])
    %535 = nn.batch_flatten(%534)
    %536 = strided_slice(%p517, begin=[0, 126], end=[None, 150])
    %537 = add(%536, %p61)
    %538 = transpose(%537, axes=[0, 2, 3, 1])
    %539 = nn.batch_flatten(%538)
    %540 = strided_slice(%p71, begin=[0, 126], end=[None, 150])
    %541 = add(%540, %p81)
    %542 = transpose(%541, axes=[0, 2, 3, 1])
    %543 = nn.batch_flatten(%542)
    %544 = strided_slice(%p91, begin=[0, 84], end=[None, 100])
    %545 = add(%544, %p101)
    %546 = transpose(%545, axes=[0, 2, 3, 1])
    %547 = nn.batch_flatten(%546)
    %548 = strided_slice(%p1111, begin=[0, 84], end=[None, 100])
    %549 = add(%548, %p1210)
    %550 = transpose(%549, axes=[0, 2, 3, 1])
    %551 = nn.batch_flatten(%550)
    %552 = (%531, %535, %539, %543, %547, %551)
    %553 = concatenate(%552, axis=1)
    %554 = reshape(%553, newshape=[0, -1, 4])
    %555 = split(%554, indices_or_sections=int64(4), axis=-1)
    %556 = %555.0
    %557 = multiply(%556, 0.1f)
    %558 = add(%557, 0f)
    %559 = multiply(%p1410, 0f)
    %560 = layout_transform(%559, src_layout="NCHW8c", dst_layout="NCHW")
    %561 = slice_like(%p1310, %560, meta[relay.attrs.SliceLikeAttrs][0])
    %562 = reshape(%561, newshape=[1, -1, 4])
    %563 = reshape(%562, newshape=[1, -1, 4])
    %564 = reshape(%563, newshape=[1, -1])
    %565 = multiply(%p1610, 0f)
    %566 = layout_transform(%565, src_layout="NCHW8c", dst_layout="NCHW")
    %567 = slice_like(%p1510, %566, meta[relay.attrs.SliceLikeAttrs][1])
    %568 = reshape(%567, newshape=[1, -1, 4])
    %569 = reshape(%568, newshape=[1, -1, 4])
    %570 = reshape(%569, newshape=[1, -1])
    %571 = multiply(%p181, 0f)
    %572 = layout_transform(%571, src_layout="NCHW8c", dst_layout="NCHW")
    %573 = slice_like(%p173, %572, meta[relay.attrs.SliceLikeAttrs][2])
    %574 = reshape(%573, newshape=[1, -1, 4])
    %575 = reshape(%574, newshape=[1, -1, 4])
    %576 = reshape(%575, newshape=[1, -1])
    %577 = multiply(%p20, 0f)
    %578 = layout_transform(%577, src_layout="NCHW8c", dst_layout="NCHW")
    %579 = slice_like(%p191, %578, meta[relay.attrs.SliceLikeAttrs][3])
    %580 = reshape(%579, newshape=[1, -1, 4])
    %581 = reshape(%580, newshape=[1, -1, 4])
    %582 = reshape(%581, newshape=[1, -1])
    %583 = multiply(%p2210, 0f)
    %584 = layout_transform(%583, src_layout="NCHW8c", dst_layout="NCHW")
    %585 = slice_like(%p2110, %584, meta[relay.attrs.SliceLikeAttrs][4])
    %586 = reshape(%585, newshape=[1, -1, 4])
    %587 = reshape(%586, newshape=[1, -1, 4])
    %588 = reshape(%587, newshape=[1, -1])
    %589 = multiply(%p2410, 0f)
    %590 = layout_transform(%589, src_layout="NCHW8c", dst_layout="NCHW")
    %591 = slice_like(%p2310, %590, meta[relay.attrs.SliceLikeAttrs][5])
    %592 = reshape(%591, newshape=[1, -1, 4])
    %593 = reshape(%592, newshape=[1, -1, 4])
    %594 = reshape(%593, newshape=[1, -1])
    %595 = (%564, %570, %576, %582, %588, %594)
    %596 = concatenate(%595, axis=1)
    %597 = reshape(%596, newshape=[1, -1, 4])
    %598 = split(%597, indices_or_sections=int64(4), axis=-1)
    %599 = %598.2
    %600 = multiply(%558, %599)
    %601 = %598.0
    %602 = add(%600, %601)
    %603 = %555.2
    %604 = multiply(%603, 0.2f)
    %605 = add(%604, 0f)
    %606 = exp(%605)
    %607 = %598.2
    %608 = multiply(%606, %607)
    %609 = divide(%608, 2f)
    %610 = subtract(%602, %609)
    %611 = %555.1
    %612 = multiply(%611, 0.1f)
    %613 = add(%612, 0f)
    %614 = %598.3
    %615 = multiply(%613, %614)
    %616 = %598.1
    %617 = add(%615, %616)
    %618 = %555.3
    %619 = multiply(%618, 0.2f)
    %620 = add(%619, 0f)
    %621 = exp(%620)
    %622 = %598.3
    %623 = multiply(%621, %622)
    %624 = divide(%623, 2f)
    %625 = subtract(%617, %624)
    %626 = add(%602, %609)
    %627 = add(%617, %624)
    %628 = (%610, %625, %626, %627)
    %629 = concatenate(%628, axis=-1)
    %630 = (%524, %527, %629)
    %631 = concatenate(%630, axis=-1)
    %632 = strided_slice(%523, begin=[0, 0, 1], end=[1, 6132, 2])
    %633 = strided_slice(%526, begin=[0, 0, 1], end=[1, 6132, 2])
    %634 = (%632, %633, %629)
    %635 = concatenate(%634, axis=-1)
    %636 = strided_slice(%523, begin=[0, 0, 2], end=[1, 6132, 3])
    %637 = strided_slice(%526, begin=[0, 0, 2], end=[1, 6132, 3])
    %638 = (%636, %637, %629)
    %639 = concatenate(%638, axis=-1)
    %640 = strided_slice(%523, begin=[0, 0, 3], end=[1, 6132, 4])
    %641 = strided_slice(%526, begin=[0, 0, 3], end=[1, 6132, 4])
    %642 = (%640, %641, %629)
    %643 = concatenate(%642, axis=-1)
    %644 = strided_slice(%523, begin=[0, 0, 4], end=[1, 6132, 5])
    %645 = strided_slice(%526, begin=[0, 0, 4], end=[1, 6132, 5])
    %646 = (%644, %645, %629)
    %647 = concatenate(%646, axis=-1)
    %648 = strided_slice(%523, begin=[0, 0, 5], end=[1, 6132, 6])
    %649 = strided_slice(%526, begin=[0, 0, 5], end=[1, 6132, 6])
    %650 = (%648, %649, %629)
    %651 = concatenate(%650, axis=-1)
    %652 = strided_slice(%523, begin=[0, 0, 6], end=[1, 6132, 7])
    %653 = strided_slice(%526, begin=[0, 0, 6], end=[1, 6132, 7])
    %654 = (%652, %653, %629)
    %655 = concatenate(%654, axis=-1)
    %656 = strided_slice(%523, begin=[0, 0, 7], end=[1, 6132, 8])
    %657 = strided_slice(%526, begin=[0, 0, 7], end=[1, 6132, 8])
    %658 = (%656, %657, %629)
    %659 = concatenate(%658, axis=-1)
    %660 = strided_slice(%523, begin=[0, 0, 8], end=[1, 6132, 9])
    %661 = strided_slice(%526, begin=[0, 0, 8], end=[1, 6132, 9])
    %662 = (%660, %661, %629)
    %663 = concatenate(%662, axis=-1)
    %664 = strided_slice(%523, begin=[0, 0, 9], end=[1, 6132, 10])
    %665 = strided_slice(%526, begin=[0, 0, 9], end=[1, 6132, 10])
    %666 = (%664, %665, %629)
    %667 = concatenate(%666, axis=-1)
    %668 = strided_slice(%523, begin=[0, 0, 10], end=[1, 6132, 11])
    %669 = strided_slice(%526, begin=[0, 0, 10], end=[1, 6132, 11])
    %670 = (%668, %669, %629)
    %671 = concatenate(%670, axis=-1)
    %672 = strided_slice(%523, begin=[0, 0, 11], end=[1, 6132, 12])
    %673 = strided_slice(%526, begin=[0, 0, 11], end=[1, 6132, 12])
    %674 = (%672, %673, %629)
    %675 = concatenate(%674, axis=-1)
    %676 = strided_slice(%523, begin=[0, 0, 12], end=[1, 6132, 13])
    %677 = strided_slice(%526, begin=[0, 0, 12], end=[1, 6132, 13])
    %678 = (%676, %677, %629)
    %679 = concatenate(%678, axis=-1)
    %680 = strided_slice(%523, begin=[0, 0, 13], end=[1, 6132, 14])
    %681 = strided_slice(%526, begin=[0, 0, 13], end=[1, 6132, 14])
    %682 = (%680, %681, %629)
    %683 = concatenate(%682, axis=-1)
    %684 = strided_slice(%523, begin=[0, 0, 14], end=[1, 6132, 15])
    %685 = strided_slice(%526, begin=[0, 0, 14], end=[1, 6132, 15])
    %686 = (%684, %685, %629)
    %687 = concatenate(%686, axis=-1)
    %688 = strided_slice(%523, begin=[0, 0, 15], end=[1, 6132, 16])
    %689 = strided_slice(%526, begin=[0, 0, 15], end=[1, 6132, 16])
    %690 = (%688, %689, %629)
    %691 = concatenate(%690, axis=-1)
    %692 = strided_slice(%523, begin=[0, 0, 16], end=[1, 6132, 17])
    %693 = strided_slice(%526, begin=[0, 0, 16], end=[1, 6132, 17])
    %694 = (%692, %693, %629)
    %695 = concatenate(%694, axis=-1)
    %696 = strided_slice(%523, begin=[0, 0, 17], end=[1, 6132, 18])
    %697 = strided_slice(%526, begin=[0, 0, 17], end=[1, 6132, 18])
    %698 = (%696, %697, %629)
    %699 = concatenate(%698, axis=-1)
    %700 = strided_slice(%523, begin=[0, 0, 18], end=[1, 6132, 19])
    %701 = strided_slice(%526, begin=[0, 0, 18], end=[1, 6132, 19])
    %702 = (%700, %701, %629)
    %703 = concatenate(%702, axis=-1)
    %704 = strided_slice(%523, begin=[0, 0, 19], end=[1, 6132, 20])
    %705 = strided_slice(%526, begin=[0, 0, 19], end=[1, 6132, 20])
    %706 = (%704, %705, %629)
    %707 = concatenate(%706, axis=-1)
    %708 = (%631, %635, %639, %643, %647, %651, %655, %659, %663, %667, %671, %675, %679, %683, %687, %691, %695, %699, %703, %707)
    %709 = concatenate(%708, axis=1)
    %709
  }

j++;
}
if (j != i) {
auto copy = Copy(new_fields[i]);
Copy link
Member

@vinx13 vinx13 Apr 4, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure if this is the best way to fix. Does it work if you remove this assertion?
The llvm error could be fixed in codegen. But we might still have memory planing issue, that's why I added the copy above.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vinx13 it works, but should we remove that assertion? For my case, memory planning can pass.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it is https://github.com/dmlc/tvm/blob/3441b95e8f1e0435060f2010c10496a4b84973ae/src/codegen/llvm/codegen_llvm.cc#L849
The assertion is not handled properly. We should keep the assertion and fix the codegen.

@zhiics
Copy link
Member Author

zhiics commented Apr 4, 2019

@masahi yes, because concatenate is injective.

@zhiics
Copy link
Member Author

zhiics commented Apr 5, 2019

@masahi Does it make sense to add an function to check if a Tensor is included in the Schedule. Compute_inline is allowed only when it is in the Schedule.

@masahi
Copy link
Member

masahi commented Apr 5, 2019

@zhiics yeah, that will be a quick fix, but I want to understand why the scalar is not included in the schedule. The scalar seems to be 0f associated with multiplys by 0, such as %559 = multiply(%p1410, 0f) .

I also wonder why we need compute_inline here in the first place. That should be handled in TOPi schedules. If I remove compute_inline, it seems to work.

@zhiics
Copy link
Member Author

zhiics commented Apr 5, 2019

@masahi Yes. I tried that. I have the same question.

@vinx13
Copy link
Member

vinx13 commented Apr 5, 2019

@masahi @zhiics ConstantNode that has scalar value is lowered as a tvm.compute. If we don't call compute inline, a (1,) tensor will be created

@masahi
Copy link
Member

masahi commented Apr 5, 2019

@tqchen can you review this PR? We have two problems and walkarounds for each. I'm not sure if we want to merge them.

  1. There is an codegen error when some of the tuple fields refer to the same value. I'm not sure if it is the problem of memory planning or llvm codegen.

  2. Some scalar tensors are not included into the schedule of the fused function.

@masahi
Copy link
Member

masahi commented Apr 5, 2019

@vinx13 isn't it the case that tvm.compute corresponding to scalars will be visited in topi, when we invoke fschedule[master_op_](...) on the master op?

@vinx13
Copy link
Member

vinx13 commented Apr 5, 2019

@masahi Yes it will be visited, but we need to explicitly call compute_inline in topi schedule. For example, in add, we need to check if lhs/rhs is a scalar and then inline it. We don't do this currently.

see the discussion #2116 (comment)

@masahi
Copy link
Member

masahi commented Apr 5, 2019

@vinx13 Good to know. So I guess skipping compute_inline on scalars that are not in the schedule (for some reason) will cause problems if we are running on GPU.

@vinx13
Copy link
Member

vinx13 commented Apr 5, 2019

For the llvm error, we need to handle op->a.type().is_handle() || op->a.type().is_handle() case in
https://github.com/dmlc/tvm/blob/3441b95e8f1e0435060f2010c10496a4b84973ae/src/codegen/llvm/codegen_llvm.cc#L852

@zhiics
Copy link
Member Author

zhiics commented Apr 5, 2019

@vinx13 Thanks. But I am not sure if we want to do that because handle should be a pointer, right? llvm cmps can supporttype*. If we want to do this, we probably need to add them to other cases as well.

@vinx13
Copy link
Member

vinx13 commented Apr 5, 2019

Yes handle should be pointer, in the above case it goes the else branch which is float point cmp and caused error.
The assertion is that two arguments should be pointers to the same address
https://github.com/dmlc/tvm/blob/7cd986db0e67583bc347ed208c25be4c0d0c32a0/src/pass/make_api.cc#L130
because there are duplicated elements in api_args

@zhiics
Copy link
Member Author

zhiics commented Apr 5, 2019

@vinx13 Thanks. I updated the codegen part for llvm. Please take a look.

@masahi may I ask why skipping uncached scalar is problematic on GPU?

@masahi
Copy link
Member

masahi commented Apr 5, 2019

@zhiics that is just my rough guess, based on @vinx13's comment in #2116. I don't know what happens if unscheduled scalars are accessed from GPU.

I'll try running the SSD GPU PR at #2784 with this change to see if it works.

@vinx13
Copy link
Member

vinx13 commented Apr 5, 2019

The problem with unscheduled scalars on GPU is, the host code initialize the scalar tensor directly, instead of finding the tensor in params dict as for non-scalar constant cases. However, the host doesn't check the device of the scalar tensor and caused segmentation fault.

@vinx13
Copy link
Member

vinx13 commented Apr 8, 2019

@masahi @vinx13 It looks that we have to fix codegen for other targets as well.

For CUDA, it gets stuck here:
https://github.com/dmlc/tvm/blob/master/src/schedule/schedule_lang.cc#L30

@zhiics How to reproduce this error?

@zhiics
Copy link
Member Author

zhiics commented Apr 8, 2019

@vinx13 This one should be able to reproduce:

import tvm
from tvm import relay

data = relay.var("data", relay.ty.TensorType((1, 32, 32, 3), "float32"))
log = relay.log(data)

func = relay.Function([data],  relay.Tuple(tvm.convert([log, log])))
func = relay.ir_pass.infer_type(func)

with relay.build_config(opt_level=3):
    graph, lib, params = relay.build(func, target="cuda")

@vinx13
Copy link
Member

vinx13 commented Apr 8, 2019

If the output is like (%1, %1), two tensors are passed to fschedule https://github.com/dmlc/tvm/blob/552d4aa3587aa3a0443d050ceb324386489ff793/src/relay/backend/compile_engine.cc#L128
Here tensor_outs are an array of two elements, the two elements are the same tensor. As result, the compute is scheduled twice. This also caused the llvm codegen issue because an assertion that the generated function have two arguments that should be the same tensor is added.

@zhiics
Copy link
Member Author

zhiics commented Apr 8, 2019

@vinx13 yeah, I am aware of that the same compute is scheduled twice here as well. That’s probably why ‘copy’ works here.

@tqchen
Copy link
Member

tqchen commented Apr 16, 2019

This PR has been a bit stale. Please summarize the problems and list possible solutions

@zhiics
Copy link
Member Author

zhiics commented Apr 17, 2019

@tqchen Sorry. I've been on vacation these days.

There are two problems/bugs triggered by fusing from elemwise and broadcast ops.

SliceLikeCompute takes two tensor inputs, data and shape_like. The second input is only used to calculate the attributes of topi function. Actually the tensor is not generated, only its shape is used. That's why scalars are not found in the schedule.

But I am not sure what's the best way to fix.

@masahi @vinx13, please comment if I've missed anything.

@vinx13
Copy link
Member

vinx13 commented Apr 17, 2019

@zhiics compute_inline only when the scalar exists in the schedule is the right fix. Previously I assume that scalars created always exist in the schedule. However, SliceLike is an exception.

Regarding to the problem of (%x, %x) and the LLVM codegen error, a possible solution is to make sure that cache_node->inputs and tensor_outs have only unique entries. But this will also affect the calling convention.

@zhiics
Copy link
Member Author

zhiics commented Apr 17, 2019

@vinx13 Thanks for the quick response. Yeah, I have a fix locally that takes the unique tensors in tensor_outs. I agree with you that this will change the calling convention a bit. @tqchen, please advise.

BTW, the llvm codegen problem should also be fixed (I've actually committed a fix here).

@tqchen
Copy link
Member

tqchen commented Apr 17, 2019

At the current moment, we do not pass the tuple around as intermediate values. That means unless the tuple is the final return value of the global function, the fusor should not stop at that point. One way to solve the problem is to treat the tuple which is the return value of the function and intermediate values differently, and not fuse through such tuples if they are marked as extern_ref ( by marking its children as Opaque).

This way we simplified the assumption of the primitive kernel. I also agree that it is useful to add additional checks to check such invariants(so tuple is not used as return value)

@tqchen
Copy link
Member

tqchen commented Apr 17, 2019

#3039

@tqchen
Copy link
Member

tqchen commented Apr 29, 2019

@masahi @zhiics please follow up now that #3092 is merged

@tqchen tqchen closed this Apr 29, 2019
@tqchen tqchen reopened this Apr 29, 2019
@zhiics
Copy link
Member Author

zhiics commented May 1, 2019

@masahi @tqchen Updated. PTAL. Thanks.

@masahi masahi merged commit f88f458 into apache:master May 1, 2019
@masahi
Copy link
Member

masahi commented May 1, 2019

Thanks @zhiics @vinx13 @tqchen this is merged finally.

@zhiics zhiics deleted the fuse branch May 1, 2019 11:44
wweic pushed a commit to wweic/tvm that referenced this pull request May 13, 2019
…adcast (apache#2932)

* [relay][bugfix] fuse injective to elemwise and broadcast

* enhance fusion for prarllel injectiveOD

* check if tensor in schedule

* fix codegen

* fix lint

* update

* lint
wweic pushed a commit to neo-ai/tvm that referenced this pull request May 13, 2019
…adcast (apache#2932)

* [relay][bugfix] fuse injective to elemwise and broadcast

* enhance fusion for prarllel injectiveOD

* check if tensor in schedule

* fix codegen

* fix lint

* update

* lint
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants