Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix[venom]: fix invalid phis after SCCP #4181

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions tests/unit/compiler/venom/test_sccp.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,3 +211,34 @@ def test_cont_phi_const_case():
assert sccp.lattice[IRVariable("%5", version=1)].value == 106
assert sccp.lattice[IRVariable("%5", version=2)].value == 97
assert sccp.lattice[IRVariable("%5")].value == 2


def test_phi_reduction_after_unreachable_block():
ctx = IRContext()
fn = ctx.create_function("_global")

bb = fn.get_basic_block()

br1 = IRBasicBlock(IRLabel("then"), fn)
fn.append_basic_block(br1)
join = IRBasicBlock(IRLabel("join"), fn)
fn.append_basic_block(join)

op = bb.append_instruction("store", 1)
true = IRLiteral(1)
bb.append_instruction("jnz", true, br1.label, join.label)

op1 = br1.append_instruction("store", 2)

br1.append_instruction("jmp", join.label)

join.append_instruction("phi", bb.label, op, br1.label, op1)
join.append_instruction("stop")

ac = IRAnalysesCache(fn)
SCCP(ac, fn).run_pass()

assert join.instructions[0].opcode == "store", join.instructions[0]
assert join.instructions[0].operands == [op1]

assert join.instructions[1].opcode == "stop"
49 changes: 49 additions & 0 deletions tests/unit/compiler/venom/test_simplify_cfg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
from vyper.venom.analysis.analysis import IRAnalysesCache
from vyper.venom.basicblock import IRBasicBlock, IRLabel, IRLiteral
from vyper.venom.context import IRContext
from vyper.venom.passes.sccp import SCCP
from vyper.venom.passes.simplify_cfg import SimplifyCFGPass


def test_phi_reduction_after_block_pruning():
ctx = IRContext()
fn = ctx.create_function("_global")

bb = fn.get_basic_block()

br1 = IRBasicBlock(IRLabel("then"), fn)
fn.append_basic_block(br1)
br2 = IRBasicBlock(IRLabel("else"), fn)
fn.append_basic_block(br2)

join = IRBasicBlock(IRLabel("join"), fn)
fn.append_basic_block(join)

true = IRLiteral(1)
bb.append_instruction("jnz", true, br1.label, br2.label)

op1 = br1.append_instruction("store", 1)
op2 = br2.append_instruction("store", 2)

br1.append_instruction("jmp", join.label)
br2.append_instruction("jmp", join.label)

join.append_instruction("phi", br1.label, op1, br2.label, op2)
join.append_instruction("stop")

ac = IRAnalysesCache(fn)
SCCP(ac, fn).run_pass()
SimplifyCFGPass(ac, fn).run_pass()

bbs = list(fn.get_basic_blocks())

assert len(bbs) == 1
final_bb = bbs[0]

inst0, inst1, inst2 = final_bb.instructions

assert inst0.opcode == "store"
assert inst0.operands == [IRLiteral(1)]
assert inst1.opcode == "store"
assert inst1.operands == [inst0.output]
assert inst2.opcode == "stop"
42 changes: 40 additions & 2 deletions vyper/venom/passes/sccp/sccp.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,14 +56,18 @@ class SCCP(IRPass):
uses: dict[IRVariable, OrderedSet[IRInstruction]]
lattice: Lattice
work_list: list[WorkListItem]
cfg_dirty: bool
cfg_in_exec: dict[IRBasicBlock, OrderedSet[IRBasicBlock]]

cfg_dirty: bool
# list of basic blocks whose cfg_in was modified
cfg_in_modified: OrderedSet[IRBasicBlock]

def __init__(self, analyses_cache: IRAnalysesCache, function: IRFunction):
super().__init__(analyses_cache, function)
self.lattice = {}
self.work_list: list[WorkListItem] = []
self.cfg_dirty = False
self.cfg_in_modified = OrderedSet()

def run_pass(self):
self.fn = self.function
Expand All @@ -74,7 +78,9 @@ def run_pass(self):

# self._propagate_variables()

self.analyses_cache.invalidate_analysis(CFGAnalysis)
if self.cfg_dirty:
self.analyses_cache.force_analysis(CFGAnalysis)
self._fix_phi_nodes()

def _calculate_sccp(self, entry: IRBasicBlock):
"""
Expand Down Expand Up @@ -304,7 +310,11 @@ def _replace_constants(self, inst: IRInstruction):
target = inst.operands[1]
inst.opcode = "jmp"
inst.operands = [target]

self.cfg_dirty = True
for bb in inst.parent.cfg_out:
if bb.label == target:
self.cfg_in_modified.add(bb)

elif inst.opcode in ("assert", "assert_unreachable"):
lat = self._eval_from_lattice(inst.operands[0])
Expand All @@ -329,6 +339,34 @@ def _replace_constants(self, inst: IRInstruction):
if isinstance(lat, IRLiteral):
inst.operands[i] = lat

def _fix_phi_nodes(self):
charles-cooper marked this conversation as resolved.
Show resolved Hide resolved
# fix basic blocks whose cfg in was changed
# maybe this should really be done in _visit_phi
needs_sort = False

for bb in self.fn.get_basic_blocks():
cfg_in_labels = OrderedSet(in_bb.label for in_bb in bb.cfg_in)

for inst in bb.instructions:
if inst.opcode != "phi":
break
needs_sort |= self._fix_phi_inst(inst, cfg_in_labels)

# move phi instructions to the top of the block
if needs_sort:
bb.instructions.sort(key=lambda inst: inst.opcode != "phi")

def _fix_phi_inst(self, inst: IRInstruction, cfg_in_labels: OrderedSet):
operands = [op for label, op in inst.phi_operands if label in cfg_in_labels]

if len(operands) != 1:
return False

assert inst.output is not None
inst.opcode = "store"
inst.operands = operands
return True


def _meet(x: LatticeItem, y: LatticeItem) -> LatticeItem:
if x == LatticeEnum.TOP:
Expand Down
16 changes: 7 additions & 9 deletions vyper/venom/passes/simplify_cfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,23 +9,21 @@ class SimplifyCFGPass(IRPass):
visited: OrderedSet

def _merge_blocks(self, a: IRBasicBlock, b: IRBasicBlock):
a.instructions.pop()
a.instructions.pop() # pop terminating instruction
for inst in b.instructions:
assert inst.opcode != "phi", "Not implemented yet"
if inst.opcode == "phi":
a.instructions.insert(0, inst)
else:
inst.parent = a
a.instructions.append(inst)
assert inst.opcode != "phi", f"Instruction should never be phi {b}"
inst.parent = a
a.instructions.append(inst)

# Update CFG
a.cfg_out = b.cfg_out
if len(b.cfg_out) > 0:
next_bb = b.cfg_out.first()

for next_bb in a.cfg_out:
next_bb.remove_cfg_in(b)
next_bb.add_cfg_in(a)

for inst in next_bb.instructions:
# assume phi instructions are at beginning of bb
if inst.opcode != "phi":
break
inst.operands[inst.operands.index(b.label)] = a.label
Expand Down
Loading