Skip to content

Commit

Permalink
fix[venom]: fix invalid phis after SCCP (#4181)
Browse files Browse the repository at this point in the history
This commit reduces `phi` nodes which, after SCCP, refer to a
basic block which is unreachable.  This could happen when the
SCCP would replace a `jnz` by a `jmp` instruction, resulting in a
`phi` instruction which refers to a non-existent block, or a `phi`
instruction in a basic block which only has one predecessor.  This
commit reduces such `phi` instructions into `store` instructions.

---------

Co-authored-by: Harry Kalogirou <[email protected]>
Co-authored-by: Charles Cooper <[email protected]>
  • Loading branch information
3 people authored Sep 27, 2024
1 parent d60d31f commit e2f6001
Show file tree
Hide file tree
Showing 4 changed files with 127 additions and 11 deletions.
31 changes: 31 additions & 0 deletions tests/unit/compiler/venom/test_sccp.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,3 +211,34 @@ def test_cont_phi_const_case():
assert sccp.lattice[IRVariable("%5", version=1)].value == 106
assert sccp.lattice[IRVariable("%5", version=2)].value == 97
assert sccp.lattice[IRVariable("%5")].value == 2


def test_phi_reduction_after_unreachable_block():
ctx = IRContext()
fn = ctx.create_function("_global")

bb = fn.get_basic_block()

br1 = IRBasicBlock(IRLabel("then"), fn)
fn.append_basic_block(br1)
join = IRBasicBlock(IRLabel("join"), fn)
fn.append_basic_block(join)

op = bb.append_instruction("store", 1)
true = IRLiteral(1)
bb.append_instruction("jnz", true, br1.label, join.label)

op1 = br1.append_instruction("store", 2)

br1.append_instruction("jmp", join.label)

join.append_instruction("phi", bb.label, op, br1.label, op1)
join.append_instruction("stop")

ac = IRAnalysesCache(fn)
SCCP(ac, fn).run_pass()

assert join.instructions[0].opcode == "store", join.instructions[0]
assert join.instructions[0].operands == [op1]

assert join.instructions[1].opcode == "stop"
49 changes: 49 additions & 0 deletions tests/unit/compiler/venom/test_simplify_cfg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
from vyper.venom.analysis.analysis import IRAnalysesCache
from vyper.venom.basicblock import IRBasicBlock, IRLabel, IRLiteral
from vyper.venom.context import IRContext
from vyper.venom.passes.sccp import SCCP
from vyper.venom.passes.simplify_cfg import SimplifyCFGPass


def test_phi_reduction_after_block_pruning():
ctx = IRContext()
fn = ctx.create_function("_global")

bb = fn.get_basic_block()

br1 = IRBasicBlock(IRLabel("then"), fn)
fn.append_basic_block(br1)
br2 = IRBasicBlock(IRLabel("else"), fn)
fn.append_basic_block(br2)

join = IRBasicBlock(IRLabel("join"), fn)
fn.append_basic_block(join)

true = IRLiteral(1)
bb.append_instruction("jnz", true, br1.label, br2.label)

op1 = br1.append_instruction("store", 1)
op2 = br2.append_instruction("store", 2)

br1.append_instruction("jmp", join.label)
br2.append_instruction("jmp", join.label)

join.append_instruction("phi", br1.label, op1, br2.label, op2)
join.append_instruction("stop")

ac = IRAnalysesCache(fn)
SCCP(ac, fn).run_pass()
SimplifyCFGPass(ac, fn).run_pass()

bbs = list(fn.get_basic_blocks())

assert len(bbs) == 1
final_bb = bbs[0]

inst0, inst1, inst2 = final_bb.instructions

assert inst0.opcode == "store"
assert inst0.operands == [IRLiteral(1)]
assert inst1.opcode == "store"
assert inst1.operands == [inst0.output]
assert inst2.opcode == "stop"
42 changes: 40 additions & 2 deletions vyper/venom/passes/sccp/sccp.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,14 +56,18 @@ class SCCP(IRPass):
uses: dict[IRVariable, OrderedSet[IRInstruction]]
lattice: Lattice
work_list: list[WorkListItem]
cfg_dirty: bool
cfg_in_exec: dict[IRBasicBlock, OrderedSet[IRBasicBlock]]

cfg_dirty: bool
# list of basic blocks whose cfg_in was modified
cfg_in_modified: OrderedSet[IRBasicBlock]

def __init__(self, analyses_cache: IRAnalysesCache, function: IRFunction):
super().__init__(analyses_cache, function)
self.lattice = {}
self.work_list: list[WorkListItem] = []
self.cfg_dirty = False
self.cfg_in_modified = OrderedSet()

def run_pass(self):
self.fn = self.function
Expand All @@ -74,7 +78,9 @@ def run_pass(self):

# self._propagate_variables()

self.analyses_cache.invalidate_analysis(CFGAnalysis)
if self.cfg_dirty:
self.analyses_cache.force_analysis(CFGAnalysis)
self._fix_phi_nodes()

def _calculate_sccp(self, entry: IRBasicBlock):
"""
Expand Down Expand Up @@ -304,7 +310,11 @@ def _replace_constants(self, inst: IRInstruction):
target = inst.operands[1]
inst.opcode = "jmp"
inst.operands = [target]

self.cfg_dirty = True
for bb in inst.parent.cfg_out:
if bb.label == target:
self.cfg_in_modified.add(bb)

elif inst.opcode in ("assert", "assert_unreachable"):
lat = self._eval_from_lattice(inst.operands[0])
Expand All @@ -329,6 +339,34 @@ def _replace_constants(self, inst: IRInstruction):
if isinstance(lat, IRLiteral):
inst.operands[i] = lat

def _fix_phi_nodes(self):
# fix basic blocks whose cfg in was changed
# maybe this should really be done in _visit_phi
needs_sort = False

for bb in self.fn.get_basic_blocks():
cfg_in_labels = OrderedSet(in_bb.label for in_bb in bb.cfg_in)

for inst in bb.instructions:
if inst.opcode != "phi":
break
needs_sort |= self._fix_phi_inst(inst, cfg_in_labels)

# move phi instructions to the top of the block
if needs_sort:
bb.instructions.sort(key=lambda inst: inst.opcode != "phi")

def _fix_phi_inst(self, inst: IRInstruction, cfg_in_labels: OrderedSet):
operands = [op for label, op in inst.phi_operands if label in cfg_in_labels]

if len(operands) != 1:
return False

assert inst.output is not None
inst.opcode = "store"
inst.operands = operands
return True


def _meet(x: LatticeItem, y: LatticeItem) -> LatticeItem:
if x == LatticeEnum.TOP:
Expand Down
16 changes: 7 additions & 9 deletions vyper/venom/passes/simplify_cfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,23 +9,21 @@ class SimplifyCFGPass(IRPass):
visited: OrderedSet

def _merge_blocks(self, a: IRBasicBlock, b: IRBasicBlock):
a.instructions.pop()
a.instructions.pop() # pop terminating instruction
for inst in b.instructions:
assert inst.opcode != "phi", "Not implemented yet"
if inst.opcode == "phi":
a.instructions.insert(0, inst)
else:
inst.parent = a
a.instructions.append(inst)
assert inst.opcode != "phi", f"Instruction should never be phi {b}"
inst.parent = a
a.instructions.append(inst)

# Update CFG
a.cfg_out = b.cfg_out
if len(b.cfg_out) > 0:
next_bb = b.cfg_out.first()

for next_bb in a.cfg_out:
next_bb.remove_cfg_in(b)
next_bb.add_cfg_in(a)

for inst in next_bb.instructions:
# assume phi instructions are at beginning of bb
if inst.opcode != "phi":
break
inst.operands[inst.operands.index(b.label)] = a.label
Expand Down

0 comments on commit e2f6001

Please sign in to comment.