Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[transform] tow level tilings for m dimension in large k schedule #970

Merged
merged 1 commit into from
Jan 19, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 24 additions & 21 deletions tao_compiler/mlir/disc/transforms/disc_transform_schedule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -571,17 +571,17 @@ LogicalResult Aarch64GEMMLargeKScheduleFactory::assignSchedule(
buildFuseIntoContainingOp(b, loc, fill, forEachThreadLoop);

// first level tile size for dimension m
int64_t M0 = 8;
int64_t M0 = 288, M1 = 8;
// first/second level tile size for dimension n
int64_t N0 = 204, N1 = 12;
// first/second level tile size for dimension k
int64_t K0 = 512, K1 = 1;
// TODO(wyzero): query cpuinfo.
int64_t hardwareVectorSizeInBytes = 4;

auto tileOp0 = buildTileOp(b, loc, tiledMatmul, {0, N0, K0}, {0, 2, 1});
auto tileOp0 = buildTileOp(b, loc, tiledMatmul, {M0, N0, K0}, {0, 2, 1});
auto tileOp1 =
buildTileOp(b, loc, tileOp0->getResult(0), {M0, N1, K1}, {0, 1, 2});
buildTileOp(b, loc, tileOp0->getResult(0), {M1, N1, K1}, {0, 1, 2});

// fold extract_slice ops generated by two-level tiling. It's needed to
// enable following pad and cache_read schedule.
Expand All @@ -596,8 +596,9 @@ LogicalResult Aarch64GEMMLargeKScheduleFactory::assignSchedule(
Value padForWeight = buildGetProducerOfOperand(b, loc, padOp, 1);

// Check if we need to pad dimension `m/n/k` if input or weight is packed
bool mIsPadded =
(M0 != 1) && (M == ShapedType::kDynamicSize || M > M0 && M % M0 != 0);
bool mIsPadded = (M1 != 1) && (M == ShapedType::kDynamicSize ||
(M > M0 && (M % M0 != 0 || M0 % M1 != 0)) ||
(M <= M0 && M > M1 && M % M1 != 0));
bool nIsPadded = (N1 != 1) && (N == ShapedType::kDynamicSize ||
(N > N0 && (N % N0 != 0 || N0 % N1 != 0)) ||
(N <= N0 && N > N1 && N % N1 != 0));
Expand All @@ -606,32 +607,35 @@ LogicalResult Aarch64GEMMLargeKScheduleFactory::assignSchedule(
(K <= K0 && K > K1 && K % K1 != 0));

// Check if we need to pack the input:
bool packInput = ((M == ShapedType::kDynamicSize || M >= M0) &&
bool packInput = ((M == ShapedType::kDynamicSize || M >= M1) &&
(K == ShapedType::kDynamicSize || K > K0) &&
(N == ShapedType::kDynamicSize || N > N0));
// supposed loop order:
// loop_k0
// loop_n0
// loop_m0
// loop_n1
// loop_k1 {
// inner_most_gemm
// }
// loop_m0
// loop_k0
// loop_n0
// loop_m1
// loop_n1
// loop_k1 {
// inner_most_gemm
// }
// in case:
// - the size of dimension K <= K0, then loop_k0 will be folded.
bool m0Skipped = (M != ShapedType::kDynamicSize && M <= M0);
// - the size of dimension K <= K0, then loop_k0 will be folded.
bool k0Skipped = (K != ShapedType::kDynamicSize && K <= K0);
// - the size of dimension N <= N0, then loop_n0 will be folded.
bool n0Skipped = (N != ShapedType::kDynamicSize && N <= N0);
// - the size of dimension M <= M1, then loop_m1 will be folded.
bool m0Skipped = (M != ShapedType::kDynamicSize && M <= M0);
bool m1Skipped = (M != ShapedType::kDynamicSize && M <= M1);
// - the size of dimension N <= N1, then loop_n1 will be folded.
bool n1Skipped = (N != ShapedType::kDynamicSize && N <= N1);
// - the size of dimension K <= K0, then loop_k0 will be folded.
bool k1Skipped = (K != ShapedType::kDynamicSize && K <= K1);
if (packInput) {
// We want to cache the packed A below loop_k0 and above loop_n0.
// Thus the initial loop_level is 4.
int loopLevel = 4 - n0Skipped - m0Skipped - n1Skipped - k1Skipped;
int loopLevel = 4 - n0Skipped - m1Skipped - n1Skipped - k1Skipped;
if (loopLevel <= 0) {
return m->emitError()
<< "failed to cache the packed input due to loopLevel = "
Expand All @@ -642,10 +646,10 @@ LogicalResult Aarch64GEMMLargeKScheduleFactory::assignSchedule(
SmallVector<int64_t> tileSizes;
SmallVector<int64_t> permutation;
if (lhsTranspose) {
tileSizes = {K1, M0};
tileSizes = {K1, M1};
permutation = {2, 0, 1, 3};
} else {
tileSizes = {M0, K1};
tileSizes = {M1, K1};
permutation = {0, 2, 3, 1};
}
buildCacheRead(b, loc, padForInput, loopN0, {1, 1}, tileSizes,
Expand Down Expand Up @@ -690,15 +694,14 @@ LogicalResult Aarch64GEMMLargeKScheduleFactory::assignSchedule(
if (!k0Skipped && !k1Skipped) {
Value leftFillOp = buildMatchOp(b, loc, variant, {}, nameMap[dotOp]);
Value contractOp = buildMatchOp(b, loc, variant, {"vector.contract"});
int outterMostLoopLevel =
5 - k0Skipped - n0Skipped - m0Skipped - n1Skipped - k1Skipped;
auto loop0 = buildGetParentForOp(b, loc, contractOp, outterMostLoopLevel);
int outterMostKLoopLevel =
5 - k0Skipped - n0Skipped - m1Skipped - n1Skipped - k1Skipped;
auto loop0 = buildGetParentForOp(b, loc, contractOp, outterMostKLoopLevel);
auto loop1 = buildGetParentForOp(b, loc, contractOp, 2);
auto readers = buildMatchOp(b, loc, loop1, {"vector.transfer_read"});
auto splitedReaders = buildSplitHandlesOp(b, loc, readers, 3);
buildInlineReductionInitializerOp(b, loc, leftFillOp, loop0,
splitedReaders->getResult(0));
variant = buildDISCBufferize(b, loc, variant);
}

buildLowerVectors(b, loc, {0, 1, 2, 3, 4}, "outerproduct", "innerparallel",
Expand Down