-
-
Notifications
You must be signed in to change notification settings - Fork 791
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat[venom]: add loop invariant hoisting pass #4175
base: master
Are you sure you want to change the base?
Changes from all commits
13944eb
f626b68
942c731
399a0a4
2a8dd4a
302aa21
013840c
1ee0068
f703408
c5dbf05
cd655fc
ecb272a
18c7610
5583cc5
bdb2896
788bd0d
715b128
8edce11
cf6b25e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,159 @@ | ||
import pytest | ||
|
||
from vyper.venom.analysis.analysis import IRAnalysesCache | ||
from vyper.venom.analysis.loop_detection import NaturalLoopDetectionAnalysis | ||
from vyper.venom.basicblock import IRBasicBlock, IRLabel, IRVariable | ||
from vyper.venom.context import IRContext | ||
from vyper.venom.function import IRFunction | ||
from vyper.venom.passes.loop_invariant_hosting import LoopInvariantHoisting | ||
|
||
|
||
def _create_loops(fn, depth, loop_id, body_fn=lambda _: (), top=True): | ||
bb = fn.get_basic_block() | ||
cond = IRBasicBlock(IRLabel(f"cond{loop_id}{depth}"), fn) | ||
body = IRBasicBlock(IRLabel(f"body{loop_id}{depth}"), fn) | ||
if top: | ||
exit_block = IRBasicBlock(IRLabel(f"exit_top{loop_id}{depth}"), fn) | ||
else: | ||
exit_block = IRBasicBlock(IRLabel(f"exit{loop_id}{depth}"), fn) | ||
fn.append_basic_block(cond) | ||
fn.append_basic_block(body) | ||
|
||
bb.append_instruction("jmp", cond.label) | ||
|
||
cond_var = IRVariable(f"cond_var{loop_id}{depth}") | ||
cond.append_instruction("iszero", 0, ret=cond_var) | ||
assert isinstance(cond_var, IRVariable) | ||
cond.append_instruction("jnz", cond_var, body.label, exit_block.label) | ||
body_fn(fn, loop_id, depth) | ||
if depth > 1: | ||
_create_loops(fn, depth - 1, loop_id, body_fn, top=False) | ||
bb = fn.get_basic_block() | ||
bb.append_instruction("jmp", cond.label) | ||
fn.append_basic_block(exit_block) | ||
|
||
|
||
def _simple_body(fn, loop_id, depth): | ||
assert isinstance(fn, IRFunction) | ||
bb = fn.get_basic_block() | ||
add_var = IRVariable(f"add_var{loop_id}{depth}") | ||
bb.append_instruction("add", 1, 2, ret=add_var) | ||
|
||
|
||
def _hoistable_body(fn, loop_id, depth): | ||
assert isinstance(fn, IRFunction) | ||
bb = fn.get_basic_block() | ||
store_var = IRVariable(f"store_var{loop_id}{depth}") | ||
add_var_a = IRVariable(f"add_var_a{loop_id}{depth}") | ||
bb.append_instruction("store", 1, ret=store_var) | ||
bb.append_instruction("add", 1, store_var, ret=add_var_a) | ||
add_var_b = IRVariable(f"add_var_b{loop_id}{depth}") | ||
bb.append_instruction("add", store_var, add_var_a, ret=add_var_b) | ||
|
||
|
||
@pytest.mark.parametrize("depth", range(1, 4)) | ||
@pytest.mark.parametrize("count", range(1, 4)) | ||
def test_loop_detection_analysis(depth, count): | ||
ctx = IRContext() | ||
fn = ctx.create_function("_global") | ||
|
||
for c in range(count): | ||
_create_loops(fn, depth, c, _simple_body) | ||
|
||
bb = fn.get_basic_block() | ||
bb.append_instruction("ret") | ||
|
||
ac = IRAnalysesCache(fn) | ||
analysis = ac.request_analysis(NaturalLoopDetectionAnalysis) | ||
assert len(analysis.loops) == depth * count | ||
|
||
|
||
@pytest.mark.parametrize("depth", range(1, 4)) | ||
@pytest.mark.parametrize("count", range(1, 4)) | ||
def test_loop_invariant_hoisting_simple(depth, count): | ||
ctx = IRContext() | ||
fn = ctx.create_function("_global") | ||
|
||
for c in range(count): | ||
_create_loops(fn, depth, c, _simple_body) | ||
|
||
bb = fn.get_basic_block() | ||
bb.append_instruction("ret") | ||
|
||
ac = IRAnalysesCache(fn) | ||
LoopInvariantHoisting(ac, fn).run_pass() | ||
|
||
entry = fn.entry | ||
assignments = list(map(lambda x: x.value, entry.get_assignments())) | ||
for bb in filter(lambda bb: bb.label.name.startswith("exit_top"), fn.get_basic_blocks()): | ||
assignments.extend(map(lambda x: x.value, bb.get_assignments())) | ||
|
||
assert len(assignments) == depth * count * 2 | ||
for loop_id in range(count): | ||
for d in range(1, depth + 1): | ||
assert f"%add_var{loop_id}{d}" in assignments, repr(fn) | ||
assert f"%cond_var{loop_id}{d}" in assignments, repr(fn) | ||
|
||
|
||
@pytest.mark.parametrize("depth", range(1, 4)) | ||
@pytest.mark.parametrize("count", range(1, 4)) | ||
def test_loop_invariant_hoisting_dependant(depth, count): | ||
ctx = IRContext() | ||
fn = ctx.create_function("_global") | ||
|
||
for c in range(count): | ||
_create_loops(fn, depth, c, _hoistable_body) | ||
|
||
bb = fn.get_basic_block() | ||
bb.append_instruction("ret") | ||
|
||
ac = IRAnalysesCache(fn) | ||
LoopInvariantHoisting(ac, fn).run_pass() | ||
|
||
entry = fn.entry | ||
assignments = list(map(lambda x: x.value, entry.get_assignments())) | ||
for bb in filter(lambda bb: bb.label.name.startswith("exit_top"), fn.get_basic_blocks()): | ||
assignments.extend(map(lambda x: x.value, bb.get_assignments())) | ||
|
||
assert len(assignments) == depth * count * 4 | ||
for loop_id in range(count): | ||
for d in range(1, depth + 1): | ||
assert f"%store_var{loop_id}{d}" in assignments, repr(fn) | ||
assert f"%add_var_a{loop_id}{d}" in assignments, repr(fn) | ||
assert f"%add_var_b{loop_id}{d}" in assignments, repr(fn) | ||
assert f"%cond_var{loop_id}{d}" in assignments, repr(fn) | ||
|
||
|
||
def _unhoistable_body(fn, loop_id, depth): | ||
assert isinstance(fn, IRFunction) | ||
bb = fn.get_basic_block() | ||
add_var_a = IRVariable(f"add_var_a{loop_id}{depth}") | ||
bb.append_instruction("mload", 64, ret=add_var_a) | ||
add_var_b = IRVariable(f"add_var_b{loop_id}{depth}") | ||
bb.append_instruction("add", add_var_a, 2, ret=add_var_b) | ||
|
||
|
||
@pytest.mark.parametrize("depth", range(1, 4)) | ||
@pytest.mark.parametrize("count", range(1, 4)) | ||
def test_loop_invariant_hoisting_unhoistable(depth, count): | ||
ctx = IRContext() | ||
fn = ctx.create_function("_global") | ||
|
||
for c in range(count): | ||
_create_loops(fn, depth, c, _unhoistable_body) | ||
|
||
bb = fn.get_basic_block() | ||
bb.append_instruction("ret") | ||
|
||
ac = IRAnalysesCache(fn) | ||
LoopInvariantHoisting(ac, fn).run_pass() | ||
|
||
entry = fn.entry | ||
assignments = list(map(lambda x: x.value, entry.get_assignments())) | ||
for bb in filter(lambda bb: bb.label.name.startswith("exit_top"), fn.get_basic_blocks()): | ||
assignments.extend(map(lambda x: x.value, bb.get_assignments())) | ||
|
||
assert len(assignments) == depth * count | ||
for loop_id in range(count): | ||
for d in range(1, depth + 1): | ||
assert f"%cond_var{loop_id}{d}" in assignments, repr(fn) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
from vyper.utils import OrderedSet | ||
from vyper.venom.analysis.analysis import IRAnalysis | ||
from vyper.venom.analysis.cfg import CFGAnalysis | ||
from vyper.venom.basicblock import IRBasicBlock | ||
|
||
|
||
class NaturalLoopDetectionAnalysis(IRAnalysis): | ||
""" | ||
Detects loops and computes basic blocks | ||
and the block which is before the loop | ||
""" | ||
|
||
# key = loop header | ||
# value = all the blocks that the loop contains | ||
loops: dict[IRBasicBlock, OrderedSet[IRBasicBlock]] | ||
|
||
def analyze(self): | ||
self.analyses_cache.request_analysis(CFGAnalysis) | ||
self.loops = self._find_natural_loops(self.function.entry) | ||
|
||
# Could possibly reuse the dominator tree algorithm to find the back edges | ||
# if it is already cached it will be faster. Still might need to separate the | ||
# varius extra information that the dominator analysis provides | ||
# (like frontiers and immediate dominators) | ||
def _find_back_edges(self, entry: IRBasicBlock) -> list[tuple[IRBasicBlock, IRBasicBlock]]: | ||
back_edges = [] | ||
visited = OrderedSet() | ||
stack = [] | ||
|
||
def dfs(bb: IRBasicBlock): | ||
visited.add(bb) | ||
stack.append(bb) | ||
|
||
for succ in bb.cfg_out: | ||
if succ not in visited: | ||
dfs(succ) | ||
elif succ in stack: | ||
back_edges.append((bb, succ)) | ||
|
||
stack.pop() | ||
|
||
dfs(entry) | ||
|
||
return back_edges | ||
|
||
def _find_natural_loops(self, entry: IRBasicBlock) -> dict[IRBasicBlock, OrderedSet[IRBasicBlock]]: | ||
back_edges = self._find_back_edges(entry) | ||
natural_loops = {} | ||
|
||
for u, v in back_edges: | ||
# back edge: u -> v | ||
loop = OrderedSet() | ||
stack = [u] | ||
|
||
while stack: | ||
bb = stack.pop() | ||
if bb in loop: | ||
continue | ||
loop.add(bb) | ||
for pred in bb.cfg_in: | ||
if pred != v: | ||
stack.append(pred) | ||
|
||
loop.add(v) | ||
natural_loops[v.cfg_in.first()] = loop | ||
|
||
return natural_loops | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,124 @@ | ||
from vyper.utils import OrderedSet | ||
from vyper.venom.analysis.cfg import CFGAnalysis | ||
from vyper.venom.analysis.dfg import DFGAnalysis | ||
from vyper.venom.analysis.liveness import LivenessAnalysis | ||
from vyper.venom.analysis.loop_detection import NaturalLoopDetectionAnalysis | ||
from vyper.venom.basicblock import IRBasicBlock, IRInstruction, IRLabel, IRVariable, IRLiteral | ||
from vyper.venom.function import IRFunction | ||
from vyper.venom.passes.base_pass import IRPass | ||
|
||
|
||
def _ignore_instruction(inst: IRInstruction) -> bool: | ||
return ( | ||
inst.is_volatile | ||
or inst.is_bb_terminator | ||
or inst.opcode == "returndatasize" | ||
or inst.opcode == "phi" | ||
or (inst.opcode == "add" and isinstance(inst.operands[1], IRLabel)) | ||
or inst.opcode == "store" | ||
) | ||
|
||
|
||
# must check if it has as operand as literal because | ||
# there are cases when the store just moves value | ||
# from one variable to another | ||
def _is_correct_store(inst: IRInstruction) -> bool: | ||
return inst.opcode == "store" and isinstance(inst.operands[0], IRLiteral) | ||
|
||
|
||
class LoopInvariantHoisting(IRPass): | ||
""" | ||
This pass detects invariants in loops and hoists them above the loop body. | ||
Any VOLATILE_INSTRUCTIONS, BB_TERMINATORS CFG_ALTERING_INSTRUCTIONS are ignored | ||
""" | ||
|
||
function: IRFunction | ||
loops: dict[IRBasicBlock, OrderedSet[IRBasicBlock]] | ||
dfg: DFGAnalysis | ||
|
||
def run_pass(self): | ||
self.analyses_cache.request_analysis(CFGAnalysis) | ||
self.dfg = self.analyses_cache.request_analysis(DFGAnalysis) | ||
loops = self.analyses_cache.request_analysis(NaturalLoopDetectionAnalysis) | ||
self.loops = loops.loops | ||
invalidate = False | ||
while True: | ||
change = False | ||
for from_bb, loop in self.loops.items(): | ||
hoistable: list[IRInstruction] = self._get_hoistable_loop(from_bb, loop) | ||
if len(hoistable) == 0: | ||
continue | ||
change |= True | ||
self._hoist(from_bb, hoistable) | ||
if not change: | ||
break | ||
invalidate = True | ||
|
||
# only need to invalidate if you did some hoisting | ||
if invalidate: | ||
self.analyses_cache.invalidate_analysis(LivenessAnalysis) | ||
|
||
def _hoist(self, target_bb: IRBasicBlock, hoistable: list[IRInstruction]): | ||
for inst in hoistable: | ||
bb = inst.parent | ||
bb.remove_instruction(inst) | ||
target_bb.insert_instruction(inst, index=len(target_bb.instructions) - 1) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. maybe index does not need to be specified in this case There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am modifying already terminated basic block so I it would crash on assert if the index would not be specified |
||
|
||
def _get_hoistable_loop( | ||
self, from_bb: IRBasicBlock, loop: OrderedSet[IRBasicBlock] | ||
) -> list[IRInstruction]: | ||
result: list[IRInstruction] = [] | ||
for bb in loop: | ||
result.extend(self._get_hoistable_bb(bb, from_bb)) | ||
return result | ||
|
||
def _get_hoistable_bb(self, bb: IRBasicBlock, loop_idx: IRBasicBlock) -> list[IRInstruction]: | ||
result: list[IRInstruction] = [] | ||
for inst in bb.instructions: | ||
if self._can_hoist_instruction_ignore_stores(inst, self.loops[loop_idx]): | ||
result.extend(self._store_dependencies(inst, loop_idx)) | ||
result.append(inst) | ||
|
||
return result | ||
|
||
# query store dependacies of instruction (they are not handled otherwise) | ||
def _store_dependencies( | ||
self, inst: IRInstruction, loop_idx: IRBasicBlock | ||
) -> list[IRInstruction]: | ||
result: list[IRInstruction] = [] | ||
for var in inst.get_input_variables(): | ||
source_inst = self.dfg.get_producing_instruction(var) | ||
assert isinstance(source_inst, IRInstruction) | ||
if not _is_correct_store(source_inst): | ||
continue | ||
for bb in self.loops[loop_idx]: | ||
if source_inst.parent == bb: | ||
result.append(source_inst) | ||
return result | ||
|
||
# since the stores are always hoistable this ignores | ||
# stores in analysis (their are hoisted if some instrution is dependent on them) | ||
def _can_hoist_instruction_ignore_stores( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. leave a comment explaining why this is necessary? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. maybe just call this |
||
self, inst: IRInstruction, loop: OrderedSet[IRBasicBlock] | ||
) -> bool: | ||
if _ignore_instruction(inst): | ||
return False | ||
for bb in loop: | ||
if self._dependent_in_bb(inst, bb): | ||
return False | ||
return True | ||
|
||
def _dependent_in_bb(self, inst: IRInstruction, bb: IRBasicBlock): | ||
for in_var in inst.get_input_variables(): | ||
assert isinstance(in_var, IRVariable) | ||
source_ins = self.dfg.get_producing_instruction(in_var) | ||
assert isinstance(source_ins, IRInstruction) | ||
|
||
# ignores stores since all stores are independant | ||
# and can be always hoisted | ||
if _is_correct_store(source_ins): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. how can this condition ever hold true? |
||
continue | ||
|
||
if source_ins.parent == bb: | ||
return True | ||
return False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's bring this outside the loop, and add a flag inside the loop to detect if
invalidate_analysis
should be called