diff --git a/tests/unit/compiler/venom/test_sccp.py b/tests/unit/compiler/venom/test_sccp.py index e65839136e..478acc1079 100644 --- a/tests/unit/compiler/venom/test_sccp.py +++ b/tests/unit/compiler/venom/test_sccp.py @@ -211,3 +211,34 @@ def test_cont_phi_const_case(): assert sccp.lattice[IRVariable("%5", version=1)].value == 106 assert sccp.lattice[IRVariable("%5", version=2)].value == 97 assert sccp.lattice[IRVariable("%5")].value == 2 + + +def test_phi_reduction_after_unreachable_block(): + ctx = IRContext() + fn = ctx.create_function("_global") + + bb = fn.get_basic_block() + + br1 = IRBasicBlock(IRLabel("then"), fn) + fn.append_basic_block(br1) + join = IRBasicBlock(IRLabel("join"), fn) + fn.append_basic_block(join) + + op = bb.append_instruction("store", 1) + true = IRLiteral(1) + bb.append_instruction("jnz", true, br1.label, join.label) + + op1 = br1.append_instruction("store", 2) + + br1.append_instruction("jmp", join.label) + + join.append_instruction("phi", bb.label, op, br1.label, op1) + join.append_instruction("stop") + + ac = IRAnalysesCache(fn) + SCCP(ac, fn).run_pass() + + assert join.instructions[0].opcode == "store", join.instructions[0] + assert join.instructions[0].operands == [op1] + + assert join.instructions[1].opcode == "stop" diff --git a/tests/unit/compiler/venom/test_simplify_cfg.py b/tests/unit/compiler/venom/test_simplify_cfg.py new file mode 100644 index 0000000000..c4bdbb263b --- /dev/null +++ b/tests/unit/compiler/venom/test_simplify_cfg.py @@ -0,0 +1,49 @@ +from vyper.venom.analysis.analysis import IRAnalysesCache +from vyper.venom.basicblock import IRBasicBlock, IRLabel, IRLiteral +from vyper.venom.context import IRContext +from vyper.venom.passes.sccp import SCCP +from vyper.venom.passes.simplify_cfg import SimplifyCFGPass + + +def test_phi_reduction_after_block_pruning(): + ctx = IRContext() + fn = ctx.create_function("_global") + + bb = fn.get_basic_block() + + br1 = IRBasicBlock(IRLabel("then"), fn) + fn.append_basic_block(br1) + br2 = IRBasicBlock(IRLabel("else"), fn) + fn.append_basic_block(br2) + + join = IRBasicBlock(IRLabel("join"), fn) + fn.append_basic_block(join) + + true = IRLiteral(1) + bb.append_instruction("jnz", true, br1.label, br2.label) + + op1 = br1.append_instruction("store", 1) + op2 = br2.append_instruction("store", 2) + + br1.append_instruction("jmp", join.label) + br2.append_instruction("jmp", join.label) + + join.append_instruction("phi", br1.label, op1, br2.label, op2) + join.append_instruction("stop") + + ac = IRAnalysesCache(fn) + SCCP(ac, fn).run_pass() + SimplifyCFGPass(ac, fn).run_pass() + + bbs = list(fn.get_basic_blocks()) + + assert len(bbs) == 1 + final_bb = bbs[0] + + inst0, inst1, inst2 = final_bb.instructions + + assert inst0.opcode == "store" + assert inst0.operands == [IRLiteral(1)] + assert inst1.opcode == "store" + assert inst1.operands == [inst0.output] + assert inst2.opcode == "stop" diff --git a/vyper/venom/analysis/dup_requirements.py b/vyper/venom/analysis/dup_requirements.py deleted file mode 100644 index 7afb315035..0000000000 --- a/vyper/venom/analysis/dup_requirements.py +++ /dev/null @@ -1,15 +0,0 @@ -from vyper.utils import OrderedSet -from vyper.venom.analysis.analysis import IRAnalysis - - -class DupRequirementsAnalysis(IRAnalysis): - def analyze(self): - for bb in self.function.get_basic_blocks(): - last_liveness = bb.out_vars - for inst in reversed(bb.instructions): - inst.dup_requirements = OrderedSet() - ops = inst.get_input_variables() - for op in ops: - if op in last_liveness: - inst.dup_requirements.add(op) - last_liveness = inst.liveness diff --git a/vyper/venom/basicblock.py b/vyper/venom/basicblock.py index d6fb9560cd..1199579b3f 100644 --- a/vyper/venom/basicblock.py +++ b/vyper/venom/basicblock.py @@ -209,7 +209,6 @@ class IRInstruction: output: Optional[IROperand] # set of live variables at this instruction liveness: OrderedSet[IRVariable] - dup_requirements: OrderedSet[IRVariable] parent: "IRBasicBlock" fence_id: int annotation: Optional[str] @@ -228,7 +227,6 @@ def __init__( self.operands = list(operands) # in case we get an iterator self.output = output self.liveness = OrderedSet() - self.dup_requirements = OrderedSet() self.fence_id = -1 self.annotation = None self.ast_source = None diff --git a/vyper/venom/passes/sccp/sccp.py b/vyper/venom/passes/sccp/sccp.py index 164d8e241d..013583ec63 100644 --- a/vyper/venom/passes/sccp/sccp.py +++ b/vyper/venom/passes/sccp/sccp.py @@ -56,9 +56,10 @@ class SCCP(IRPass): uses: dict[IRVariable, OrderedSet[IRInstruction]] lattice: Lattice work_list: list[WorkListItem] - cfg_dirty: bool cfg_in_exec: dict[IRBasicBlock, OrderedSet[IRBasicBlock]] + cfg_dirty: bool + def __init__(self, analyses_cache: IRAnalysesCache, function: IRFunction): super().__init__(analyses_cache, function) self.lattice = {} @@ -72,9 +73,9 @@ def run_pass(self): self._calculate_sccp(self.fn.entry) self._propagate_constants() - # self._propagate_variables() - - self.analyses_cache.invalidate_analysis(CFGAnalysis) + if self.cfg_dirty: + self.analyses_cache.force_analysis(CFGAnalysis) + self._fix_phi_nodes() def _calculate_sccp(self, entry: IRBasicBlock): """ @@ -304,6 +305,7 @@ def _replace_constants(self, inst: IRInstruction): target = inst.operands[1] inst.opcode = "jmp" inst.operands = [target] + self.cfg_dirty = True elif inst.opcode in ("assert", "assert_unreachable"): @@ -329,6 +331,34 @@ def _replace_constants(self, inst: IRInstruction): if isinstance(lat, IRLiteral): inst.operands[i] = lat + def _fix_phi_nodes(self): + # fix basic blocks whose cfg in was changed + # maybe this should really be done in _visit_phi + needs_sort = False + + for bb in self.fn.get_basic_blocks(): + cfg_in_labels = OrderedSet(in_bb.label for in_bb in bb.cfg_in) + + for inst in bb.instructions: + if inst.opcode != "phi": + break + needs_sort |= self._fix_phi_inst(inst, cfg_in_labels) + + # move phi instructions to the top of the block + if needs_sort: + bb.instructions.sort(key=lambda inst: inst.opcode != "phi") + + def _fix_phi_inst(self, inst: IRInstruction, cfg_in_labels: OrderedSet): + operands = [op for label, op in inst.phi_operands if label in cfg_in_labels] + + if len(operands) != 1: + return False + + assert inst.output is not None + inst.opcode = "store" + inst.operands = operands + return True + def _meet(x: LatticeItem, y: LatticeItem) -> LatticeItem: if x == LatticeEnum.TOP: diff --git a/vyper/venom/passes/simplify_cfg.py b/vyper/venom/passes/simplify_cfg.py index 08582fee96..1409f43947 100644 --- a/vyper/venom/passes/simplify_cfg.py +++ b/vyper/venom/passes/simplify_cfg.py @@ -9,23 +9,21 @@ class SimplifyCFGPass(IRPass): visited: OrderedSet def _merge_blocks(self, a: IRBasicBlock, b: IRBasicBlock): - a.instructions.pop() + a.instructions.pop() # pop terminating instruction for inst in b.instructions: - assert inst.opcode != "phi", "Not implemented yet" - if inst.opcode == "phi": - a.instructions.insert(0, inst) - else: - inst.parent = a - a.instructions.append(inst) + assert inst.opcode != "phi", f"Instruction should never be phi {b}" + inst.parent = a + a.instructions.append(inst) # Update CFG a.cfg_out = b.cfg_out - if len(b.cfg_out) > 0: - next_bb = b.cfg_out.first() + + for next_bb in a.cfg_out: next_bb.remove_cfg_in(b) next_bb.add_cfg_in(a) for inst in next_bb.instructions: + # assume phi instructions are at beginning of bb if inst.opcode != "phi": break inst.operands[inst.operands.index(b.label)] = a.label diff --git a/vyper/venom/venom_to_assembly.py b/vyper/venom/venom_to_assembly.py index 41a76319d7..390fab8e7c 100644 --- a/vyper/venom/venom_to_assembly.py +++ b/vyper/venom/venom_to_assembly.py @@ -12,7 +12,6 @@ ) from vyper.utils import MemoryPositions, OrderedSet from vyper.venom.analysis.analysis import IRAnalysesCache -from vyper.venom.analysis.dup_requirements import DupRequirementsAnalysis from vyper.venom.analysis.liveness import LivenessAnalysis from vyper.venom.basicblock import ( IRBasicBlock, @@ -153,7 +152,6 @@ def generate_evm(self, no_optimize: bool = False) -> list[str]: NormalizationPass(ac, fn).run_pass() self.liveness_analysis = ac.request_analysis(LivenessAnalysis) - ac.request_analysis(DupRequirementsAnalysis) assert fn.normalized, "Non-normalized CFG!" @@ -231,7 +229,12 @@ def _stack_reorder( return cost def _emit_input_operands( - self, assembly: list, inst: IRInstruction, ops: list[IROperand], stack: StackModel + self, + assembly: list, + inst: IRInstruction, + ops: list[IROperand], + stack: StackModel, + next_liveness: OrderedSet[IRVariable], ) -> None: # PRE: we already have all the items on the stack that have # been scheduled to be killed. now it's just a matter of emitting @@ -241,7 +244,7 @@ def _emit_input_operands( # it with something that is wanted if ops and stack.height > 0 and stack.peek(0) not in ops: for op in ops: - if isinstance(op, IRVariable) and op not in inst.dup_requirements: + if isinstance(op, IRVariable) and op not in next_liveness: self.swap_op(assembly, stack, op) break @@ -264,7 +267,7 @@ def _emit_input_operands( stack.push(op) continue - if op in inst.dup_requirements and op not in emitted_ops: + if op in next_liveness and op not in emitted_ops: self.dup_op(assembly, stack, op) if op in emitted_ops: @@ -288,7 +291,9 @@ def _generate_evm_for_basicblock_r( all_insts = sorted(basicblock.instructions, key=lambda x: x.opcode != "param") for i, inst in enumerate(all_insts): - next_liveness = all_insts[i + 1].liveness if i + 1 < len(all_insts) else OrderedSet() + next_liveness = ( + all_insts[i + 1].liveness if i + 1 < len(all_insts) else basicblock.out_vars + ) asm.extend(self._generate_evm_for_instruction(inst, stack, next_liveness)) @@ -327,10 +332,9 @@ def clean_stack_from_cfg_in( self.pop(asm, stack) def _generate_evm_for_instruction( - self, inst: IRInstruction, stack: StackModel, next_liveness: OrderedSet = None + self, inst: IRInstruction, stack: StackModel, next_liveness: OrderedSet ) -> list[str]: assembly: list[str | int] = [] - next_liveness = next_liveness or OrderedSet() opcode = inst.opcode # @@ -375,7 +379,7 @@ def _generate_evm_for_instruction( # example, for `%56 = %label1 %13 %label2 %14`, we will # find an instance of %13 *or* %14 in the stack and replace it with %56. to_be_replaced = stack.peek(depth) - if to_be_replaced in inst.dup_requirements: + if to_be_replaced in next_liveness: # %13/%14 is still live(!), so we make a copy of it self.dup(assembly, stack, depth) stack.poke(0, ret) @@ -390,7 +394,7 @@ def _generate_evm_for_instruction( return apply_line_numbers(inst, assembly) # Step 2: Emit instruction's input operands - self._emit_input_operands(assembly, inst, operands, stack) + self._emit_input_operands(assembly, inst, operands, stack, next_liveness) # Step 3: Reorder stack before join points if opcode == "jmp":