-
Notifications
You must be signed in to change notification settings - Fork 70
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
transformations: (riscv_scf) add a pass to fuse perfectly nested loops (
- Loading branch information
1 parent
4a3a11d
commit ab2bc78
Showing
3 changed files
with
248 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,150 @@ | ||
// RUN: xdsl-opt -p riscv-scf-loop-fusion %s | filecheck %s | ||
|
||
// CHECK: builtin.module { | ||
|
||
// Success case | ||
%c0 = riscv.li 0 : () -> !riscv.reg<> | ||
%c1 = riscv.li 1 : () -> !riscv.reg<> | ||
%c8 = riscv.li 8 : () -> !riscv.reg<> | ||
%c64 = riscv.li 64 : () -> !riscv.reg<> | ||
|
||
riscv_scf.for %16 : !riscv.reg<> = %c0 to %c64 step %c8 { | ||
riscv_scf.for %17 : !riscv.reg<> = %c0 to %c8 step %c1 { | ||
%18 = riscv.li 8 : () -> !riscv.reg<> | ||
%19 = riscv.add %16, %17 : (!riscv.reg<>, !riscv.reg<>) -> !riscv.reg<> | ||
"test.op"(%19) : (!riscv.reg<>) -> () | ||
} | ||
} | ||
|
||
// CHECK-NEXT: %c0 = riscv.li 0 : () -> !riscv.reg<> | ||
// CHECK-NEXT: %c1 = riscv.li 1 : () -> !riscv.reg<> | ||
// CHECK-NEXT: %c8 = riscv.li 8 : () -> !riscv.reg<> | ||
// CHECK-NEXT: %c64 = riscv.li 64 : () -> !riscv.reg<> | ||
// CHECK-NEXT: riscv_scf.for %0 : !riscv.reg<> = %c0 to %c64 step %c1 { | ||
// CHECK-NEXT: %1 = riscv.li 8 : () -> !riscv.reg<> | ||
// CHECK-NEXT: "test.op"(%0) : (!riscv.reg<>) -> () | ||
// CHECK-NEXT: } | ||
|
||
// Cannot fuse outer loop with iteration arguments | ||
%res0 = riscv_scf.for %16 : !riscv.reg<> = %c0 to %c64 step %c8 iter_args(%arg0 = %c0) -> (!riscv.reg<>) { | ||
riscv_scf.for %17 : !riscv.reg<> = %c0 to %c8 step %c1 { | ||
%18 = riscv.li 8 : () -> !riscv.reg<> | ||
%19 = riscv.add %16, %17 : (!riscv.reg<>, !riscv.reg<>) -> !riscv.reg<> | ||
"test.op"(%19) : (!riscv.reg<>) -> () | ||
} | ||
riscv_scf.yield %arg0 : !riscv.reg<> | ||
} | ||
|
||
// CHECK-NEXT: %{{.*}} = riscv_scf.for %{{.*}} : !riscv.reg<> = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%{{.*}} = %{{.*}}) -> (!riscv.reg<>) { | ||
// CHECK-NEXT: riscv_scf.for %{{.*}} : !riscv.reg<> = %{{.*}} to %{{.*}} step %{{.*}} { | ||
// CHECK-NEXT: %{{.*}} = riscv.li 8 : () -> !riscv.reg<> | ||
// CHECK-NEXT: %{{.*}} = riscv.add %{{.*}}, %{{.*}} : (!riscv.reg<>, !riscv.reg<>) -> !riscv.reg<> | ||
// CHECK-NEXT: "test.op"(%{{.*}}) : (!riscv.reg<>) -> () | ||
// CHECK-NEXT: } | ||
// CHECK-NEXT: riscv_scf.yield %{{.*}} : !riscv.reg<> | ||
// CHECK-NEXT: } | ||
|
||
// Inner loop must be the only operation in the outer loop, aside from yield | ||
riscv_scf.for %16 : !riscv.reg<> = %c0 to %c64 step %c8 { | ||
riscv_scf.for %17 : !riscv.reg<> = %c0 to %c8 step %c1 { | ||
%18 = riscv.li 8 : () -> !riscv.reg<> | ||
%19 = riscv.add %16, %17 : (!riscv.reg<>, !riscv.reg<>) -> !riscv.reg<> | ||
"test.op"(%19) : (!riscv.reg<>) -> () | ||
} | ||
%20 = riscv.li 42 : () -> !riscv.reg<> | ||
} | ||
|
||
// CHECK-NEXT: riscv_scf.for %{{.*}} : !riscv.reg<> = %{{.*}} to %{{.*}} step %{{.*}} { | ||
// CHECK-NEXT: riscv_scf.for %{{.*}} : !riscv.reg<> = %{{.*}} to %{{.*}} step %{{.*}} { | ||
// CHECK-NEXT: %{{.*}} = riscv.li 8 : () -> !riscv.reg<> | ||
// CHECK-NEXT: %{{.*}} = riscv.add %{{.*}}, %{{.*}} : (!riscv.reg<>, !riscv.reg<>) -> !riscv.reg<> | ||
// CHECK-NEXT: "test.op"(%{{.*}}) : (!riscv.reg<>) -> () | ||
// CHECK-NEXT: } | ||
// CHECK-NEXT: %{{.*}} = riscv.li 42 : () -> !riscv.reg<> | ||
// CHECK-NEXT: } | ||
|
||
// Cannot fuse inner loop with iteration arguments | ||
riscv_scf.for %16 : !riscv.reg<> = %c0 to %c64 step %c8 { | ||
%res1 = riscv_scf.for %17 : !riscv.reg<> = %c0 to %c8 step %c1 iter_args(%arg1 = %c0) -> (!riscv.reg<>) { | ||
%18 = riscv.li 8 : () -> !riscv.reg<> | ||
%19 = riscv.add %16, %17 : (!riscv.reg<>, !riscv.reg<>) -> !riscv.reg<> | ||
"test.op"(%19) : (!riscv.reg<>) -> () | ||
riscv_scf.yield %arg1 : !riscv.reg<> | ||
} | ||
} | ||
// CHECK-NEXT: riscv_scf.for %{{.*}} : !riscv.reg<> = %{{.*}} to %{{.*}} step %{{.*}} { | ||
// CHECK-NEXT: %{{.*}} = riscv_scf.for %{{.*}} : !riscv.reg<> = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%{{.*}} = %{{.*}}) -> (!riscv.reg<>) { | ||
// CHECK-NEXT: %{{.*}} = riscv.li 8 : () -> !riscv.reg<> | ||
// CHECK-NEXT: %{{.*}} = riscv.add %{{.*}}, %{{.*}} : (!riscv.reg<>, !riscv.reg<>) -> !riscv.reg<> | ||
// CHECK-NEXT: "test.op"(%{{.*}}) : (!riscv.reg<>) -> () | ||
// CHECK-NEXT: riscv_scf.yield %{{.*}} : !riscv.reg<> | ||
// CHECK-NEXT: } | ||
// CHECK-NEXT: } | ||
|
||
// Cannot fuse inner loop with non-zero lb | ||
riscv_scf.for %16 : !riscv.reg<> = %c0 to %c64 step %c8 { | ||
riscv_scf.for %17 : !riscv.reg<> = %c8 to %c8 step %c1 { | ||
%18 = riscv.li 8 : () -> !riscv.reg<> | ||
%19 = riscv.add %16, %17 : (!riscv.reg<>, !riscv.reg<>) -> !riscv.reg<> | ||
"test.op"(%19) : (!riscv.reg<>) -> () | ||
} | ||
} | ||
|
||
// CHECK-NEXT: riscv_scf.for %{{.*}} : !riscv.reg<> = %{{.*}} to %{{.*}} step %{{.*}} { | ||
// CHECK-NEXT: riscv_scf.for %{{.*}} : !riscv.reg<> = %{{.*}} to %{{.*}} step %{{.*}} { | ||
// CHECK-NEXT: %{{.*}} = riscv.li 8 : () -> !riscv.reg<> | ||
// CHECK-NEXT: %{{.*}} = riscv.add %{{.*}}, %{{.*}} : (!riscv.reg<>, !riscv.reg<>) -> !riscv.reg<> | ||
// CHECK-NEXT: "test.op"(%{{.*}}) : (!riscv.reg<>) -> () | ||
// CHECK-NEXT: } | ||
// CHECK-NEXT: } | ||
|
||
|
||
// Each iter arg must only be used once, in an add | ||
|
||
riscv_scf.for %16 : !riscv.reg<> = %c0 to %c64 step %c8 { | ||
riscv_scf.for %17 : !riscv.reg<> = %c0 to %c8 step %c1 { | ||
%18 = riscv.li 8 : () -> !riscv.reg<> | ||
%19 = riscv.add %16, %17 : (!riscv.reg<>, !riscv.reg<>) -> !riscv.reg<> | ||
"test.op"(%19, %16) : (!riscv.reg<>, !riscv.reg<>) -> () | ||
} | ||
} | ||
riscv_scf.for %16 : !riscv.reg<> = %c0 to %c64 step %c8 { | ||
riscv_scf.for %17 : !riscv.reg<> = %c0 to %c8 step %c1 { | ||
%18 = riscv.li 8 : () -> !riscv.reg<> | ||
%19 = riscv.add %16, %17 : (!riscv.reg<>, !riscv.reg<>) -> !riscv.reg<> | ||
"test.op"(%19, %17) : (!riscv.reg<>, !riscv.reg<>) -> () | ||
} | ||
} | ||
riscv_scf.for %16 : !riscv.reg<> = %c0 to %c64 step %c8 { | ||
riscv_scf.for %17 : !riscv.reg<> = %c0 to %c8 step %c1 { | ||
%18 = riscv.li 8 : () -> !riscv.reg<> | ||
%19 = riscv.mul %16, %17 : (!riscv.reg<>, !riscv.reg<>) -> !riscv.reg<> | ||
"test.op"(%19) : (!riscv.reg<>) -> () | ||
} | ||
} | ||
|
||
// CHECK-NEXT: riscv_scf.for %{{.*}} : !riscv.reg<> = %{{.*}} to %{{.*}} step %{{.*}} { | ||
// CHECK-NEXT: riscv_scf.for %{{.*}} : !riscv.reg<> = %{{.*}} to %{{.*}} step %{{.*}} { | ||
// CHECK-NEXT: %{{.*}} = riscv.li 8 : () -> !riscv.reg<> | ||
// CHECK-NEXT: %{{.*}} = riscv.add %{{.*}}, %{{.*}} : (!riscv.reg<>, !riscv.reg<>) -> !riscv.reg<> | ||
// CHECK-NEXT: "test.op"(%{{.*}}, %{{.*}}) : (!riscv.reg<>, !riscv.reg<>) -> () | ||
// CHECK-NEXT: } | ||
// CHECK-NEXT: } | ||
// CHECK-NEXT: riscv_scf.for %{{.*}} : !riscv.reg<> = %{{.*}} to %{{.*}} step %{{.*}} { | ||
// CHECK-NEXT: riscv_scf.for %{{.*}} : !riscv.reg<> = %{{.*}} to %{{.*}} step %{{.*}} { | ||
// CHECK-NEXT: %{{.*}} = riscv.li 8 : () -> !riscv.reg<> | ||
// CHECK-NEXT: %{{.*}} = riscv.add %{{.*}}, %{{.*}} : (!riscv.reg<>, !riscv.reg<>) -> !riscv.reg<> | ||
// CHECK-NEXT: "test.op"(%{{.*}}, %{{.*}}) : (!riscv.reg<>, !riscv.reg<>) -> () | ||
// CHECK-NEXT: } | ||
// CHECK-NEXT: } | ||
// CHECK-NEXT: riscv_scf.for %{{.*}} : !riscv.reg<> = %{{.*}} to %{{.*}} step %{{.*}} { | ||
// CHECK-NEXT: riscv_scf.for %{{.*}} : !riscv.reg<> = %{{.*}} to %{{.*}} step %{{.*}} { | ||
// CHECK-NEXT: %{{.*}} = riscv.li 8 : () -> !riscv.reg<> | ||
// CHECK-NEXT: %{{.*}} = riscv.mul %{{.*}}, %{{.*}} : (!riscv.reg<>, !riscv.reg<>) -> !riscv.reg<> | ||
// CHECK-NEXT: "test.op"(%{{.*}}) : (!riscv.reg<>) -> () | ||
// CHECK-NEXT: } | ||
// CHECK-NEXT: } | ||
|
||
// CHECK-NEXT: } | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
from typing import cast | ||
|
||
from xdsl.dialects import builtin, riscv, riscv_scf | ||
from xdsl.ir import MLContext | ||
from xdsl.passes import ModulePass | ||
from xdsl.pattern_rewriter import ( | ||
PatternRewriter, | ||
PatternRewriteWalker, | ||
RewritePattern, | ||
op_type_rewrite_pattern, | ||
) | ||
from xdsl.transforms.canonicalization_patterns.riscv import get_constant_value | ||
|
||
|
||
class FuseNestedLoopsPattern(RewritePattern): | ||
@op_type_rewrite_pattern | ||
def match_and_rewrite(self, op: riscv_scf.ForOp, rewriter: PatternRewriter) -> None: | ||
if op.iter_args: | ||
return | ||
|
||
outer_body = op.body.block | ||
if not isinstance(inner_loop := outer_body.first_op, riscv_scf.ForOp): | ||
# Outer loop must contain inner loop | ||
return | ||
if inner_loop is not cast(riscv_scf.YieldOp, outer_body.last_op).prev_op: | ||
# Outer loop must contain only inner loop and yield | ||
return | ||
if inner_loop.iter_args: | ||
return | ||
|
||
if (inner_lb := get_constant_value(inner_loop.lb)) is None: | ||
return | ||
if inner_lb.value.data != 0: | ||
return | ||
|
||
if (inner_ub := get_constant_value(inner_loop.ub)) is None: | ||
return | ||
if (outer_step := get_constant_value(op.step)) is None: | ||
return | ||
if inner_ub != outer_step: | ||
return | ||
|
||
outer_index = outer_body.args[0] | ||
inner_index = inner_loop.body.block.args[0] | ||
|
||
if len(outer_index.uses) != 1 or len(inner_index.uses) != 1: | ||
# If the induction variable is used more than once, we can't fold it | ||
return | ||
|
||
outer_user = next(iter(outer_index.uses)).operation | ||
inner_user = next(iter(inner_index.uses)).operation | ||
if outer_user is not inner_user: | ||
return | ||
|
||
user = outer_user | ||
|
||
if not isinstance(user, riscv.AddOp): | ||
return | ||
|
||
# We can fuse | ||
user.rd.replace_by(inner_index) | ||
rewriter.erase_op(user) | ||
moved_region = rewriter.move_region_contents_to_new_regions(inner_loop.body) | ||
rewriter.erase_op(inner_loop) | ||
|
||
rewriter.replace_matched_op( | ||
riscv_scf.ForOp( | ||
op.lb, | ||
op.ub, | ||
inner_loop.step, | ||
(), | ||
moved_region, | ||
) | ||
) | ||
|
||
|
||
class RiscvScfLoopFusionPass(ModulePass): | ||
""" | ||
Folds perfect loop nests if they can be represented with a single loop. | ||
Currently does this by matching the inner loop range with the outer loop step. | ||
If the inner iteration space fits perfectly in the outer iteration step, then merge. | ||
Other conditions: | ||
- the only use of the induction arguments must be an add operation, this op is fused | ||
into a single induction argument, | ||
- the lower bound of the inner loop must be 0, | ||
- the loops must have no iteration arguments. | ||
""" | ||
|
||
name = "riscv-scf-loop-fusion" | ||
|
||
def apply(self, ctx: MLContext, op: builtin.ModuleOp) -> None: | ||
PatternRewriteWalker(FuseNestedLoopsPattern()).rewrite_module(op) |