diff --git a/tao_compiler/mlir/disc/transforms/disc_transform_schedule.cc b/tao_compiler/mlir/disc/transforms/disc_transform_schedule.cc index 464983fb3c5..6ea8b7bab21 100644 --- a/tao_compiler/mlir/disc/transforms/disc_transform_schedule.cc +++ b/tao_compiler/mlir/disc/transforms/disc_transform_schedule.cc @@ -571,7 +571,7 @@ LogicalResult Aarch64GEMMLargeKScheduleFactory::assignSchedule( buildFuseIntoContainingOp(b, loc, fill, forEachThreadLoop); // first level tile size for dimension m - int64_t M0 = 8; + int64_t M0 = 288, M1 = 8; // first/second level tile size for dimension n int64_t N0 = 204, N1 = 12; // first/second level tile size for dimension k @@ -579,9 +579,9 @@ LogicalResult Aarch64GEMMLargeKScheduleFactory::assignSchedule( // TODO(wyzero): query cpuinfo. int64_t hardwareVectorSizeInBytes = 4; - auto tileOp0 = buildTileOp(b, loc, tiledMatmul, {0, N0, K0}, {0, 2, 1}); + auto tileOp0 = buildTileOp(b, loc, tiledMatmul, {M0, N0, K0}, {0, 2, 1}); auto tileOp1 = - buildTileOp(b, loc, tileOp0->getResult(0), {M0, N1, K1}, {0, 1, 2}); + buildTileOp(b, loc, tileOp0->getResult(0), {M1, N1, K1}, {0, 1, 2}); // fold extract_slice ops generated by two-level tiling. It's needed to // enable following pad and cache_read schedule. @@ -596,8 +596,9 @@ LogicalResult Aarch64GEMMLargeKScheduleFactory::assignSchedule( Value padForWeight = buildGetProducerOfOperand(b, loc, padOp, 1); // Check if we need to pad dimension `m/n/k` if input or weight is packed - bool mIsPadded = - (M0 != 1) && (M == ShapedType::kDynamicSize || M > M0 && M % M0 != 0); + bool mIsPadded = (M1 != 1) && (M == ShapedType::kDynamicSize || + (M > M0 && (M % M0 != 0 || M0 % M1 != 0)) || + (M <= M0 && M > M1 && M % M1 != 0)); bool nIsPadded = (N1 != 1) && (N == ShapedType::kDynamicSize || (N > N0 && (N % N0 != 0 || N0 % N1 != 0)) || (N <= N0 && N > N1 && N % N1 != 0)); @@ -606,24 +607,27 @@ LogicalResult Aarch64GEMMLargeKScheduleFactory::assignSchedule( (K <= K0 && K > K1 && K % K1 != 0)); // Check if we need to pack the input: - bool packInput = ((M == ShapedType::kDynamicSize || M >= M0) && + bool packInput = ((M == ShapedType::kDynamicSize || M >= M1) && (K == ShapedType::kDynamicSize || K > K0) && (N == ShapedType::kDynamicSize || N > N0)); // supposed loop order: - // loop_k0 - // loop_n0 - // loop_m0 - // loop_n1 - // loop_k1 { - // inner_most_gemm - // } + // loop_m0 + // loop_k0 + // loop_n0 + // loop_m1 + // loop_n1 + // loop_k1 { + // inner_most_gemm + // } // in case: // - the size of dimension K <= K0, then loop_k0 will be folded. + bool m0Skipped = (M != ShapedType::kDynamicSize && M <= M0); + // - the size of dimension K <= K0, then loop_k0 will be folded. bool k0Skipped = (K != ShapedType::kDynamicSize && K <= K0); // - the size of dimension N <= N0, then loop_n0 will be folded. bool n0Skipped = (N != ShapedType::kDynamicSize && N <= N0); // - the size of dimension M <= M1, then loop_m1 will be folded. - bool m0Skipped = (M != ShapedType::kDynamicSize && M <= M0); + bool m1Skipped = (M != ShapedType::kDynamicSize && M <= M1); // - the size of dimension N <= N1, then loop_n1 will be folded. bool n1Skipped = (N != ShapedType::kDynamicSize && N <= N1); // - the size of dimension K <= K0, then loop_k0 will be folded. @@ -631,7 +635,7 @@ LogicalResult Aarch64GEMMLargeKScheduleFactory::assignSchedule( if (packInput) { // We want to cache the packed A below loop_k0 and above loop_n0. // Thus the initial loop_level is 4. - int loopLevel = 4 - n0Skipped - m0Skipped - n1Skipped - k1Skipped; + int loopLevel = 4 - n0Skipped - m1Skipped - n1Skipped - k1Skipped; if (loopLevel <= 0) { return m->emitError() << "failed to cache the packed input due to loopLevel = " @@ -642,10 +646,10 @@ LogicalResult Aarch64GEMMLargeKScheduleFactory::assignSchedule( SmallVector tileSizes; SmallVector permutation; if (lhsTranspose) { - tileSizes = {K1, M0}; + tileSizes = {K1, M1}; permutation = {2, 0, 1, 3}; } else { - tileSizes = {M0, K1}; + tileSizes = {M1, K1}; permutation = {0, 2, 3, 1}; } buildCacheRead(b, loc, padForInput, loopN0, {1, 1}, tileSizes, @@ -690,15 +694,14 @@ LogicalResult Aarch64GEMMLargeKScheduleFactory::assignSchedule( if (!k0Skipped && !k1Skipped) { Value leftFillOp = buildMatchOp(b, loc, variant, {}, nameMap[dotOp]); Value contractOp = buildMatchOp(b, loc, variant, {"vector.contract"}); - int outterMostLoopLevel = - 5 - k0Skipped - n0Skipped - m0Skipped - n1Skipped - k1Skipped; - auto loop0 = buildGetParentForOp(b, loc, contractOp, outterMostLoopLevel); + int outterMostKLoopLevel = + 5 - k0Skipped - n0Skipped - m1Skipped - n1Skipped - k1Skipped; + auto loop0 = buildGetParentForOp(b, loc, contractOp, outterMostKLoopLevel); auto loop1 = buildGetParentForOp(b, loc, contractOp, 2); auto readers = buildMatchOp(b, loc, loop1, {"vector.transfer_read"}); auto splitedReaders = buildSplitHandlesOp(b, loc, readers, 3); buildInlineReductionInitializerOp(b, loc, leftFillOp, loop0, splitedReaders->getResult(0)); - variant = buildDISCBufferize(b, loc, variant); } buildLowerVectors(b, loc, {0, 1, 2, 3, 4}, "outerproduct", "innerparallel",