diff --git a/tests/unit/compiler/venom/test_sccp.py b/tests/unit/compiler/venom/test_sccp.py index e65839136e..478acc1079 100644 --- a/tests/unit/compiler/venom/test_sccp.py +++ b/tests/unit/compiler/venom/test_sccp.py @@ -211,3 +211,34 @@ def test_cont_phi_const_case(): assert sccp.lattice[IRVariable("%5", version=1)].value == 106 assert sccp.lattice[IRVariable("%5", version=2)].value == 97 assert sccp.lattice[IRVariable("%5")].value == 2 + + +def test_phi_reduction_after_unreachable_block(): + ctx = IRContext() + fn = ctx.create_function("_global") + + bb = fn.get_basic_block() + + br1 = IRBasicBlock(IRLabel("then"), fn) + fn.append_basic_block(br1) + join = IRBasicBlock(IRLabel("join"), fn) + fn.append_basic_block(join) + + op = bb.append_instruction("store", 1) + true = IRLiteral(1) + bb.append_instruction("jnz", true, br1.label, join.label) + + op1 = br1.append_instruction("store", 2) + + br1.append_instruction("jmp", join.label) + + join.append_instruction("phi", bb.label, op, br1.label, op1) + join.append_instruction("stop") + + ac = IRAnalysesCache(fn) + SCCP(ac, fn).run_pass() + + assert join.instructions[0].opcode == "store", join.instructions[0] + assert join.instructions[0].operands == [op1] + + assert join.instructions[1].opcode == "stop" diff --git a/tests/unit/compiler/venom/test_simplify_cfg.py b/tests/unit/compiler/venom/test_simplify_cfg.py new file mode 100644 index 0000000000..c4bdbb263b --- /dev/null +++ b/tests/unit/compiler/venom/test_simplify_cfg.py @@ -0,0 +1,49 @@ +from vyper.venom.analysis.analysis import IRAnalysesCache +from vyper.venom.basicblock import IRBasicBlock, IRLabel, IRLiteral +from vyper.venom.context import IRContext +from vyper.venom.passes.sccp import SCCP +from vyper.venom.passes.simplify_cfg import SimplifyCFGPass + + +def test_phi_reduction_after_block_pruning(): + ctx = IRContext() + fn = ctx.create_function("_global") + + bb = fn.get_basic_block() + + br1 = IRBasicBlock(IRLabel("then"), fn) + fn.append_basic_block(br1) + br2 = IRBasicBlock(IRLabel("else"), fn) + fn.append_basic_block(br2) + + join = IRBasicBlock(IRLabel("join"), fn) + fn.append_basic_block(join) + + true = IRLiteral(1) + bb.append_instruction("jnz", true, br1.label, br2.label) + + op1 = br1.append_instruction("store", 1) + op2 = br2.append_instruction("store", 2) + + br1.append_instruction("jmp", join.label) + br2.append_instruction("jmp", join.label) + + join.append_instruction("phi", br1.label, op1, br2.label, op2) + join.append_instruction("stop") + + ac = IRAnalysesCache(fn) + SCCP(ac, fn).run_pass() + SimplifyCFGPass(ac, fn).run_pass() + + bbs = list(fn.get_basic_blocks()) + + assert len(bbs) == 1 + final_bb = bbs[0] + + inst0, inst1, inst2 = final_bb.instructions + + assert inst0.opcode == "store" + assert inst0.operands == [IRLiteral(1)] + assert inst1.opcode == "store" + assert inst1.operands == [inst0.output] + assert inst2.opcode == "stop" diff --git a/vyper/venom/passes/sccp/sccp.py b/vyper/venom/passes/sccp/sccp.py index 164d8e241d..cfac6794f8 100644 --- a/vyper/venom/passes/sccp/sccp.py +++ b/vyper/venom/passes/sccp/sccp.py @@ -56,14 +56,18 @@ class SCCP(IRPass): uses: dict[IRVariable, OrderedSet[IRInstruction]] lattice: Lattice work_list: list[WorkListItem] - cfg_dirty: bool cfg_in_exec: dict[IRBasicBlock, OrderedSet[IRBasicBlock]] + cfg_dirty: bool + # list of basic blocks whose cfg_in was modified + cfg_in_modified: OrderedSet[IRBasicBlock] + def __init__(self, analyses_cache: IRAnalysesCache, function: IRFunction): super().__init__(analyses_cache, function) self.lattice = {} self.work_list: list[WorkListItem] = [] self.cfg_dirty = False + self.cfg_in_modified = OrderedSet() def run_pass(self): self.fn = self.function @@ -74,7 +78,9 @@ def run_pass(self): # self._propagate_variables() - self.analyses_cache.invalidate_analysis(CFGAnalysis) + if self.cfg_dirty: + self.analyses_cache.force_analysis(CFGAnalysis) + self._fix_phi_nodes() def _calculate_sccp(self, entry: IRBasicBlock): """ @@ -304,7 +310,11 @@ def _replace_constants(self, inst: IRInstruction): target = inst.operands[1] inst.opcode = "jmp" inst.operands = [target] + self.cfg_dirty = True + for bb in inst.parent.cfg_out: + if bb.label == target: + self.cfg_in_modified.add(bb) elif inst.opcode in ("assert", "assert_unreachable"): lat = self._eval_from_lattice(inst.operands[0]) @@ -329,6 +339,34 @@ def _replace_constants(self, inst: IRInstruction): if isinstance(lat, IRLiteral): inst.operands[i] = lat + def _fix_phi_nodes(self): + # fix basic blocks whose cfg in was changed + # maybe this should really be done in _visit_phi + needs_sort = False + + for bb in self.fn.get_basic_blocks(): + cfg_in_labels = OrderedSet(in_bb.label for in_bb in bb.cfg_in) + + for inst in bb.instructions: + if inst.opcode != "phi": + break + needs_sort |= self._fix_phi_inst(inst, cfg_in_labels) + + # move phi instructions to the top of the block + if needs_sort: + bb.instructions.sort(key=lambda inst: inst.opcode != "phi") + + def _fix_phi_inst(self, inst: IRInstruction, cfg_in_labels: OrderedSet): + operands = [op for label, op in inst.phi_operands if label in cfg_in_labels] + + if len(operands) != 1: + return False + + assert inst.output is not None + inst.opcode = "store" + inst.operands = operands + return True + def _meet(x: LatticeItem, y: LatticeItem) -> LatticeItem: if x == LatticeEnum.TOP: diff --git a/vyper/venom/passes/simplify_cfg.py b/vyper/venom/passes/simplify_cfg.py index 08582fee96..1409f43947 100644 --- a/vyper/venom/passes/simplify_cfg.py +++ b/vyper/venom/passes/simplify_cfg.py @@ -9,23 +9,21 @@ class SimplifyCFGPass(IRPass): visited: OrderedSet def _merge_blocks(self, a: IRBasicBlock, b: IRBasicBlock): - a.instructions.pop() + a.instructions.pop() # pop terminating instruction for inst in b.instructions: - assert inst.opcode != "phi", "Not implemented yet" - if inst.opcode == "phi": - a.instructions.insert(0, inst) - else: - inst.parent = a - a.instructions.append(inst) + assert inst.opcode != "phi", f"Instruction should never be phi {b}" + inst.parent = a + a.instructions.append(inst) # Update CFG a.cfg_out = b.cfg_out - if len(b.cfg_out) > 0: - next_bb = b.cfg_out.first() + + for next_bb in a.cfg_out: next_bb.remove_cfg_in(b) next_bb.add_cfg_in(a) for inst in next_bb.instructions: + # assume phi instructions are at beginning of bb if inst.opcode != "phi": break inst.operands[inst.operands.index(b.label)] = a.label