diff --git a/docs/resources.rst b/docs/resources.rst index a3dfa480ed..7bb3c99df4 100644 --- a/docs/resources.rst +++ b/docs/resources.rst @@ -24,6 +24,7 @@ Frameworks and tooling - `🐍 snekmate – Vyper smart contract building blocks `_ - `Serpentor – A set of smart contracts tools for governance `_ - `Smart contract development frameworks and tools for Vyper on Ethreum.org `_ +- `Vyper Online Compiler - an online platform for compiling and deploying Vyper smart contracts `_ Security -------- diff --git a/requirements-docs.txt b/requirements-docs.txt index d33eae62af..157d7bcab5 100644 --- a/requirements-docs.txt +++ b/requirements-docs.txt @@ -1,3 +1,3 @@ -sphinx==4.5.0 +sphinx==5.0.0 recommonmark==0.6.0 sphinx_rtd_theme==0.5.2 diff --git a/tests/functional/codegen/features/iteration/test_for_in_list.py b/tests/functional/codegen/features/iteration/test_for_in_list.py index 5c7b5c6b1b..7f5658e485 100644 --- a/tests/functional/codegen/features/iteration/test_for_in_list.py +++ b/tests/functional/codegen/features/iteration/test_for_in_list.py @@ -11,7 +11,9 @@ NamespaceCollision, StateAccessViolation, StructureException, + SyntaxException, TypeMismatch, + UnknownType, ) BASIC_FOR_LOOP_CODE = [ @@ -803,6 +805,33 @@ def test_for() -> int128: """, TypeMismatch, ), + ( + """ +@external +def foo(): + for i in [1, 2, 3]: + pass + """, + SyntaxException, + ), + ( + """ +@external +def foo(): + for i: $$$ in [1, 2, 3]: + pass + """, + SyntaxException, + ), + ( + """ +@external +def foo(): + for i: uint9 in [1, 2, 3]: + pass + """, + UnknownType, + ), ] BAD_CODE = [code if isinstance(code, tuple) else (code, StructureException) for code in BAD_CODE] diff --git a/tests/functional/codegen/features/test_assert.py b/tests/functional/codegen/features/test_assert.py index df379d3f16..de9dd17ef6 100644 --- a/tests/functional/codegen/features/test_assert.py +++ b/tests/functional/codegen/features/test_assert.py @@ -107,14 +107,6 @@ def test(): assert self.ret1() == 1 """, """ -@internal -def valid_address(sender: address) -> bool: - selfdestruct(sender) -@external -def test(): - assert self.valid_address(msg.sender) - """, - """ @external def test(): assert raw_call(msg.sender, b'', max_outsize=1, gas=10, value=1000*1000) == b'' diff --git a/tests/functional/codegen/features/test_conditionals.py b/tests/functional/codegen/features/test_conditionals.py index 15ccc40bdf..3b0e57eeca 100644 --- a/tests/functional/codegen/features/test_conditionals.py +++ b/tests/functional/codegen/features/test_conditionals.py @@ -7,7 +7,6 @@ def foo(i: bool) -> int128: else: assert 2 != 0 return 7 - return 11 """ c = get_contract_with_gas_estimation(conditional_return_code) diff --git a/tests/functional/codegen/modules/test_interface_imports.py b/tests/functional/codegen/modules/test_interface_imports.py new file mode 100644 index 0000000000..084ad26e6b --- /dev/null +++ b/tests/functional/codegen/modules/test_interface_imports.py @@ -0,0 +1,31 @@ +def test_import_interface_types(make_input_bundle, get_contract): + ifaces = """ +interface IFoo: + def foo() -> uint256: nonpayable + """ + + foo_impl = """ +import ifaces + +implements: ifaces.IFoo + +@external +def foo() -> uint256: + return block.number + """ + + contract = """ +import ifaces + +@external +def test_foo(s: ifaces.IFoo) -> bool: + assert s.foo() == block.number + return True + """ + + input_bundle = make_input_bundle({"ifaces.vy": ifaces}) + + foo = get_contract(foo_impl, input_bundle=input_bundle) + c = get_contract(contract, input_bundle=input_bundle) + + assert c.test_foo(foo.address) is True diff --git a/tests/functional/codegen/modules/test_module_constants.py b/tests/functional/codegen/modules/test_module_constants.py new file mode 100644 index 0000000000..aafbb69252 --- /dev/null +++ b/tests/functional/codegen/modules/test_module_constants.py @@ -0,0 +1,78 @@ +def test_module_constant(make_input_bundle, get_contract): + mod1 = """ +X: constant(uint256) = 12345 + """ + contract = """ +import mod1 + +@external +def foo() -> uint256: + return mod1.X + """ + + input_bundle = make_input_bundle({"mod1.vy": mod1}) + + c = get_contract(contract, input_bundle=input_bundle) + + assert c.foo() == 12345 + + +def test_nested_module_constant(make_input_bundle, get_contract): + # test nested module constants + # test at least 3 modules deep to test the `path.reverse()` gizmo + # in ConstantFolder.visit_Attribute() + mod1 = """ +X: constant(uint256) = 12345 + """ + mod2 = """ +import mod1 +X: constant(uint256) = 54321 + """ + mod3 = """ +import mod2 +X: constant(uint256) = 98765 + """ + + contract = """ +import mod1 +import mod2 +import mod3 + +@external +def test_foo() -> bool: + assert mod1.X == 12345 + assert mod2.X == 54321 + assert mod3.X == 98765 + assert mod2.mod1.X == mod1.X + assert mod3.mod2.mod1.X == mod1.X + assert mod3.mod2.X == mod2.X + return True + """ + + input_bundle = make_input_bundle({"mod1.vy": mod1, "mod2.vy": mod2, "mod3.vy": mod3}) + + c = get_contract(contract, input_bundle=input_bundle) + assert c.test_foo() is True + + +def test_import_constant_array(make_input_bundle, get_contract, tx_failed): + mod1 = """ +X: constant(uint256[3]) = [1,2,3] + """ + contract = """ +import mod1 + +@external +def foo(ix: uint256) -> uint256: + return mod1.X[ix] + """ + + input_bundle = make_input_bundle({"mod1.vy": mod1}) + + c = get_contract(contract, input_bundle=input_bundle) + + assert c.foo(0) == 1 + assert c.foo(1) == 2 + assert c.foo(2) == 3 + with tx_failed(): + c.foo(3) diff --git a/tests/functional/codegen/test_stateless_modules.py b/tests/functional/codegen/modules/test_stateless_functions.py similarity index 100% rename from tests/functional/codegen/test_stateless_modules.py rename to tests/functional/codegen/modules/test_stateless_functions.py diff --git a/tests/functional/syntax/exceptions/test_syntax_exception.py b/tests/functional/syntax/exceptions/test_syntax_exception.py index 9ab9b6c677..53a9550a7d 100644 --- a/tests/functional/syntax/exceptions/test_syntax_exception.py +++ b/tests/functional/syntax/exceptions/test_syntax_exception.py @@ -86,6 +86,18 @@ def f(a:uint256,/): # test posonlyargs blocked def g(): self.f() """, + """ +@external +def foo(): + for i in range(0, 10): + pass + """, + """ +@external +def foo(): + for i: $$$ in range(0, 10): + pass + """, ] diff --git a/tests/functional/syntax/test_for_range.py b/tests/functional/syntax/test_for_range.py index 66981a90de..a486d11738 100644 --- a/tests/functional/syntax/test_for_range.py +++ b/tests/functional/syntax/test_for_range.py @@ -8,6 +8,7 @@ StateAccessViolation, StructureException, TypeMismatch, + UnknownType, ) fail_list = [ @@ -235,9 +236,54 @@ def foo(): "Bound must be at least 1", "FOO", ), + ( + """ +@external +def foo(): + for i: DynArra[uint256, 3] in [1, 2, 3]: + pass + """, + UnknownType, + "No builtin or user-defined type named 'DynArra'. Did you mean 'DynArray'?", + "DynArra", + ), + ( + # test for loop target broken into multiple lines + """ +@external +def foo(): + for i: \\ + \\ + \\ + \\ + \\ + \\ + uint9 in [1,2,3]: + pass + """, + UnknownType, + "No builtin or user-defined type named 'uint9'. Did you mean 'uint96', or maybe 'uint8'?", + "uint9", + ), + ( + # test an even more deranged example + """ +@external +def foo(): + for i: \\ + \\ + DynArray[\\ + uint9, 3\\ + ] in [1,2,3]: + pass + """, + UnknownType, + "No builtin or user-defined type named 'uint9'. Did you mean 'uint96', or maybe 'uint8'?", + "uint9", + ), ] -for_code_regex = re.compile(r"for .+ in (.*):") +for_code_regex = re.compile(r"for .+ in (.*):", re.DOTALL) fail_test_names = [ ( f"{i:02d}: {for_code_regex.search(code).group(1)}" # type: ignore[union-attr] diff --git a/tests/functional/syntax/test_interfaces.py b/tests/functional/syntax/test_interfaces.py index a672ed7b88..ca96adca91 100644 --- a/tests/functional/syntax/test_interfaces.py +++ b/tests/functional/syntax/test_interfaces.py @@ -90,7 +90,7 @@ def foo(): nonpayable """ implements: self.x """, - StructureException, + InvalidType, ), ( """ diff --git a/tests/functional/syntax/test_unbalanced_return.py b/tests/functional/syntax/test_unbalanced_return.py index d1d9732777..d5754f0053 100644 --- a/tests/functional/syntax/test_unbalanced_return.py +++ b/tests/functional/syntax/test_unbalanced_return.py @@ -8,7 +8,7 @@ """ @external def foo() -> int128: - pass + pass # missing return """, FunctionDeclarationException, ), @@ -18,6 +18,7 @@ def foo() -> int128: def foo() -> int128: if False: return 123 + # missing return """, FunctionDeclarationException, ), @@ -27,19 +28,19 @@ def foo() -> int128: def test() -> int128: if 1 == 1 : return 1 - if True: + if True: # unreachable return 0 else: assert msg.sender != msg.sender """, - FunctionDeclarationException, + StructureException, ), ( """ @internal def valid_address(sender: address) -> bool: selfdestruct(sender) - return True + return True # unreachable """, StructureException, ), @@ -48,7 +49,7 @@ def valid_address(sender: address) -> bool: @internal def valid_address(sender: address) -> bool: selfdestruct(sender) - a: address = sender + a: address = sender # unreachable """, StructureException, ), @@ -58,7 +59,7 @@ def valid_address(sender: address) -> bool: def valid_address(sender: address) -> bool: if sender == empty(address): selfdestruct(sender) - _sender: address = sender + _sender: address = sender # unreachable else: return False """, @@ -69,7 +70,7 @@ def valid_address(sender: address) -> bool: @internal def foo() -> bool: raw_revert(b"vyper") - return True + return True # unreachable """, StructureException, ), @@ -78,7 +79,7 @@ def foo() -> bool: @internal def foo() -> bool: raw_revert(b"vyper") - x: uint256 = 3 + x: uint256 = 3 # unreachable """, StructureException, ), @@ -88,12 +89,35 @@ def foo() -> bool: def foo(x: uint256) -> bool: if x == 2: raw_revert(b"vyper") - a: uint256 = 3 + a: uint256 = 3 # unreachable else: return False """, StructureException, ), + ( + """ +@internal +def foo(): + return + return # unreachable + """, + StructureException, + ), + ( + """ +@internal +def foo() -> uint256: + if block.number % 2 == 0: + return 5 + elif block.number % 3 == 0: + return 6 + else: + return 10 + return 0 # unreachable + """, + StructureException, + ), ] @@ -154,7 +178,6 @@ def test() -> int128: else: x = keccak256(x) return 1 - return 1 """, """ @external diff --git a/tests/utils.py b/tests/utils.py index b8a6b493d8..25dad818ca 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -2,7 +2,7 @@ import os from vyper import ast as vy_ast -from vyper.semantics.analysis.pre_typecheck import pre_typecheck +from vyper.semantics.analysis.constant_folding import constant_fold @contextlib.contextmanager @@ -17,5 +17,5 @@ def working_directory(directory): def parse_and_fold(source_code): ast = vy_ast.parse_to_ast(source_code) - pre_typecheck(ast) + constant_fold(ast) return ast diff --git a/vyper/ast/nodes.py b/vyper/ast/nodes.py index 90365c63d5..de15fb9075 100644 --- a/vyper/ast/nodes.py +++ b/vyper/ast/nodes.py @@ -237,8 +237,6 @@ class VyperNode: Field names that, if present, must be set to None or a `SyntaxException` is raised. This attribute is used to exclude syntax that is valid in Python but not in Vyper. - _is_terminus : bool, optional - If `True`, indicates that execution halts upon reaching this node. _translated_fields : Dict, optional Field names that are reassigned if encountered. Used to normalize fields across different Python versions. @@ -389,6 +387,13 @@ def is_literal_value(self): """ return False + @property + def is_terminus(self): + """ + Check if execution halts upon reaching this node. + """ + return False + @property def has_folded_value(self): """ @@ -396,7 +401,7 @@ def has_folded_value(self): """ return "folded_value" in self._metadata - def get_folded_value(self) -> "VyperNode": + def get_folded_value(self) -> "ExprNode": """ Attempt to get the folded value, bubbling up UnfoldableNode if the node is not foldable. @@ -711,12 +716,19 @@ class Stmt(VyperNode): class Return(Stmt): __slots__ = ("value",) - _is_terminus = True + + @property + def is_terminus(self): + return True class Expr(Stmt): __slots__ = ("value",) + @property + def is_terminus(self): + return self.value.is_terminus + class Log(Stmt): __slots__ = ("value",) @@ -1187,6 +1199,21 @@ def _op(self, left, right): class Call(ExprNode): __slots__ = ("func", "args", "keywords") + @property + def is_terminus(self): + # cursed import cycle! + from vyper.builtins.functions import get_builtin_functions + + if not isinstance(self.func, Name): + return False + + funcname = self.func.id + builtin_t = get_builtin_functions().get(funcname) + if builtin_t is None: + return False + + return builtin_t._is_terminus + class keyword(VyperNode): __slots__ = ("arg", "value") @@ -1322,7 +1349,10 @@ class AugAssign(Stmt): class Raise(Stmt): __slots__ = ("exc",) _only_empty_fields = ("cause",) - _is_terminus = True + + @property + def is_terminus(self): + return True class Assert(Stmt): @@ -1372,8 +1402,8 @@ class ImplementsDecl(Stmt): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - if not isinstance(self.annotation, Name): - raise StructureException("not an identifier", self.annotation) + if not isinstance(self.annotation, (Name, Attribute)): + raise StructureException("invalid implements", self.annotation) class If(Stmt): diff --git a/vyper/ast/nodes.pyi b/vyper/ast/nodes.pyi index 4a5bc0d001..7f8c902d45 100644 --- a/vyper/ast/nodes.pyi +++ b/vyper/ast/nodes.pyi @@ -30,8 +30,8 @@ class VyperNode: def has_folded_value(self): ... @classmethod def get_fields(cls: Any) -> set: ... - def get_folded_value(self) -> VyperNode: ... - def _set_folded_value(self, node: VyperNode) -> None: ... + def get_folded_value(self) -> ExprNode: ... + def _set_folded_value(self, node: ExprNode) -> None: ... @classmethod def from_node(cls, node: VyperNode, **kwargs: Any) -> Any: ... def to_dict(self) -> dict: ... diff --git a/vyper/ast/parse.py b/vyper/ast/parse.py index b657cf2245..cc0a47824c 100644 --- a/vyper/ast/parse.py +++ b/vyper/ast/parse.py @@ -114,6 +114,50 @@ def dict_to_ast(ast_struct: Union[Dict, List]) -> Union[vy_ast.VyperNode, List]: raise CompilerPanic(f'Unknown ast_struct provided: "{type(ast_struct)}".') +def annotate_python_ast( + parsed_ast: python_ast.AST, + source_code: str, + modification_offsets: ModificationOffsets, + for_loop_annotations: dict, + source_id: int = 0, + module_path: Optional[str] = None, + resolved_path: Optional[str] = None, +) -> python_ast.AST: + """ + Annotate and optimize a Python AST in preparation conversion to a Vyper AST. + + Parameters + ---------- + parsed_ast : AST + The AST to be annotated and optimized. + source_code : str + The originating source code of the AST. + loop_var_annotations: dict + A mapping of line numbers of `For` nodes to the tokens of the type + annotation of the iterator extracted during pre-parsing. + modification_offsets : dict + A mapping of class names to their original class types. + + Returns + ------- + The annotated and optimized AST. + """ + + tokens = asttokens.ASTTokens(source_code, tree=cast(Optional[python_ast.Module], parsed_ast)) + visitor = AnnotatingVisitor( + source_code, + modification_offsets, + for_loop_annotations, + tokens, + source_id, + module_path=module_path, + resolved_path=resolved_path, + ) + visitor.visit(parsed_ast) + + return parsed_ast + + class AnnotatingVisitor(python_ast.NodeTransformer): _source_code: str _modification_offsets: ModificationOffsets @@ -150,7 +194,9 @@ def generic_visit(self, node): self.counter += 1 # Decorate every node with source end offsets - start = node.first_token.start if hasattr(node, "first_token") else (None, None) + start = (None, None) + if hasattr(node, "first_token"): + start = node.first_token.start end = (None, None) if hasattr(node, "last_token"): end = node.last_token.end @@ -167,6 +213,7 @@ def generic_visit(self, node): if hasattr(node, "last_token"): start_pos = node.first_token.startpos end_pos = node.last_token.endpos + if node.last_token.type == 4: # ignore trailing newline once more end_pos -= 1 @@ -224,9 +271,9 @@ def visit_For(self, node): Visit a For node, splicing in the loop variable annotation provided by the pre-parser """ - raw_annotation = self._for_loop_annotations.pop((node.lineno, node.col_offset)) + annotation_tokens = self._for_loop_annotations.pop((node.lineno, node.col_offset)) - if not raw_annotation: + if not annotation_tokens: # a common case for people migrating to 0.4.0, provide a more # specific error message than "invalid type annotation" raise SyntaxException( @@ -238,27 +285,42 @@ def visit_For(self, node): node.col_offset, ) + # some kind of black magic. untokenize preserves the line and column + # offsets, giving us something like `\ + # \ + # \ + # uint8` + # that's not a valid python Expr because it is indented. + # but it's good because the code is indented to exactly the same + # offset as it did in the original source! + # (to best understand this, print out annotation_str and + # self._source_code and compare them side-by-side). + # + # what we do here is add in a dummy target which we will remove + # in a bit, but for now lets us keep the line/col offset, and + # *also* gives us a valid AST. it doesn't matter what the dummy + # target name is, since it gets removed in a few lines. + annotation_str = tokenize.untokenize(annotation_tokens) + annotation_str = "dummy_target:" + annotation_str + try: - annotation = python_ast.parse(raw_annotation, mode="eval") - # annotate with token and source code information. `first_token` - # and `last_token` attributes are accessed in `generic_visit`. - tokens = asttokens.ASTTokens(raw_annotation) - tokens.mark_tokens(annotation) + fake_node = python_ast.parse(annotation_str).body[0] except SyntaxError as e: raise SyntaxException( "invalid type annotation", self._source_code, node.lineno, node.col_offset ) from e - assert isinstance(annotation, python_ast.Expression) - annotation = annotation.body + # fill in with asttokens info. note we can use `self._tokens` because + # it is indented to exactly the same position where it appeared + # in the original source! + self._tokens.mark_tokens(fake_node) - old_target = node.target - new_target = python_ast.AnnAssign(target=old_target, annotation=annotation, simple=1) - node.target = new_target + # replace the dummy target name with the real target name. + fake_node.target = node.target + # replace the For node target with the new ann_assign + node.target = fake_node - self.generic_visit(node) - - return node + return self.generic_visit(node) def visit_Expr(self, node): """ @@ -397,47 +459,3 @@ def visit_UnaryOp(self, node): return node.operand else: return node - - -def annotate_python_ast( - parsed_ast: python_ast.AST, - source_code: str, - modification_offsets: ModificationOffsets, - for_loop_annotations: dict, - source_id: int = 0, - module_path: Optional[str] = None, - resolved_path: Optional[str] = None, -) -> python_ast.AST: - """ - Annotate and optimize a Python AST in preparation conversion to a Vyper AST. - - Parameters - ---------- - parsed_ast : AST - The AST to be annotated and optimized. - source_code : str - The originating source code of the AST. - loop_var_annotations: dict, optional - A mapping of line numbers of `For` nodes to the type annotation of the iterator - extracted during pre-parsing. - modification_offsets : dict, optional - A mapping of class names to their original class types. - - Returns - ------- - The annotated and optimized AST. - """ - - tokens = asttokens.ASTTokens(source_code, tree=cast(Optional[python_ast.Module], parsed_ast)) - visitor = AnnotatingVisitor( - source_code, - modification_offsets, - for_loop_annotations, - tokens, - source_id, - module_path=module_path, - resolved_path=resolved_path, - ) - visitor.visit(parsed_ast) - - return parsed_ast diff --git a/vyper/ast/pre_parser.py b/vyper/ast/pre_parser.py index c7e6f3698f..159dfc0ace 100644 --- a/vyper/ast/pre_parser.py +++ b/vyper/ast/pre_parser.py @@ -136,7 +136,7 @@ def pre_parse(code: str) -> tuple[Settings, ModificationOffsets, dict, str]: Compilation settings based on the directives in the source code ModificationOffsets A mapping of class names to their original class types. - dict[tuple[int, int], str] + dict[tuple[int, int], list[TokenInfo]] A mapping of line/column offsets of `For` nodes to the annotation of the for loop target str Reformatted python source string. @@ -220,9 +220,6 @@ def pre_parse(code: str) -> tuple[Settings, ModificationOffsets, dict, str]: for_loop_annotations = {} for k, v in for_parser.annotations.items(): - v_source = untokenize(v) - # untokenize adds backslashes and whitespace, strip them. - v_source = v_source.replace("\\", "").strip() - for_loop_annotations[k] = v_source + for_loop_annotations[k] = v.copy() return settings, modification_offsets, for_loop_annotations, untokenize(result).decode("utf-8") diff --git a/vyper/builtins/_signatures.py b/vyper/builtins/_signatures.py index af92f86a44..d2aefb2fd4 100644 --- a/vyper/builtins/_signatures.py +++ b/vyper/builtins/_signatures.py @@ -85,6 +85,7 @@ class BuiltinFunctionT(VyperType): _kwargs: dict[str, KwargSettings] = {} _modifiability: Modifiability = Modifiability.MODIFIABLE _return_type: Optional[VyperType] = None + _is_terminus = False # helper function to deal with TYPE_DEFINITIONs def _validate_single(self, arg: vy_ast.VyperNode, expected_type: VyperType) -> None: diff --git a/vyper/codegen/core.py b/vyper/codegen/core.py index c16de3c55a..c3215f8c16 100644 --- a/vyper/codegen/core.py +++ b/vyper/codegen/core.py @@ -1,12 +1,11 @@ import contextlib from typing import Generator -from vyper import ast as vy_ast from vyper.codegen.ir_node import Encoding, IRnode from vyper.compiler.settings import OptimizationLevel from vyper.evm.address_space import CALLDATA, DATA, IMMUTABLES, MEMORY, STORAGE, TRANSIENT from vyper.evm.opcodes import version_check -from vyper.exceptions import CompilerPanic, StructureException, TypeCheckFailure, TypeMismatch +from vyper.exceptions import CompilerPanic, TypeCheckFailure, TypeMismatch from vyper.semantics.types import ( AddressT, BoolT, @@ -1035,43 +1034,6 @@ def eval_seq(ir_node): return None -def is_return_from_function(node): - if isinstance(node, vy_ast.Expr) and node.get("value.func.id") in ( - "raw_revert", - "selfdestruct", - ): - return True - if isinstance(node, (vy_ast.Return, vy_ast.Raise)): - return True - return False - - -# TODO this is almost certainly duplicated with check_terminus_node -# in vyper/semantics/analysis/local.py -def check_single_exit(fn_node): - _check_return_body(fn_node, fn_node.body) - for node in fn_node.get_descendants(vy_ast.If): - _check_return_body(node, node.body) - if node.orelse: - _check_return_body(node, node.orelse) - - -def _check_return_body(node, node_list): - return_count = len([n for n in node_list if is_return_from_function(n)]) - if return_count > 1: - raise StructureException( - "Too too many exit statements (return, raise or selfdestruct).", node - ) - # Check for invalid code after returns. - last_node_pos = len(node_list) - 1 - for idx, n in enumerate(node_list): - if is_return_from_function(n) and idx < last_node_pos: - # is not last statement in body. - raise StructureException( - "Exit statement with succeeding code (that will not execute).", node_list[idx + 1] - ) - - def mzero(dst, nbytes): # calldatacopy from past-the-end gives zero bytes. # cf. YP H.2 (ops section) with CALLDATACOPY spec. diff --git a/vyper/codegen/function_definitions/common.py b/vyper/codegen/function_definitions/common.py index 454ba9c8cd..5877ff3d13 100644 --- a/vyper/codegen/function_definitions/common.py +++ b/vyper/codegen/function_definitions/common.py @@ -4,7 +4,6 @@ import vyper.ast as vy_ast from vyper.codegen.context import Constancy, Context -from vyper.codegen.core import check_single_exit from vyper.codegen.function_definitions.external_function import generate_ir_for_external_function from vyper.codegen.function_definitions.internal_function import generate_ir_for_internal_function from vyper.codegen.ir_node import IRnode @@ -115,10 +114,6 @@ def generate_ir_for_function( # generate _FuncIRInfo func_t._ir_info = _FuncIRInfo(func_t) - # Validate return statements. - # XXX: This should really be in semantics pass. - check_single_exit(code) - callees = func_t.called_functions # we start our function frame from the largest callee frame diff --git a/vyper/codegen/stmt.py b/vyper/codegen/stmt.py index a47faefeb1..7d4938f287 100644 --- a/vyper/codegen/stmt.py +++ b/vyper/codegen/stmt.py @@ -15,7 +15,6 @@ get_dyn_array_count, get_element_ptr, getpos, - is_return_from_function, make_byte_array_copier, make_setter, pop_dyn_array, @@ -404,7 +403,7 @@ def parse_stmt(stmt, context): def _is_terminated(code): last_stmt = code[-1] - if is_return_from_function(last_stmt): + if last_stmt.is_terminus: return True if isinstance(last_stmt, vy_ast.If): diff --git a/vyper/semantics/analysis/pre_typecheck.py b/vyper/semantics/analysis/constant_folding.py similarity index 89% rename from vyper/semantics/analysis/pre_typecheck.py rename to vyper/semantics/analysis/constant_folding.py index 1c2a5392c3..b165a6dae9 100644 --- a/vyper/semantics/analysis/pre_typecheck.py +++ b/vyper/semantics/analysis/constant_folding.py @@ -1,11 +1,11 @@ from vyper import ast as vy_ast -from vyper.exceptions import InvalidLiteral, UnfoldableNode +from vyper.exceptions import InvalidLiteral, UnfoldableNode, VyperException from vyper.semantics.analysis.base import VarInfo from vyper.semantics.analysis.common import VyperNodeVisitorBase from vyper.semantics.namespace import get_namespace -def pre_typecheck(module_ast: vy_ast.Module): +def constant_fold(module_ast: vy_ast.Module): ConstantFolder(module_ast).run() @@ -89,6 +89,33 @@ def visit_Name(self, node) -> vy_ast.ExprNode: except KeyError: raise UnfoldableNode("unknown name", node) + def visit_Attribute(self, node) -> vy_ast.ExprNode: + namespace = get_namespace() + path = [] + value = node.value + while isinstance(value, vy_ast.Attribute): + path.append(value.attr) + value = value.value + + path.reverse() + + if not isinstance(value, vy_ast.Name): + raise UnfoldableNode("not a module", value) + + # not super type-safe but we don't care. just catch AttributeErrors + # and move on + try: + module_t = namespace[value.id].module_t + + for module_name in path: + module_t = module_t.members[module_name].module_t + + varinfo = module_t.get_member(node.attr, node) + + return varinfo.decl_node.value.get_folded_value() + except (VyperException, AttributeError): + raise UnfoldableNode("not a module") + def visit_UnaryOp(self, node): operand = node.operand.get_folded_value() diff --git a/vyper/semantics/analysis/local.py b/vyper/semantics/analysis/local.py index cc8ddaf98d..c4af5b1e3a 100644 --- a/vyper/semantics/analysis/local.py +++ b/vyper/semantics/analysis/local.py @@ -66,26 +66,28 @@ def validate_functions(vy_module: vy_ast.Module) -> None: err_list.raise_if_not_empty() -def _is_terminus_node(node: vy_ast.VyperNode) -> bool: - if getattr(node, "_is_terminus", None): - return True - if isinstance(node, vy_ast.Expr) and isinstance(node.value, vy_ast.Call): - func = get_exact_type_from_node(node.value.func) - if getattr(func, "_is_terminus", None): - return True - return False - - -def check_for_terminus(node_list: list) -> bool: - if next((i for i in node_list if _is_terminus_node(i)), None): - return True - for node in [i for i in node_list if isinstance(i, vy_ast.If)][::-1]: - if not node.orelse or not check_for_terminus(node.orelse): - continue - if not check_for_terminus(node.body): - continue - return True - return False +# finds the terminus node for a list of nodes. +# raises an exception if any nodes are unreachable +def find_terminating_node(node_list: list) -> Optional[vy_ast.VyperNode]: + ret = None + + for node in node_list: + if ret is not None: + raise StructureException("Unreachable code!", node) + if node.is_terminus: + ret = node + + if isinstance(node, vy_ast.If): + body_terminates = find_terminating_node(node.body) + + else_terminates = None + if node.orelse is not None: + else_terminates = find_terminating_node(node.orelse) + + if body_terminates is not None and else_terminates is not None: + ret = else_terminates + + return ret def _check_iterator_modification( @@ -201,11 +203,13 @@ def analyze(self): self.visit(node) if self.func.return_type: - if not check_for_terminus(self.fn_node.body): + if not find_terminating_node(self.fn_node.body): raise FunctionDeclarationException( - f"Missing or unmatched return statements in function '{self.fn_node.name}'", - self.fn_node, + f"Missing return statement in function '{self.fn_node.name}'", self.fn_node ) + else: + # call find_terminator for its unreachable code detection side effect + find_terminating_node(self.fn_node.body) # visit default args assert self.func.n_keyword_args == len(self.fn_node.args.defaults) @@ -468,7 +472,7 @@ def visit_Return(self, node): raise FunctionDeclarationException("Return statement is missing a value", node) return elif self.func.return_type is None: - raise FunctionDeclarationException("Function does not return any values", node) + raise FunctionDeclarationException("Function should not return any values", node) if isinstance(values, vy_ast.Tuple): values = values.elements diff --git a/vyper/semantics/analysis/module.py b/vyper/semantics/analysis/module.py index 4a7e33e848..100819526b 100644 --- a/vyper/semantics/analysis/module.py +++ b/vyper/semantics/analysis/module.py @@ -23,9 +23,9 @@ ) from vyper.semantics.analysis.base import ImportInfo, Modifiability, ModuleInfo, VarInfo from vyper.semantics.analysis.common import VyperNodeVisitorBase +from vyper.semantics.analysis.constant_folding import constant_fold from vyper.semantics.analysis.import_graph import ImportGraph from vyper.semantics.analysis.local import ExprVisitor, validate_functions -from vyper.semantics.analysis.pre_typecheck import pre_typecheck from vyper.semantics.analysis.utils import check_modifiability, get_exact_type_from_node from vyper.semantics.data_locations import DataLocation from vyper.semantics.namespace import Namespace, get_namespace, override_global_namespace @@ -51,8 +51,6 @@ def validate_semantics_r( """ validate_literal_nodes(module_ast) - pre_typecheck(module_ast) - # validate semantics and annotate AST with type/semantics information namespace = get_namespace() @@ -140,6 +138,9 @@ def analyze(self) -> ModuleT: self.visit(node) to_visit.remove(node) + # we can resolve constants after imports are handled. + constant_fold(self.ast) + # keep trying to process all the nodes until we finish or can # no longer progress. this makes it so we don't need to # calculate a dependency tree between top-level items. @@ -383,8 +384,9 @@ def visit_ImportFrom(self, node): self._add_import(node, node.level, qualified_module_name, alias) def visit_InterfaceDef(self, node): - obj = InterfaceT.from_InterfaceDef(node) - self.namespace[node.name] = obj + interface_t = InterfaceT.from_InterfaceDef(node) + node._metadata["interface_type"] = interface_t + self.namespace[node.name] = interface_t def visit_StructDef(self, node): struct_t = StructT.from_StructDef(node) diff --git a/vyper/semantics/types/module.py b/vyper/semantics/types/module.py index 4063b8e162..ee1da22a87 100644 --- a/vyper/semantics/types/module.py +++ b/vyper/semantics/types/module.py @@ -43,6 +43,7 @@ def __init__(self, _id: str, functions: dict, events: dict, structs: dict) -> No self._helper = VyperType(events | structs) self._id = _id + self._helper._id = _id self.functions = functions self.events = events self.structs = structs @@ -274,6 +275,8 @@ def from_InterfaceDef(cls, node: vy_ast.InterfaceDef) -> "InterfaceT": # Datatype to store all module information. class ModuleT(VyperType): + _attribute_in_annotation = True + def __init__(self, module: vy_ast.Module, name: Optional[str] = None): super().__init__() @@ -283,7 +286,10 @@ def __init__(self, module: vy_ast.Module, name: Optional[str] = None): # compute the interface, note this has the side effect of checking # for function collisions - self._helper = self.interface + _ = self.interface + + self._helper = VyperType() + self._helper._id = self._id for f in self.function_defs: # note: this checks for collisions @@ -296,6 +302,12 @@ def __init__(self, module: vy_ast.Module, name: Optional[str] = None): for s in self.struct_defs: # add the type of the struct so it can be used in call position self.add_member(s.name, TYPE_T(s._metadata["struct_type"])) # type: ignore + self._helper.add_member(s.name, TYPE_T(s._metadata["struct_type"])) # type: ignore + + for i in self.interface_defs: + # add the type of the interface so it can be used in call position + self.add_member(i.name, TYPE_T(i._metadata["interface_type"])) # type: ignore + self._helper.add_member(i.name, TYPE_T(i._metadata["interface_type"])) # type: ignore for v in self.variable_decls: self.add_member(v.target.id, v.target._metadata["varinfo"]) @@ -329,6 +341,10 @@ def event_defs(self): def struct_defs(self): return self._module.get_children(vy_ast.StructDef) + @property + def interface_defs(self): + return self._module.get_children(vy_ast.InterfaceDef) + @property def import_stmts(self): return self._module.get_children((vy_ast.Import, vy_ast.ImportFrom)) diff --git a/vyper/semantics/types/utils.py b/vyper/semantics/types/utils.py index eb96375404..c82eb73afc 100644 --- a/vyper/semantics/types/utils.py +++ b/vyper/semantics/types/utils.py @@ -127,14 +127,16 @@ def _type_from_annotation(node: vy_ast.VyperNode) -> VyperType: except UndeclaredDefinition: raise InvalidType(err_msg, node) from None - interface = module_or_interface if hasattr(module_or_interface, "module_t"): # i.e., it's a ModuleInfo - interface = module_or_interface.module_t.interface + module_or_interface = module_or_interface.module_t - if not interface._attribute_in_annotation: + if not isinstance(module_or_interface, VyperType): raise InvalidType(err_msg, node) - type_t = interface.get_type_member(node.attr, node) + if not module_or_interface._attribute_in_annotation: + raise InvalidType(err_msg, node) + + type_t = module_or_interface.get_type_member(node.attr, node) # type: ignore assert isinstance(type_t, TYPE_T) # sanity check return type_t.typedef