Skip to content

Commit

Permalink
refactor: remove duplicate terminus checking code (vyperlang#3541)
Browse files Browse the repository at this point in the history
remove `check_single_exit` and `is_return_from_function` which duplicate
functionality in `is_terminus_node`/`check_for_terminus`.

additionally rewrite termination checking routine to be simpler, and
also fix an outstanding analysis bug where the following program would
not be rejected:

```vyper
@external
def foo(a: bool) -> uint256:
    if a:
        return 1
    else:
        return 2
    pass  # unreachable
```

---------

Co-authored-by: Charles Cooper <cooper.charles.m@gmail.com>
  • Loading branch information
tserg and charles-cooper authored Jan 14, 2024
1 parent 9cf66c9 commit af5c49f
Show file tree
Hide file tree
Showing 9 changed files with 98 additions and 93 deletions.
8 changes: 0 additions & 8 deletions tests/functional/codegen/features/test_assert.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,14 +107,6 @@ def test():
assert self.ret1() == 1
""",
"""
@internal
def valid_address(sender: address) -> bool:
selfdestruct(sender)
@external
def test():
assert self.valid_address(msg.sender)
""",
"""
@external
def test():
assert raw_call(msg.sender, b'', max_outsize=1, gas=10, value=1000*1000) == b''
Expand Down
1 change: 0 additions & 1 deletion tests/functional/codegen/features/test_conditionals.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ def foo(i: bool) -> int128:
else:
assert 2 != 0
return 7
return 11
"""

c = get_contract_with_gas_estimation(conditional_return_code)
Expand Down
43 changes: 33 additions & 10 deletions tests/functional/syntax/test_unbalanced_return.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
"""
@external
def foo() -> int128:
pass
pass # missing return
""",
FunctionDeclarationException,
),
Expand All @@ -18,6 +18,7 @@ def foo() -> int128:
def foo() -> int128:
if False:
return 123
# missing return
""",
FunctionDeclarationException,
),
Expand All @@ -27,19 +28,19 @@ def foo() -> int128:
def test() -> int128:
if 1 == 1 :
return 1
if True:
if True: # unreachable
return 0
else:
assert msg.sender != msg.sender
""",
FunctionDeclarationException,
StructureException,
),
(
"""
@internal
def valid_address(sender: address) -> bool:
selfdestruct(sender)
return True
return True # unreachable
""",
StructureException,
),
Expand All @@ -48,7 +49,7 @@ def valid_address(sender: address) -> bool:
@internal
def valid_address(sender: address) -> bool:
selfdestruct(sender)
a: address = sender
a: address = sender # unreachable
""",
StructureException,
),
Expand All @@ -58,7 +59,7 @@ def valid_address(sender: address) -> bool:
def valid_address(sender: address) -> bool:
if sender == empty(address):
selfdestruct(sender)
_sender: address = sender
_sender: address = sender # unreachable
else:
return False
""",
Expand All @@ -69,7 +70,7 @@ def valid_address(sender: address) -> bool:
@internal
def foo() -> bool:
raw_revert(b"vyper")
return True
return True # unreachable
""",
StructureException,
),
Expand All @@ -78,7 +79,7 @@ def foo() -> bool:
@internal
def foo() -> bool:
raw_revert(b"vyper")
x: uint256 = 3
x: uint256 = 3 # unreachable
""",
StructureException,
),
Expand All @@ -88,12 +89,35 @@ def foo() -> bool:
def foo(x: uint256) -> bool:
if x == 2:
raw_revert(b"vyper")
a: uint256 = 3
a: uint256 = 3 # unreachable
else:
return False
""",
StructureException,
),
(
"""
@internal
def foo():
return
return # unreachable
""",
StructureException,
),
(
"""
@internal
def foo() -> uint256:
if block.number % 2 == 0:
return 5
elif block.number % 3 == 0:
return 6
else:
return 10
return 0 # unreachable
""",
StructureException,
),
]


Expand Down Expand Up @@ -154,7 +178,6 @@ def test() -> int128:
else:
x = keccak256(x)
return 1
return 1
""",
"""
@external
Expand Down
38 changes: 34 additions & 4 deletions vyper/ast/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,8 +237,6 @@ class VyperNode:
Field names that, if present, must be set to None or a `SyntaxException`
is raised. This attribute is used to exclude syntax that is valid in Python
but not in Vyper.
_is_terminus : bool, optional
If `True`, indicates that execution halts upon reaching this node.
_translated_fields : Dict, optional
Field names that are reassigned if encountered. Used to normalize fields
across different Python versions.
Expand Down Expand Up @@ -389,6 +387,13 @@ def is_literal_value(self):
"""
return False

@property
def is_terminus(self):
"""
Check if execution halts upon reaching this node.
"""
return False

@property
def has_folded_value(self):
"""
Expand Down Expand Up @@ -711,12 +716,19 @@ class Stmt(VyperNode):

class Return(Stmt):
__slots__ = ("value",)
_is_terminus = True

@property
def is_terminus(self):
return True


class Expr(Stmt):
__slots__ = ("value",)

@property
def is_terminus(self):
return self.value.is_terminus


class Log(Stmt):
__slots__ = ("value",)
Expand Down Expand Up @@ -1187,6 +1199,21 @@ def _op(self, left, right):
class Call(ExprNode):
__slots__ = ("func", "args", "keywords")

@property
def is_terminus(self):
# cursed import cycle!
from vyper.builtins.functions import get_builtin_functions

if not isinstance(self.func, Name):
return False

funcname = self.func.id
builtin_t = get_builtin_functions().get(funcname)
if builtin_t is None:
return False

return builtin_t._is_terminus


class keyword(VyperNode):
__slots__ = ("arg", "value")
Expand Down Expand Up @@ -1322,7 +1349,10 @@ class AugAssign(Stmt):
class Raise(Stmt):
__slots__ = ("exc",)
_only_empty_fields = ("cause",)
_is_terminus = True

@property
def is_terminus(self):
return True


class Assert(Stmt):
Expand Down
1 change: 1 addition & 0 deletions vyper/builtins/_signatures.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ class BuiltinFunctionT(VyperType):
_kwargs: dict[str, KwargSettings] = {}
_modifiability: Modifiability = Modifiability.MODIFIABLE
_return_type: Optional[VyperType] = None
_is_terminus = False

# helper function to deal with TYPE_DEFINITIONs
def _validate_single(self, arg: vy_ast.VyperNode, expected_type: VyperType) -> None:
Expand Down
40 changes: 1 addition & 39 deletions vyper/codegen/core.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
import contextlib
from typing import Generator

from vyper import ast as vy_ast
from vyper.codegen.ir_node import Encoding, IRnode
from vyper.compiler.settings import OptimizationLevel
from vyper.evm.address_space import CALLDATA, DATA, IMMUTABLES, MEMORY, STORAGE, TRANSIENT
from vyper.evm.opcodes import version_check
from vyper.exceptions import CompilerPanic, StructureException, TypeCheckFailure, TypeMismatch
from vyper.exceptions import CompilerPanic, TypeCheckFailure, TypeMismatch
from vyper.semantics.types import (
AddressT,
BoolT,
Expand Down Expand Up @@ -1035,43 +1034,6 @@ def eval_seq(ir_node):
return None


def is_return_from_function(node):
if isinstance(node, vy_ast.Expr) and node.get("value.func.id") in (
"raw_revert",
"selfdestruct",
):
return True
if isinstance(node, (vy_ast.Return, vy_ast.Raise)):
return True
return False


# TODO this is almost certainly duplicated with check_terminus_node
# in vyper/semantics/analysis/local.py
def check_single_exit(fn_node):
_check_return_body(fn_node, fn_node.body)
for node in fn_node.get_descendants(vy_ast.If):
_check_return_body(node, node.body)
if node.orelse:
_check_return_body(node, node.orelse)


def _check_return_body(node, node_list):
return_count = len([n for n in node_list if is_return_from_function(n)])
if return_count > 1:
raise StructureException(
"Too too many exit statements (return, raise or selfdestruct).", node
)
# Check for invalid code after returns.
last_node_pos = len(node_list) - 1
for idx, n in enumerate(node_list):
if is_return_from_function(n) and idx < last_node_pos:
# is not last statement in body.
raise StructureException(
"Exit statement with succeeding code (that will not execute).", node_list[idx + 1]
)


def mzero(dst, nbytes):
# calldatacopy from past-the-end gives zero bytes.
# cf. YP H.2 (ops section) with CALLDATACOPY spec.
Expand Down
5 changes: 0 additions & 5 deletions vyper/codegen/function_definitions/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

import vyper.ast as vy_ast
from vyper.codegen.context import Constancy, Context
from vyper.codegen.core import check_single_exit
from vyper.codegen.function_definitions.external_function import generate_ir_for_external_function
from vyper.codegen.function_definitions.internal_function import generate_ir_for_internal_function
from vyper.codegen.ir_node import IRnode
Expand Down Expand Up @@ -115,10 +114,6 @@ def generate_ir_for_function(
# generate _FuncIRInfo
func_t._ir_info = _FuncIRInfo(func_t)

# Validate return statements.
# XXX: This should really be in semantics pass.
check_single_exit(code)

callees = func_t.called_functions

# we start our function frame from the largest callee frame
Expand Down
3 changes: 1 addition & 2 deletions vyper/codegen/stmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
get_dyn_array_count,
get_element_ptr,
getpos,
is_return_from_function,
make_byte_array_copier,
make_setter,
pop_dyn_array,
Expand Down Expand Up @@ -404,7 +403,7 @@ def parse_stmt(stmt, context):
def _is_terminated(code):
last_stmt = code[-1]

if is_return_from_function(last_stmt):
if last_stmt.is_terminus:
return True

if isinstance(last_stmt, vy_ast.If):
Expand Down
52 changes: 28 additions & 24 deletions vyper/semantics/analysis/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,26 +66,28 @@ def validate_functions(vy_module: vy_ast.Module) -> None:
err_list.raise_if_not_empty()


def _is_terminus_node(node: vy_ast.VyperNode) -> bool:
if getattr(node, "_is_terminus", None):
return True
if isinstance(node, vy_ast.Expr) and isinstance(node.value, vy_ast.Call):
func = get_exact_type_from_node(node.value.func)
if getattr(func, "_is_terminus", None):
return True
return False


def check_for_terminus(node_list: list) -> bool:
if next((i for i in node_list if _is_terminus_node(i)), None):
return True
for node in [i for i in node_list if isinstance(i, vy_ast.If)][::-1]:
if not node.orelse or not check_for_terminus(node.orelse):
continue
if not check_for_terminus(node.body):
continue
return True
return False
# finds the terminus node for a list of nodes.
# raises an exception if any nodes are unreachable
def find_terminating_node(node_list: list) -> Optional[vy_ast.VyperNode]:
ret = None

for node in node_list:
if ret is not None:
raise StructureException("Unreachable code!", node)
if node.is_terminus:
ret = node

if isinstance(node, vy_ast.If):
body_terminates = find_terminating_node(node.body)

else_terminates = None
if node.orelse is not None:
else_terminates = find_terminating_node(node.orelse)

if body_terminates is not None and else_terminates is not None:
ret = else_terminates

return ret


def _check_iterator_modification(
Expand Down Expand Up @@ -201,11 +203,13 @@ def analyze(self):
self.visit(node)

if self.func.return_type:
if not check_for_terminus(self.fn_node.body):
if not find_terminating_node(self.fn_node.body):
raise FunctionDeclarationException(
f"Missing or unmatched return statements in function '{self.fn_node.name}'",
self.fn_node,
f"Missing return statement in function '{self.fn_node.name}'", self.fn_node
)
else:
# call find_terminator for its unreachable code detection side effect
find_terminating_node(self.fn_node.body)

# visit default args
assert self.func.n_keyword_args == len(self.fn_node.args.defaults)
Expand Down Expand Up @@ -468,7 +472,7 @@ def visit_Return(self, node):
raise FunctionDeclarationException("Return statement is missing a value", node)
return
elif self.func.return_type is None:
raise FunctionDeclarationException("Function does not return any values", node)
raise FunctionDeclarationException("Function should not return any values", node)

if isinstance(values, vy_ast.Tuple):
values = values.elements
Expand Down

0 comments on commit af5c49f

Please sign in to comment.