Skip to content

Commit

Permalink
transformations: (riscv_scf) add a pass to fuse perfectly nested loops (
Browse files Browse the repository at this point in the history
  • Loading branch information
superlopuh authored May 6, 2024
1 parent 4a3a11d commit ab2bc78
Show file tree
Hide file tree
Showing 3 changed files with 248 additions and 0 deletions.
150 changes: 150 additions & 0 deletions tests/filecheck/dialects/riscv_scf/loop_fusion.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
// RUN: xdsl-opt -p riscv-scf-loop-fusion %s | filecheck %s

// CHECK: builtin.module {

// Success case
%c0 = riscv.li 0 : () -> !riscv.reg<>
%c1 = riscv.li 1 : () -> !riscv.reg<>
%c8 = riscv.li 8 : () -> !riscv.reg<>
%c64 = riscv.li 64 : () -> !riscv.reg<>

riscv_scf.for %16 : !riscv.reg<> = %c0 to %c64 step %c8 {
riscv_scf.for %17 : !riscv.reg<> = %c0 to %c8 step %c1 {
%18 = riscv.li 8 : () -> !riscv.reg<>
%19 = riscv.add %16, %17 : (!riscv.reg<>, !riscv.reg<>) -> !riscv.reg<>
"test.op"(%19) : (!riscv.reg<>) -> ()
}
}

// CHECK-NEXT: %c0 = riscv.li 0 : () -> !riscv.reg<>
// CHECK-NEXT: %c1 = riscv.li 1 : () -> !riscv.reg<>
// CHECK-NEXT: %c8 = riscv.li 8 : () -> !riscv.reg<>
// CHECK-NEXT: %c64 = riscv.li 64 : () -> !riscv.reg<>
// CHECK-NEXT: riscv_scf.for %0 : !riscv.reg<> = %c0 to %c64 step %c1 {
// CHECK-NEXT: %1 = riscv.li 8 : () -> !riscv.reg<>
// CHECK-NEXT: "test.op"(%0) : (!riscv.reg<>) -> ()
// CHECK-NEXT: }

// Cannot fuse outer loop with iteration arguments
%res0 = riscv_scf.for %16 : !riscv.reg<> = %c0 to %c64 step %c8 iter_args(%arg0 = %c0) -> (!riscv.reg<>) {
riscv_scf.for %17 : !riscv.reg<> = %c0 to %c8 step %c1 {
%18 = riscv.li 8 : () -> !riscv.reg<>
%19 = riscv.add %16, %17 : (!riscv.reg<>, !riscv.reg<>) -> !riscv.reg<>
"test.op"(%19) : (!riscv.reg<>) -> ()
}
riscv_scf.yield %arg0 : !riscv.reg<>
}

// CHECK-NEXT: %{{.*}} = riscv_scf.for %{{.*}} : !riscv.reg<> = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%{{.*}} = %{{.*}}) -> (!riscv.reg<>) {
// CHECK-NEXT: riscv_scf.for %{{.*}} : !riscv.reg<> = %{{.*}} to %{{.*}} step %{{.*}} {
// CHECK-NEXT: %{{.*}} = riscv.li 8 : () -> !riscv.reg<>
// CHECK-NEXT: %{{.*}} = riscv.add %{{.*}}, %{{.*}} : (!riscv.reg<>, !riscv.reg<>) -> !riscv.reg<>
// CHECK-NEXT: "test.op"(%{{.*}}) : (!riscv.reg<>) -> ()
// CHECK-NEXT: }
// CHECK-NEXT: riscv_scf.yield %{{.*}} : !riscv.reg<>
// CHECK-NEXT: }

// Inner loop must be the only operation in the outer loop, aside from yield
riscv_scf.for %16 : !riscv.reg<> = %c0 to %c64 step %c8 {
riscv_scf.for %17 : !riscv.reg<> = %c0 to %c8 step %c1 {
%18 = riscv.li 8 : () -> !riscv.reg<>
%19 = riscv.add %16, %17 : (!riscv.reg<>, !riscv.reg<>) -> !riscv.reg<>
"test.op"(%19) : (!riscv.reg<>) -> ()
}
%20 = riscv.li 42 : () -> !riscv.reg<>
}

// CHECK-NEXT: riscv_scf.for %{{.*}} : !riscv.reg<> = %{{.*}} to %{{.*}} step %{{.*}} {
// CHECK-NEXT: riscv_scf.for %{{.*}} : !riscv.reg<> = %{{.*}} to %{{.*}} step %{{.*}} {
// CHECK-NEXT: %{{.*}} = riscv.li 8 : () -> !riscv.reg<>
// CHECK-NEXT: %{{.*}} = riscv.add %{{.*}}, %{{.*}} : (!riscv.reg<>, !riscv.reg<>) -> !riscv.reg<>
// CHECK-NEXT: "test.op"(%{{.*}}) : (!riscv.reg<>) -> ()
// CHECK-NEXT: }
// CHECK-NEXT: %{{.*}} = riscv.li 42 : () -> !riscv.reg<>
// CHECK-NEXT: }

// Cannot fuse inner loop with iteration arguments
riscv_scf.for %16 : !riscv.reg<> = %c0 to %c64 step %c8 {
%res1 = riscv_scf.for %17 : !riscv.reg<> = %c0 to %c8 step %c1 iter_args(%arg1 = %c0) -> (!riscv.reg<>) {
%18 = riscv.li 8 : () -> !riscv.reg<>
%19 = riscv.add %16, %17 : (!riscv.reg<>, !riscv.reg<>) -> !riscv.reg<>
"test.op"(%19) : (!riscv.reg<>) -> ()
riscv_scf.yield %arg1 : !riscv.reg<>
}
}
// CHECK-NEXT: riscv_scf.for %{{.*}} : !riscv.reg<> = %{{.*}} to %{{.*}} step %{{.*}} {
// CHECK-NEXT: %{{.*}} = riscv_scf.for %{{.*}} : !riscv.reg<> = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%{{.*}} = %{{.*}}) -> (!riscv.reg<>) {
// CHECK-NEXT: %{{.*}} = riscv.li 8 : () -> !riscv.reg<>
// CHECK-NEXT: %{{.*}} = riscv.add %{{.*}}, %{{.*}} : (!riscv.reg<>, !riscv.reg<>) -> !riscv.reg<>
// CHECK-NEXT: "test.op"(%{{.*}}) : (!riscv.reg<>) -> ()
// CHECK-NEXT: riscv_scf.yield %{{.*}} : !riscv.reg<>
// CHECK-NEXT: }
// CHECK-NEXT: }

// Cannot fuse inner loop with non-zero lb
riscv_scf.for %16 : !riscv.reg<> = %c0 to %c64 step %c8 {
riscv_scf.for %17 : !riscv.reg<> = %c8 to %c8 step %c1 {
%18 = riscv.li 8 : () -> !riscv.reg<>
%19 = riscv.add %16, %17 : (!riscv.reg<>, !riscv.reg<>) -> !riscv.reg<>
"test.op"(%19) : (!riscv.reg<>) -> ()
}
}

// CHECK-NEXT: riscv_scf.for %{{.*}} : !riscv.reg<> = %{{.*}} to %{{.*}} step %{{.*}} {
// CHECK-NEXT: riscv_scf.for %{{.*}} : !riscv.reg<> = %{{.*}} to %{{.*}} step %{{.*}} {
// CHECK-NEXT: %{{.*}} = riscv.li 8 : () -> !riscv.reg<>
// CHECK-NEXT: %{{.*}} = riscv.add %{{.*}}, %{{.*}} : (!riscv.reg<>, !riscv.reg<>) -> !riscv.reg<>
// CHECK-NEXT: "test.op"(%{{.*}}) : (!riscv.reg<>) -> ()
// CHECK-NEXT: }
// CHECK-NEXT: }


// Each iter arg must only be used once, in an add

riscv_scf.for %16 : !riscv.reg<> = %c0 to %c64 step %c8 {
riscv_scf.for %17 : !riscv.reg<> = %c0 to %c8 step %c1 {
%18 = riscv.li 8 : () -> !riscv.reg<>
%19 = riscv.add %16, %17 : (!riscv.reg<>, !riscv.reg<>) -> !riscv.reg<>
"test.op"(%19, %16) : (!riscv.reg<>, !riscv.reg<>) -> ()
}
}
riscv_scf.for %16 : !riscv.reg<> = %c0 to %c64 step %c8 {
riscv_scf.for %17 : !riscv.reg<> = %c0 to %c8 step %c1 {
%18 = riscv.li 8 : () -> !riscv.reg<>
%19 = riscv.add %16, %17 : (!riscv.reg<>, !riscv.reg<>) -> !riscv.reg<>
"test.op"(%19, %17) : (!riscv.reg<>, !riscv.reg<>) -> ()
}
}
riscv_scf.for %16 : !riscv.reg<> = %c0 to %c64 step %c8 {
riscv_scf.for %17 : !riscv.reg<> = %c0 to %c8 step %c1 {
%18 = riscv.li 8 : () -> !riscv.reg<>
%19 = riscv.mul %16, %17 : (!riscv.reg<>, !riscv.reg<>) -> !riscv.reg<>
"test.op"(%19) : (!riscv.reg<>) -> ()
}
}

// CHECK-NEXT: riscv_scf.for %{{.*}} : !riscv.reg<> = %{{.*}} to %{{.*}} step %{{.*}} {
// CHECK-NEXT: riscv_scf.for %{{.*}} : !riscv.reg<> = %{{.*}} to %{{.*}} step %{{.*}} {
// CHECK-NEXT: %{{.*}} = riscv.li 8 : () -> !riscv.reg<>
// CHECK-NEXT: %{{.*}} = riscv.add %{{.*}}, %{{.*}} : (!riscv.reg<>, !riscv.reg<>) -> !riscv.reg<>
// CHECK-NEXT: "test.op"(%{{.*}}, %{{.*}}) : (!riscv.reg<>, !riscv.reg<>) -> ()
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: riscv_scf.for %{{.*}} : !riscv.reg<> = %{{.*}} to %{{.*}} step %{{.*}} {
// CHECK-NEXT: riscv_scf.for %{{.*}} : !riscv.reg<> = %{{.*}} to %{{.*}} step %{{.*}} {
// CHECK-NEXT: %{{.*}} = riscv.li 8 : () -> !riscv.reg<>
// CHECK-NEXT: %{{.*}} = riscv.add %{{.*}}, %{{.*}} : (!riscv.reg<>, !riscv.reg<>) -> !riscv.reg<>
// CHECK-NEXT: "test.op"(%{{.*}}, %{{.*}}) : (!riscv.reg<>, !riscv.reg<>) -> ()
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: riscv_scf.for %{{.*}} : !riscv.reg<> = %{{.*}} to %{{.*}} step %{{.*}} {
// CHECK-NEXT: riscv_scf.for %{{.*}} : !riscv.reg<> = %{{.*}} to %{{.*}} step %{{.*}} {
// CHECK-NEXT: %{{.*}} = riscv.li 8 : () -> !riscv.reg<>
// CHECK-NEXT: %{{.*}} = riscv.mul %{{.*}}, %{{.*}} : (!riscv.reg<>, !riscv.reg<>) -> !riscv.reg<>
// CHECK-NEXT: "test.op"(%{{.*}}) : (!riscv.reg<>) -> ()
// CHECK-NEXT: }
// CHECK-NEXT: }

// CHECK-NEXT: }


6 changes: 6 additions & 0 deletions xdsl/tools/command_line_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,6 +460,11 @@ def get_riscv_register_allocation():

return riscv_register_allocation.RISCVRegisterAllocation

def get_riscv_scf_loop_fusion():
from xdsl.transforms import riscv_scf_loop_fusion

return riscv_scf_loop_fusion.RiscvScfLoopFusionPass

def get_riscv_scf_loop_range_folding():
from xdsl.transforms import riscv_scf_loop_range_folding

Expand Down Expand Up @@ -603,6 +608,7 @@ def get_test_lower_snitch_stream_to_asm():
"replace-incompatible-fpga": get_replace_incompatible_fpga,
"riscv-allocate-registers": get_riscv_register_allocation,
"riscv-cse": get_riscv_cse,
"riscv-scf-loop-fusion": get_riscv_scf_loop_fusion,
"riscv-scf-loop-range-folding": get_riscv_scf_loop_range_folding,
"scf-parallel-loop-tiling": get_scf_parallel_loop_tiling,
"snitch-allocate-registers": get_snitch_register_allocation,
Expand Down
92 changes: 92 additions & 0 deletions xdsl/transforms/riscv_scf_loop_fusion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
from typing import cast

from xdsl.dialects import builtin, riscv, riscv_scf
from xdsl.ir import MLContext
from xdsl.passes import ModulePass
from xdsl.pattern_rewriter import (
PatternRewriter,
PatternRewriteWalker,
RewritePattern,
op_type_rewrite_pattern,
)
from xdsl.transforms.canonicalization_patterns.riscv import get_constant_value


class FuseNestedLoopsPattern(RewritePattern):
@op_type_rewrite_pattern
def match_and_rewrite(self, op: riscv_scf.ForOp, rewriter: PatternRewriter) -> None:
if op.iter_args:
return

outer_body = op.body.block
if not isinstance(inner_loop := outer_body.first_op, riscv_scf.ForOp):
# Outer loop must contain inner loop
return
if inner_loop is not cast(riscv_scf.YieldOp, outer_body.last_op).prev_op:
# Outer loop must contain only inner loop and yield
return
if inner_loop.iter_args:
return

if (inner_lb := get_constant_value(inner_loop.lb)) is None:
return
if inner_lb.value.data != 0:
return

if (inner_ub := get_constant_value(inner_loop.ub)) is None:
return
if (outer_step := get_constant_value(op.step)) is None:
return
if inner_ub != outer_step:
return

outer_index = outer_body.args[0]
inner_index = inner_loop.body.block.args[0]

if len(outer_index.uses) != 1 or len(inner_index.uses) != 1:
# If the induction variable is used more than once, we can't fold it
return

outer_user = next(iter(outer_index.uses)).operation
inner_user = next(iter(inner_index.uses)).operation
if outer_user is not inner_user:
return

user = outer_user

if not isinstance(user, riscv.AddOp):
return

# We can fuse
user.rd.replace_by(inner_index)
rewriter.erase_op(user)
moved_region = rewriter.move_region_contents_to_new_regions(inner_loop.body)
rewriter.erase_op(inner_loop)

rewriter.replace_matched_op(
riscv_scf.ForOp(
op.lb,
op.ub,
inner_loop.step,
(),
moved_region,
)
)


class RiscvScfLoopFusionPass(ModulePass):
"""
Folds perfect loop nests if they can be represented with a single loop.
Currently does this by matching the inner loop range with the outer loop step.
If the inner iteration space fits perfectly in the outer iteration step, then merge.
Other conditions:
- the only use of the induction arguments must be an add operation, this op is fused
into a single induction argument,
- the lower bound of the inner loop must be 0,
- the loops must have no iteration arguments.
"""

name = "riscv-scf-loop-fusion"

def apply(self, ctx: MLContext, op: builtin.ModuleOp) -> None:
PatternRewriteWalker(FuseNestedLoopsPattern()).rewrite_module(op)

0 comments on commit ab2bc78

Please sign in to comment.