diff --git a/tests/functional/codegen/features/test_assert.py b/tests/functional/codegen/features/test_assert.py index df379d3f16..de9dd17ef6 100644 --- a/tests/functional/codegen/features/test_assert.py +++ b/tests/functional/codegen/features/test_assert.py @@ -107,14 +107,6 @@ def test(): assert self.ret1() == 1 """, """ -@internal -def valid_address(sender: address) -> bool: - selfdestruct(sender) -@external -def test(): - assert self.valid_address(msg.sender) - """, - """ @external def test(): assert raw_call(msg.sender, b'', max_outsize=1, gas=10, value=1000*1000) == b'' diff --git a/tests/functional/codegen/features/test_conditionals.py b/tests/functional/codegen/features/test_conditionals.py index 15ccc40bdf..3b0e57eeca 100644 --- a/tests/functional/codegen/features/test_conditionals.py +++ b/tests/functional/codegen/features/test_conditionals.py @@ -7,7 +7,6 @@ def foo(i: bool) -> int128: else: assert 2 != 0 return 7 - return 11 """ c = get_contract_with_gas_estimation(conditional_return_code) diff --git a/tests/functional/syntax/test_unbalanced_return.py b/tests/functional/syntax/test_unbalanced_return.py index d1d9732777..d5754f0053 100644 --- a/tests/functional/syntax/test_unbalanced_return.py +++ b/tests/functional/syntax/test_unbalanced_return.py @@ -8,7 +8,7 @@ """ @external def foo() -> int128: - pass + pass # missing return """, FunctionDeclarationException, ), @@ -18,6 +18,7 @@ def foo() -> int128: def foo() -> int128: if False: return 123 + # missing return """, FunctionDeclarationException, ), @@ -27,19 +28,19 @@ def foo() -> int128: def test() -> int128: if 1 == 1 : return 1 - if True: + if True: # unreachable return 0 else: assert msg.sender != msg.sender """, - FunctionDeclarationException, + StructureException, ), ( """ @internal def valid_address(sender: address) -> bool: selfdestruct(sender) - return True + return True # unreachable """, StructureException, ), @@ -48,7 +49,7 @@ def valid_address(sender: address) -> bool: @internal def valid_address(sender: address) -> bool: selfdestruct(sender) - a: address = sender + a: address = sender # unreachable """, StructureException, ), @@ -58,7 +59,7 @@ def valid_address(sender: address) -> bool: def valid_address(sender: address) -> bool: if sender == empty(address): selfdestruct(sender) - _sender: address = sender + _sender: address = sender # unreachable else: return False """, @@ -69,7 +70,7 @@ def valid_address(sender: address) -> bool: @internal def foo() -> bool: raw_revert(b"vyper") - return True + return True # unreachable """, StructureException, ), @@ -78,7 +79,7 @@ def foo() -> bool: @internal def foo() -> bool: raw_revert(b"vyper") - x: uint256 = 3 + x: uint256 = 3 # unreachable """, StructureException, ), @@ -88,12 +89,35 @@ def foo() -> bool: def foo(x: uint256) -> bool: if x == 2: raw_revert(b"vyper") - a: uint256 = 3 + a: uint256 = 3 # unreachable else: return False """, StructureException, ), + ( + """ +@internal +def foo(): + return + return # unreachable + """, + StructureException, + ), + ( + """ +@internal +def foo() -> uint256: + if block.number % 2 == 0: + return 5 + elif block.number % 3 == 0: + return 6 + else: + return 10 + return 0 # unreachable + """, + StructureException, + ), ] @@ -154,7 +178,6 @@ def test() -> int128: else: x = keccak256(x) return 1 - return 1 """, """ @external diff --git a/vyper/ast/nodes.py b/vyper/ast/nodes.py index df419daa25..de15fb9075 100644 --- a/vyper/ast/nodes.py +++ b/vyper/ast/nodes.py @@ -237,8 +237,6 @@ class VyperNode: Field names that, if present, must be set to None or a `SyntaxException` is raised. This attribute is used to exclude syntax that is valid in Python but not in Vyper. - _is_terminus : bool, optional - If `True`, indicates that execution halts upon reaching this node. _translated_fields : Dict, optional Field names that are reassigned if encountered. Used to normalize fields across different Python versions. @@ -389,6 +387,13 @@ def is_literal_value(self): """ return False + @property + def is_terminus(self): + """ + Check if execution halts upon reaching this node. + """ + return False + @property def has_folded_value(self): """ @@ -711,12 +716,19 @@ class Stmt(VyperNode): class Return(Stmt): __slots__ = ("value",) - _is_terminus = True + + @property + def is_terminus(self): + return True class Expr(Stmt): __slots__ = ("value",) + @property + def is_terminus(self): + return self.value.is_terminus + class Log(Stmt): __slots__ = ("value",) @@ -1187,6 +1199,21 @@ def _op(self, left, right): class Call(ExprNode): __slots__ = ("func", "args", "keywords") + @property + def is_terminus(self): + # cursed import cycle! + from vyper.builtins.functions import get_builtin_functions + + if not isinstance(self.func, Name): + return False + + funcname = self.func.id + builtin_t = get_builtin_functions().get(funcname) + if builtin_t is None: + return False + + return builtin_t._is_terminus + class keyword(VyperNode): __slots__ = ("arg", "value") @@ -1322,7 +1349,10 @@ class AugAssign(Stmt): class Raise(Stmt): __slots__ = ("exc",) _only_empty_fields = ("cause",) - _is_terminus = True + + @property + def is_terminus(self): + return True class Assert(Stmt): diff --git a/vyper/builtins/_signatures.py b/vyper/builtins/_signatures.py index aac008ad1e..1a488f39e0 100644 --- a/vyper/builtins/_signatures.py +++ b/vyper/builtins/_signatures.py @@ -85,6 +85,7 @@ class BuiltinFunctionT(VyperType): _kwargs: dict[str, KwargSettings] = {} _modifiability: Modifiability = Modifiability.MODIFIABLE _return_type: Optional[VyperType] = None + _is_terminus = False # helper function to deal with TYPE_DEFINITIONs def _validate_single(self, arg: vy_ast.VyperNode, expected_type: VyperType) -> None: diff --git a/vyper/codegen/core.py b/vyper/codegen/core.py index c16de3c55a..c3215f8c16 100644 --- a/vyper/codegen/core.py +++ b/vyper/codegen/core.py @@ -1,12 +1,11 @@ import contextlib from typing import Generator -from vyper import ast as vy_ast from vyper.codegen.ir_node import Encoding, IRnode from vyper.compiler.settings import OptimizationLevel from vyper.evm.address_space import CALLDATA, DATA, IMMUTABLES, MEMORY, STORAGE, TRANSIENT from vyper.evm.opcodes import version_check -from vyper.exceptions import CompilerPanic, StructureException, TypeCheckFailure, TypeMismatch +from vyper.exceptions import CompilerPanic, TypeCheckFailure, TypeMismatch from vyper.semantics.types import ( AddressT, BoolT, @@ -1035,43 +1034,6 @@ def eval_seq(ir_node): return None -def is_return_from_function(node): - if isinstance(node, vy_ast.Expr) and node.get("value.func.id") in ( - "raw_revert", - "selfdestruct", - ): - return True - if isinstance(node, (vy_ast.Return, vy_ast.Raise)): - return True - return False - - -# TODO this is almost certainly duplicated with check_terminus_node -# in vyper/semantics/analysis/local.py -def check_single_exit(fn_node): - _check_return_body(fn_node, fn_node.body) - for node in fn_node.get_descendants(vy_ast.If): - _check_return_body(node, node.body) - if node.orelse: - _check_return_body(node, node.orelse) - - -def _check_return_body(node, node_list): - return_count = len([n for n in node_list if is_return_from_function(n)]) - if return_count > 1: - raise StructureException( - "Too too many exit statements (return, raise or selfdestruct).", node - ) - # Check for invalid code after returns. - last_node_pos = len(node_list) - 1 - for idx, n in enumerate(node_list): - if is_return_from_function(n) and idx < last_node_pos: - # is not last statement in body. - raise StructureException( - "Exit statement with succeeding code (that will not execute).", node_list[idx + 1] - ) - - def mzero(dst, nbytes): # calldatacopy from past-the-end gives zero bytes. # cf. YP H.2 (ops section) with CALLDATACOPY spec. diff --git a/vyper/codegen/function_definitions/common.py b/vyper/codegen/function_definitions/common.py index 454ba9c8cd..5877ff3d13 100644 --- a/vyper/codegen/function_definitions/common.py +++ b/vyper/codegen/function_definitions/common.py @@ -4,7 +4,6 @@ import vyper.ast as vy_ast from vyper.codegen.context import Constancy, Context -from vyper.codegen.core import check_single_exit from vyper.codegen.function_definitions.external_function import generate_ir_for_external_function from vyper.codegen.function_definitions.internal_function import generate_ir_for_internal_function from vyper.codegen.ir_node import IRnode @@ -115,10 +114,6 @@ def generate_ir_for_function( # generate _FuncIRInfo func_t._ir_info = _FuncIRInfo(func_t) - # Validate return statements. - # XXX: This should really be in semantics pass. - check_single_exit(code) - callees = func_t.called_functions # we start our function frame from the largest callee frame diff --git a/vyper/codegen/stmt.py b/vyper/codegen/stmt.py index a47faefeb1..7d4938f287 100644 --- a/vyper/codegen/stmt.py +++ b/vyper/codegen/stmt.py @@ -15,7 +15,6 @@ get_dyn_array_count, get_element_ptr, getpos, - is_return_from_function, make_byte_array_copier, make_setter, pop_dyn_array, @@ -404,7 +403,7 @@ def parse_stmt(stmt, context): def _is_terminated(code): last_stmt = code[-1] - if is_return_from_function(last_stmt): + if last_stmt.is_terminus: return True if isinstance(last_stmt, vy_ast.If): diff --git a/vyper/semantics/analysis/local.py b/vyper/semantics/analysis/local.py index cc8ddaf98d..c4af5b1e3a 100644 --- a/vyper/semantics/analysis/local.py +++ b/vyper/semantics/analysis/local.py @@ -66,26 +66,28 @@ def validate_functions(vy_module: vy_ast.Module) -> None: err_list.raise_if_not_empty() -def _is_terminus_node(node: vy_ast.VyperNode) -> bool: - if getattr(node, "_is_terminus", None): - return True - if isinstance(node, vy_ast.Expr) and isinstance(node.value, vy_ast.Call): - func = get_exact_type_from_node(node.value.func) - if getattr(func, "_is_terminus", None): - return True - return False - - -def check_for_terminus(node_list: list) -> bool: - if next((i for i in node_list if _is_terminus_node(i)), None): - return True - for node in [i for i in node_list if isinstance(i, vy_ast.If)][::-1]: - if not node.orelse or not check_for_terminus(node.orelse): - continue - if not check_for_terminus(node.body): - continue - return True - return False +# finds the terminus node for a list of nodes. +# raises an exception if any nodes are unreachable +def find_terminating_node(node_list: list) -> Optional[vy_ast.VyperNode]: + ret = None + + for node in node_list: + if ret is not None: + raise StructureException("Unreachable code!", node) + if node.is_terminus: + ret = node + + if isinstance(node, vy_ast.If): + body_terminates = find_terminating_node(node.body) + + else_terminates = None + if node.orelse is not None: + else_terminates = find_terminating_node(node.orelse) + + if body_terminates is not None and else_terminates is not None: + ret = else_terminates + + return ret def _check_iterator_modification( @@ -201,11 +203,13 @@ def analyze(self): self.visit(node) if self.func.return_type: - if not check_for_terminus(self.fn_node.body): + if not find_terminating_node(self.fn_node.body): raise FunctionDeclarationException( - f"Missing or unmatched return statements in function '{self.fn_node.name}'", - self.fn_node, + f"Missing return statement in function '{self.fn_node.name}'", self.fn_node ) + else: + # call find_terminator for its unreachable code detection side effect + find_terminating_node(self.fn_node.body) # visit default args assert self.func.n_keyword_args == len(self.fn_node.args.defaults) @@ -468,7 +472,7 @@ def visit_Return(self, node): raise FunctionDeclarationException("Return statement is missing a value", node) return elif self.func.return_type is None: - raise FunctionDeclarationException("Function does not return any values", node) + raise FunctionDeclarationException("Function should not return any values", node) if isinstance(values, vy_ast.Tuple): values = values.elements