Skip to content

Commit

Permalink
[transform] tow level tilings for m dimension in large k schedule (#970)
Browse files Browse the repository at this point in the history
  • Loading branch information
wyzero authored Jan 19, 2023
1 parent 9340716 commit a21c0f7
Showing 1 changed file with 24 additions and 21 deletions.
45 changes: 24 additions & 21 deletions tao_compiler/mlir/disc/transforms/disc_transform_schedule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -571,17 +571,17 @@ LogicalResult Aarch64GEMMLargeKScheduleFactory::assignSchedule(
buildFuseIntoContainingOp(b, loc, fill, forEachThreadLoop);

// first level tile size for dimension m
int64_t M0 = 8;
int64_t M0 = 288, M1 = 8;
// first/second level tile size for dimension n
int64_t N0 = 204, N1 = 12;
// first/second level tile size for dimension k
int64_t K0 = 512, K1 = 1;
// TODO(wyzero): query cpuinfo.
int64_t hardwareVectorSizeInBytes = 4;

auto tileOp0 = buildTileOp(b, loc, tiledMatmul, {0, N0, K0}, {0, 2, 1});
auto tileOp0 = buildTileOp(b, loc, tiledMatmul, {M0, N0, K0}, {0, 2, 1});
auto tileOp1 =
buildTileOp(b, loc, tileOp0->getResult(0), {M0, N1, K1}, {0, 1, 2});
buildTileOp(b, loc, tileOp0->getResult(0), {M1, N1, K1}, {0, 1, 2});

// fold extract_slice ops generated by two-level tiling. It's needed to
// enable following pad and cache_read schedule.
Expand All @@ -596,8 +596,9 @@ LogicalResult Aarch64GEMMLargeKScheduleFactory::assignSchedule(
Value padForWeight = buildGetProducerOfOperand(b, loc, padOp, 1);

// Check if we need to pad dimension `m/n/k` if input or weight is packed
bool mIsPadded =
(M0 != 1) && (M == ShapedType::kDynamicSize || M > M0 && M % M0 != 0);
bool mIsPadded = (M1 != 1) && (M == ShapedType::kDynamicSize ||
(M > M0 && (M % M0 != 0 || M0 % M1 != 0)) ||
(M <= M0 && M > M1 && M % M1 != 0));
bool nIsPadded = (N1 != 1) && (N == ShapedType::kDynamicSize ||
(N > N0 && (N % N0 != 0 || N0 % N1 != 0)) ||
(N <= N0 && N > N1 && N % N1 != 0));
Expand All @@ -606,32 +607,35 @@ LogicalResult Aarch64GEMMLargeKScheduleFactory::assignSchedule(
(K <= K0 && K > K1 && K % K1 != 0));

// Check if we need to pack the input:
bool packInput = ((M == ShapedType::kDynamicSize || M >= M0) &&
bool packInput = ((M == ShapedType::kDynamicSize || M >= M1) &&
(K == ShapedType::kDynamicSize || K > K0) &&
(N == ShapedType::kDynamicSize || N > N0));
// supposed loop order:
// loop_k0
// loop_n0
// loop_m0
// loop_n1
// loop_k1 {
// inner_most_gemm
// }
// loop_m0
// loop_k0
// loop_n0
// loop_m1
// loop_n1
// loop_k1 {
// inner_most_gemm
// }
// in case:
// - the size of dimension K <= K0, then loop_k0 will be folded.
bool m0Skipped = (M != ShapedType::kDynamicSize && M <= M0);
// - the size of dimension K <= K0, then loop_k0 will be folded.
bool k0Skipped = (K != ShapedType::kDynamicSize && K <= K0);
// - the size of dimension N <= N0, then loop_n0 will be folded.
bool n0Skipped = (N != ShapedType::kDynamicSize && N <= N0);
// - the size of dimension M <= M1, then loop_m1 will be folded.
bool m0Skipped = (M != ShapedType::kDynamicSize && M <= M0);
bool m1Skipped = (M != ShapedType::kDynamicSize && M <= M1);
// - the size of dimension N <= N1, then loop_n1 will be folded.
bool n1Skipped = (N != ShapedType::kDynamicSize && N <= N1);
// - the size of dimension K <= K0, then loop_k0 will be folded.
bool k1Skipped = (K != ShapedType::kDynamicSize && K <= K1);
if (packInput) {
// We want to cache the packed A below loop_k0 and above loop_n0.
// Thus the initial loop_level is 4.
int loopLevel = 4 - n0Skipped - m0Skipped - n1Skipped - k1Skipped;
int loopLevel = 4 - n0Skipped - m1Skipped - n1Skipped - k1Skipped;
if (loopLevel <= 0) {
return m->emitError()
<< "failed to cache the packed input due to loopLevel = "
Expand All @@ -642,10 +646,10 @@ LogicalResult Aarch64GEMMLargeKScheduleFactory::assignSchedule(
SmallVector<int64_t> tileSizes;
SmallVector<int64_t> permutation;
if (lhsTranspose) {
tileSizes = {K1, M0};
tileSizes = {K1, M1};
permutation = {2, 0, 1, 3};
} else {
tileSizes = {M0, K1};
tileSizes = {M1, K1};
permutation = {0, 2, 3, 1};
}
buildCacheRead(b, loc, padForInput, loopN0, {1, 1}, tileSizes,
Expand Down Expand Up @@ -690,15 +694,14 @@ LogicalResult Aarch64GEMMLargeKScheduleFactory::assignSchedule(
if (!k0Skipped && !k1Skipped) {
Value leftFillOp = buildMatchOp(b, loc, variant, {}, nameMap[dotOp]);
Value contractOp = buildMatchOp(b, loc, variant, {"vector.contract"});
int outterMostLoopLevel =
5 - k0Skipped - n0Skipped - m0Skipped - n1Skipped - k1Skipped;
auto loop0 = buildGetParentForOp(b, loc, contractOp, outterMostLoopLevel);
int outterMostKLoopLevel =
5 - k0Skipped - n0Skipped - m1Skipped - n1Skipped - k1Skipped;
auto loop0 = buildGetParentForOp(b, loc, contractOp, outterMostKLoopLevel);
auto loop1 = buildGetParentForOp(b, loc, contractOp, 2);
auto readers = buildMatchOp(b, loc, loop1, {"vector.transfer_read"});
auto splitedReaders = buildSplitHandlesOp(b, loc, readers, 3);
buildInlineReductionInitializerOp(b, loc, leftFillOp, loop0,
splitedReaders->getResult(0));
variant = buildDISCBufferize(b, loc, variant);
}

buildLowerVectors(b, loc, {0, 1, 2, 3, 4}, "outerproduct", "innerparallel",
Expand Down

0 comments on commit a21c0f7

Please sign in to comment.