From 3ba14124602b673d45b86bae7ff90a01d782acb5 Mon Sep 17 00:00:00 2001 From: tserg <8017125+tserg@users.noreply.github.com> Date: Fri, 20 Oct 2023 01:26:40 +0800 Subject: [PATCH] refactor: merge `annotation.py` and `local.py` (#3456) This commit merges two phases of analysis: typechecking (`vyper/semantics/analysis/local.py`) and annotation of ast nodes with types (`vyper/semantics/analysis/annotation.py`). This is both for consistency with how it is done for `analysis/module.py`, and also because it increases internal consistency, as some typechecking was being done in `annotation.py`. It is also easier to maintain this way, since bugfixes or modifications to typechecking need to only be done in one place, rather than two passes. Lastly, it also probably improves performance, because it collapses two passes into one (and calls `get_*_type_from_node` less often). This commit also fixes a bug with accessing the iterator when the iterator is an empty list literal. --- tests/parser/syntax/test_list.py | 3 +- vyper/ast/nodes.pyi | 5 + vyper/builtins/_signatures.py | 2 +- vyper/builtins/functions.py | 6 + vyper/semantics/analysis/annotation.py | 283 ------------ vyper/semantics/analysis/common.py | 19 +- vyper/semantics/analysis/local.py | 607 ++++++++++++++++--------- vyper/semantics/analysis/utils.py | 13 +- 8 files changed, 420 insertions(+), 518 deletions(-) delete mode 100644 vyper/semantics/analysis/annotation.py diff --git a/tests/parser/syntax/test_list.py b/tests/parser/syntax/test_list.py index 3f81b911c8..db41de5526 100644 --- a/tests/parser/syntax/test_list.py +++ b/tests/parser/syntax/test_list.py @@ -305,8 +305,9 @@ def foo(): """ @external def foo(): + x: DynArray[uint256, 3] = [1, 2, 3] for i in [[], []]: - pass + x = i """, ] diff --git a/vyper/ast/nodes.pyi b/vyper/ast/nodes.pyi index 0d59a2fa63..47c9af8526 100644 --- a/vyper/ast/nodes.pyi +++ b/vyper/ast/nodes.pyi @@ -142,6 +142,7 @@ class Expr(VyperNode): class UnaryOp(ExprNode): op: VyperNode = ... + operand: VyperNode = ... class USub(VyperNode): ... class Not(VyperNode): ... @@ -165,12 +166,15 @@ class BitXor(VyperNode): ... class BoolOp(ExprNode): op: VyperNode = ... + values: list[VyperNode] = ... class And(VyperNode): ... class Or(VyperNode): ... class Compare(ExprNode): op: VyperNode = ... + left: VyperNode = ... + right: VyperNode = ... class Eq(VyperNode): ... class NotEq(VyperNode): ... @@ -179,6 +183,7 @@ class LtE(VyperNode): ... class Gt(VyperNode): ... class GtE(VyperNode): ... class In(VyperNode): ... +class NotIn(VyperNode): ... class Call(ExprNode): args: list = ... diff --git a/vyper/builtins/_signatures.py b/vyper/builtins/_signatures.py index d39a4a085f..2802421129 100644 --- a/vyper/builtins/_signatures.py +++ b/vyper/builtins/_signatures.py @@ -74,7 +74,7 @@ def decorator_fn(self, node, context): return decorator_fn -class BuiltinFunction: +class BuiltinFunction(VyperType): _has_varargs = False _kwargs: Dict[str, KwargSettings] = {} diff --git a/vyper/builtins/functions.py b/vyper/builtins/functions.py index f07202831d..001939638b 100644 --- a/vyper/builtins/functions.py +++ b/vyper/builtins/functions.py @@ -475,6 +475,12 @@ def evaluate(self, node): return vy_ast.Int.from_node(node, value=length) + def infer_arg_types(self, node): + self._validate_arg_types(node) + # return a concrete type + typ = get_possible_types_from_node(node.args[0]).pop() + return [typ] + def build_IR(self, node, context): arg = Expr(node.args[0], context).ir_node if arg.value == "~calldata": diff --git a/vyper/semantics/analysis/annotation.py b/vyper/semantics/analysis/annotation.py deleted file mode 100644 index 01ca51d7f4..0000000000 --- a/vyper/semantics/analysis/annotation.py +++ /dev/null @@ -1,283 +0,0 @@ -from vyper import ast as vy_ast -from vyper.exceptions import StructureException, TypeCheckFailure -from vyper.semantics.analysis.utils import ( - get_common_types, - get_exact_type_from_node, - get_possible_types_from_node, -) -from vyper.semantics.types import TYPE_T, BoolT, EnumT, EventT, SArrayT, StructT, is_type_t -from vyper.semantics.types.function import ContractFunctionT, MemberFunctionT - - -class _AnnotationVisitorBase: - - """ - Annotation visitor base class. - - Annotation visitors apply metadata (such as type information) to vyper AST objects. - Immediately after type checking a statement-level node, that node is passed to - `StatementAnnotationVisitor`. Some expression nodes are then passed onward to - `ExpressionAnnotationVisitor` for additional annotation. - """ - - def visit(self, node, *args): - if isinstance(node, self.ignored_types): - return - # iterate over the MRO until we find a matching visitor function - # this lets us use a single function to broadly target several - # node types with a shared parent - for class_ in node.__class__.mro(): - ast_type = class_.__name__ - visitor_fn = getattr(self, f"visit_{ast_type}", None) - if visitor_fn: - visitor_fn(node, *args) - return - raise StructureException(f"Cannot annotate: {node.ast_type}", node) - - -class StatementAnnotationVisitor(_AnnotationVisitorBase): - ignored_types = (vy_ast.Break, vy_ast.Continue, vy_ast.Pass, vy_ast.Raise) - - def __init__(self, fn_node: vy_ast.FunctionDef, namespace: dict) -> None: - self.func = fn_node._metadata["type"] - self.namespace = namespace - self.expr_visitor = ExpressionAnnotationVisitor(self.func) - - assert self.func.n_keyword_args == len(fn_node.args.defaults) - for kwarg in self.func.keyword_args: - self.expr_visitor.visit(kwarg.default_value, kwarg.typ) - - def visit(self, node): - super().visit(node) - - def visit_AnnAssign(self, node): - type_ = get_exact_type_from_node(node.target) - self.expr_visitor.visit(node.target, type_) - self.expr_visitor.visit(node.value, type_) - - def visit_Assert(self, node): - self.expr_visitor.visit(node.test) - - def visit_Assign(self, node): - type_ = get_exact_type_from_node(node.target) - self.expr_visitor.visit(node.target, type_) - self.expr_visitor.visit(node.value, type_) - - def visit_AugAssign(self, node): - type_ = get_exact_type_from_node(node.target) - self.expr_visitor.visit(node.target, type_) - self.expr_visitor.visit(node.value, type_) - - def visit_Expr(self, node): - self.expr_visitor.visit(node.value) - - def visit_If(self, node): - self.expr_visitor.visit(node.test) - - def visit_Log(self, node): - node._metadata["type"] = self.namespace[node.value.func.id] - self.expr_visitor.visit(node.value) - - def visit_Return(self, node): - if node.value is not None: - self.expr_visitor.visit(node.value, self.func.return_type) - - def visit_For(self, node): - if isinstance(node.iter, (vy_ast.Name, vy_ast.Attribute)): - self.expr_visitor.visit(node.iter) - - iter_type = node.target._metadata["type"] - if isinstance(node.iter, vy_ast.List): - # typecheck list literal as static array - len_ = len(node.iter.elements) - self.expr_visitor.visit(node.iter, SArrayT(iter_type, len_)) - - if isinstance(node.iter, vy_ast.Call) and node.iter.func.id == "range": - for a in node.iter.args: - self.expr_visitor.visit(a, iter_type) - for a in node.iter.keywords: - if a.arg == "bound": - self.expr_visitor.visit(a.value, iter_type) - - -class ExpressionAnnotationVisitor(_AnnotationVisitorBase): - ignored_types = () - - def __init__(self, fn_node: ContractFunctionT): - self.func = fn_node - - def visit(self, node, type_=None): - # the statement visitor sometimes passes type information about expressions - super().visit(node, type_) - - def visit_Attribute(self, node, type_): - type_ = get_exact_type_from_node(node) - node._metadata["type"] = type_ - self.visit(node.value, None) - - def visit_BinOp(self, node, type_): - if type_ is None: - type_ = get_common_types(node.left, node.right) - if len(type_) == 1: - type_ = type_.pop() - node._metadata["type"] = type_ - - self.visit(node.left, type_) - self.visit(node.right, type_) - - def visit_BoolOp(self, node, type_): - for value in node.values: - self.visit(value) - - def visit_Call(self, node, type_): - call_type = get_exact_type_from_node(node.func) - node_type = type_ or call_type.fetch_call_return(node) - node._metadata["type"] = node_type - self.visit(node.func) - - if isinstance(call_type, ContractFunctionT): - # function calls - if call_type.is_internal: - self.func.called_functions.add(call_type) - for arg, typ in zip(node.args, call_type.argument_types): - self.visit(arg, typ) - for kwarg in node.keywords: - # We should only see special kwargs - self.visit(kwarg.value, call_type.call_site_kwargs[kwarg.arg].typ) - - elif is_type_t(call_type, EventT): - # events have no kwargs - for arg, typ in zip(node.args, list(call_type.typedef.arguments.values())): - self.visit(arg, typ) - elif is_type_t(call_type, StructT): - # struct ctors - # ctors have no kwargs - for value, arg_type in zip( - node.args[0].values, list(call_type.typedef.members.values()) - ): - self.visit(value, arg_type) - elif isinstance(call_type, MemberFunctionT): - assert len(node.args) == len(call_type.arg_types) - for arg, arg_type in zip(node.args, call_type.arg_types): - self.visit(arg, arg_type) - else: - # builtin functions - arg_types = call_type.infer_arg_types(node) - for arg, arg_type in zip(node.args, arg_types): - self.visit(arg, arg_type) - kwarg_types = call_type.infer_kwarg_types(node) - for kwarg in node.keywords: - self.visit(kwarg.value, kwarg_types[kwarg.arg]) - - def visit_Compare(self, node, type_): - if isinstance(node.op, (vy_ast.In, vy_ast.NotIn)): - if isinstance(node.right, vy_ast.List): - type_ = get_common_types(node.left, *node.right.elements).pop() - self.visit(node.left, type_) - rlen = len(node.right.elements) - self.visit(node.right, SArrayT(type_, rlen)) - else: - type_ = get_exact_type_from_node(node.right) - self.visit(node.right, type_) - if isinstance(type_, EnumT): - self.visit(node.left, type_) - else: - # array membership - self.visit(node.left, type_.value_type) - else: - type_ = get_common_types(node.left, node.right).pop() - self.visit(node.left, type_) - self.visit(node.right, type_) - - def visit_Constant(self, node, type_): - if type_ is None: - possible_types = get_possible_types_from_node(node) - if len(possible_types) == 1: - type_ = possible_types.pop() - node._metadata["type"] = type_ - - def visit_Dict(self, node, type_): - node._metadata["type"] = type_ - - def visit_Index(self, node, type_): - self.visit(node.value, type_) - - def visit_List(self, node, type_): - if type_ is None: - type_ = get_possible_types_from_node(node) - # CMC 2022-04-14 this seems sus. try to only annotate - # if get_possible_types only returns 1 type - if len(type_) >= 1: - type_ = type_.pop() - node._metadata["type"] = type_ - for element in node.elements: - self.visit(element, type_.value_type) - - def visit_Name(self, node, type_): - if isinstance(type_, TYPE_T): - node._metadata["type"] = type_ - else: - node._metadata["type"] = get_exact_type_from_node(node) - - def visit_Subscript(self, node, type_): - node._metadata["type"] = type_ - - if isinstance(type_, TYPE_T): - # don't recurse; can't annotate AST children of type definition - return - - if isinstance(node.value, vy_ast.List): - possible_base_types = get_possible_types_from_node(node.value) - - if len(possible_base_types) == 1: - base_type = possible_base_types.pop() - - elif type_ is not None and len(possible_base_types) > 1: - for possible_type in possible_base_types: - if type_.compare_type(possible_type.value_type): - base_type = possible_type - break - else: - # this should have been caught in - # `get_possible_types_from_node` but wasn't. - raise TypeCheckFailure(f"Expected {type_} but it is not a possible type", node) - - else: - base_type = get_exact_type_from_node(node.value) - - # get the correct type for the index, it might - # not be base_type.key_type - index_types = get_possible_types_from_node(node.slice.value) - index_type = index_types.pop() - - self.visit(node.slice, index_type) - self.visit(node.value, base_type) - - def visit_Tuple(self, node, type_): - node._metadata["type"] = type_ - - if isinstance(type_, TYPE_T): - # don't recurse; can't annotate AST children of type definition - return - - for element, subtype in zip(node.elements, type_.member_types): - self.visit(element, subtype) - - def visit_UnaryOp(self, node, type_): - if type_ is None: - type_ = get_possible_types_from_node(node.operand) - if len(type_) == 1: - type_ = type_.pop() - node._metadata["type"] = type_ - self.visit(node.operand, type_) - - def visit_IfExp(self, node, type_): - if type_ is None: - ts = get_common_types(node.body, node.orelse) - if len(ts) == 1: - type_ = ts.pop() - - node._metadata["type"] = type_ - self.visit(node.test, BoolT()) - self.visit(node.body, type_) - self.visit(node.orelse, type_) diff --git a/vyper/semantics/analysis/common.py b/vyper/semantics/analysis/common.py index 193d1892e1..507eb0a570 100644 --- a/vyper/semantics/analysis/common.py +++ b/vyper/semantics/analysis/common.py @@ -10,10 +10,17 @@ class VyperNodeVisitorBase: def visit(self, node, *args): if isinstance(node, self.ignored_types): return + + # iterate over the MRO until we find a matching visitor function + # this lets us use a single function to broadly target several + # node types with a shared parent + for class_ in node.__class__.mro(): + ast_type = class_.__name__ + visitor_fn = getattr(self, f"visit_{ast_type}", None) + if visitor_fn: + return visitor_fn(node, *args) + node_type = type(node).__name__ - visitor_fn = getattr(self, f"visit_{node_type}", None) - if visitor_fn is None: - raise StructureException( - f"Unsupported syntax for {self.scope_name} namespace: {node_type}", node - ) - visitor_fn(node, *args) + raise StructureException( + f"Unsupported syntax for {self.scope_name} namespace: {node_type}", node + ) diff --git a/vyper/semantics/analysis/local.py b/vyper/semantics/analysis/local.py index b391b33953..647f01c299 100644 --- a/vyper/semantics/analysis/local.py +++ b/vyper/semantics/analysis/local.py @@ -14,11 +14,11 @@ NonPayableViolation, StateAccessViolation, StructureException, + TypeCheckFailure, TypeMismatch, VariableDeclarationException, VyperException, ) -from vyper.semantics.analysis.annotation import StatementAnnotationVisitor from vyper.semantics.analysis.base import VarInfo from vyper.semantics.analysis.common import VyperNodeVisitorBase from vyper.semantics.analysis.utils import ( @@ -34,9 +34,11 @@ from vyper.semantics.environment import CONSTANT_ENVIRONMENT_VARS, MUTABLE_ENVIRONMENT_VARS from vyper.semantics.namespace import get_namespace from vyper.semantics.types import ( + TYPE_T, AddressT, BoolT, DArrayT, + EnumT, EventT, HashMapT, IntegerT, @@ -44,6 +46,8 @@ StringT, StructT, TupleT, + VyperType, + _BytestringT, is_type_t, ) from vyper.semantics.types.function import ContractFunctionT, MemberFunctionT, StateMutability @@ -117,20 +121,8 @@ def _check_iterator_modification( return None -def _validate_revert_reason(msg_node: vy_ast.VyperNode) -> None: - if msg_node: - if isinstance(msg_node, vy_ast.Str): - if not msg_node.value.strip(): - raise StructureException("Reason string cannot be empty", msg_node) - elif not (isinstance(msg_node, vy_ast.Name) and msg_node.id == "UNREACHABLE"): - try: - validate_expected_type(msg_node, StringT(1024)) - except TypeMismatch as e: - raise InvalidType("revert reason must fit within String[1024]") from e - - -def _validate_address_code_attribute(node: vy_ast.Attribute) -> None: - value_type = get_exact_type_from_node(node.value) +# helpers +def _validate_address_code(node: vy_ast.Attribute, value_type: VyperType) -> None: if isinstance(value_type, AddressT) and node.attr == "code": # Validate `slice(
.code, start, length)` where `length` is constant parent = node.get_ancestor() @@ -139,6 +131,7 @@ def _validate_address_code_attribute(node: vy_ast.Attribute) -> None: ok_args = len(parent.args) == 3 and isinstance(parent.args[2], vy_ast.Int) if ok_func and ok_args: return + raise StructureException( "(address).code is only allowed inside of a slice function with a constant length", node ) @@ -160,8 +153,30 @@ def _validate_msg_data_attribute(node: vy_ast.Attribute) -> None: ) +def _validate_msg_value_access(node: vy_ast.Attribute) -> None: + if isinstance(node.value, vy_ast.Name) and node.attr == "value" and node.value.id == "msg": + raise NonPayableViolation("msg.value is not allowed in non-payable functions", node) + + +def _validate_pure_access(node: vy_ast.Attribute, typ: VyperType) -> None: + env_vars = set(CONSTANT_ENVIRONMENT_VARS.keys()) | set(MUTABLE_ENVIRONMENT_VARS.keys()) + if isinstance(node.value, vy_ast.Name) and node.value.id in env_vars: + if isinstance(typ, ContractFunctionT) and typ.mutability == StateMutability.PURE: + return + + raise StateAccessViolation( + "not allowed to query contract or environment variables in pure functions", node + ) + + +def _validate_self_reference(node: vy_ast.Name) -> None: + # CMC 2023-10-19 this detector seems sus, things like `a.b(self)` could slip through + if node.id == "self" and not isinstance(node.get_ancestor(), vy_ast.Attribute): + raise StateAccessViolation("not allowed to query self in pure functions", node) + + class FunctionNodeVisitor(VyperNodeVisitorBase): - ignored_types = (vy_ast.Constant, vy_ast.Pass) + ignored_types = (vy_ast.Pass,) scope_name = "function" def __init__( @@ -171,8 +186,7 @@ def __init__( self.fn_node = fn_node self.namespace = namespace self.func = fn_node._metadata["type"] - self.annotation_visitor = StatementAnnotationVisitor(fn_node, namespace) - self.expr_visitor = _LocalExpressionVisitor() + self.expr_visitor = _ExprVisitor(self.func) # allow internal function params to be mutable location, is_immutable = ( @@ -189,44 +203,13 @@ def __init__( f"Missing or unmatched return statements in function '{fn_node.name}'", fn_node ) - if self.func.mutability == StateMutability.PURE: - node_list = fn_node.get_descendants( - vy_ast.Attribute, - { - "value.id": set(CONSTANT_ENVIRONMENT_VARS.keys()).union( - set(MUTABLE_ENVIRONMENT_VARS.keys()) - ) - }, - ) - - # Add references to `self` as standalone address - self_references = fn_node.get_descendants(vy_ast.Name, {"id": "self"}) - standalone_self = [ - n for n in self_references if not isinstance(n.get_ancestor(), vy_ast.Attribute) - ] - node_list.extend(standalone_self) # type: ignore - - for node in node_list: - t = node._metadata.get("type") - if isinstance(t, ContractFunctionT) and t.mutability == StateMutability.PURE: - # allowed - continue - raise StateAccessViolation( - "not allowed to query contract or environment variables in pure functions", - node_list[0], - ) - if self.func.mutability is not StateMutability.PAYABLE: - node_list = fn_node.get_descendants( - vy_ast.Attribute, {"value.id": "msg", "attr": "value"} - ) - if node_list: - raise NonPayableViolation( - "msg.value is not allowed in non-payable functions", node_list[0] - ) + # visit default args + assert self.func.n_keyword_args == len(fn_node.args.defaults) + for kwarg in self.func.keyword_args: + self.expr_visitor.visit(kwarg.default_value, kwarg.typ) def visit(self, node): super().visit(node) - self.annotation_visitor.visit(node) def visit_AnnAssign(self, node): name = node.get("target.id") @@ -238,16 +221,42 @@ def visit_AnnAssign(self, node): "Memory variables must be declared with an initial value", node ) - type_ = type_from_annotation(node.annotation, DataLocation.MEMORY) - validate_expected_type(node.value, type_) + typ = type_from_annotation(node.annotation, DataLocation.MEMORY) + validate_expected_type(node.value, typ) try: - self.namespace[name] = VarInfo(type_, location=DataLocation.MEMORY) + self.namespace[name] = VarInfo(typ, location=DataLocation.MEMORY) except VyperException as exc: raise exc.with_annotation(node) from None - self.expr_visitor.visit(node.value) - def visit_Assign(self, node): + self.expr_visitor.visit(node.target, typ) + self.expr_visitor.visit(node.value, typ) + + def _validate_revert_reason(self, msg_node: vy_ast.VyperNode) -> None: + if isinstance(msg_node, vy_ast.Str): + if not msg_node.value.strip(): + raise StructureException("Reason string cannot be empty", msg_node) + self.expr_visitor.visit(msg_node, get_exact_type_from_node(msg_node)) + elif not (isinstance(msg_node, vy_ast.Name) and msg_node.id == "UNREACHABLE"): + try: + validate_expected_type(msg_node, StringT(1024)) + except TypeMismatch as e: + raise InvalidType("revert reason must fit within String[1024]") from e + self.expr_visitor.visit(msg_node, get_exact_type_from_node(msg_node)) + # CMC 2023-10-19 nice to have: tag UNREACHABLE nodes with a special type + + def visit_Assert(self, node): + if node.msg: + self._validate_revert_reason(node.msg) + + try: + validate_expected_type(node.test, BoolT()) + except InvalidType: + raise InvalidType("Assertion test value must be a boolean", node.test) + self.expr_visitor.visit(node.test, BoolT()) + + # repeated code for Assign and AugAssign + def _assign_helper(self, node): if isinstance(node.value, vy_ast.Tuple): raise StructureException("Right-hand side of assignment cannot be a tuple", node.value) @@ -260,81 +269,71 @@ def visit_Assign(self, node): validate_expected_type(node.value, target.typ) target.validate_modification(node, self.func.mutability) - self.expr_visitor.visit(node.value) - self.expr_visitor.visit(node.target) + self.expr_visitor.visit(node.value, target.typ) + self.expr_visitor.visit(node.target, target.typ) - def visit_AugAssign(self, node): - if isinstance(node.value, vy_ast.Tuple): - raise StructureException("Right-hand side of assignment cannot be a tuple", node.value) - - lhs_info = get_expr_info(node.target) - - validate_expected_type(node.value, lhs_info.typ) - lhs_info.validate_modification(node, self.func.mutability) - - self.expr_visitor.visit(node.value) - self.expr_visitor.visit(node.target) - - def visit_Raise(self, node): - if node.exc: - _validate_revert_reason(node.exc) - self.expr_visitor.visit(node.exc) + def visit_Assign(self, node): + self._assign_helper(node) - def visit_Assert(self, node): - if node.msg: - _validate_revert_reason(node.msg) - self.expr_visitor.visit(node.msg) + def visit_AugAssign(self, node): + self._assign_helper(node) - try: - validate_expected_type(node.test, BoolT()) - except InvalidType: - raise InvalidType("Assertion test value must be a boolean", node.test) - self.expr_visitor.visit(node.test) + def visit_Break(self, node): + for_node = node.get_ancestor(vy_ast.For) + if for_node is None: + raise StructureException("`break` must be enclosed in a `for` loop", node) def visit_Continue(self, node): + # TODO: use context/state instead of ast search for_node = node.get_ancestor(vy_ast.For) if for_node is None: raise StructureException("`continue` must be enclosed in a `for` loop", node) - def visit_Break(self, node): - for_node = node.get_ancestor(vy_ast.For) - if for_node is None: - raise StructureException("`break` must be enclosed in a `for` loop", node) + def visit_Expr(self, node): + if not isinstance(node.value, vy_ast.Call): + raise StructureException("Expressions without assignment are disallowed", node) - def visit_Return(self, node): - values = node.value - if values is None: - if self.func.return_type: - raise FunctionDeclarationException("Return statement is missing a value", node) - return - elif self.func.return_type is None: - raise FunctionDeclarationException("Function does not return any values", node) + fn_type = get_exact_type_from_node(node.value.func) + if is_type_t(fn_type, EventT): + raise StructureException("To call an event you must use the `log` statement", node) - if isinstance(values, vy_ast.Tuple): - values = values.elements - if not isinstance(self.func.return_type, TupleT): - raise FunctionDeclarationException("Function only returns a single value", node) - if self.func.return_type.length != len(values): - raise FunctionDeclarationException( - f"Incorrect number of return values: " - f"expected {self.func.return_type.length}, got {len(values)}", + if is_type_t(fn_type, StructT): + raise StructureException("Struct creation without assignment is disallowed", node) + + if isinstance(fn_type, ContractFunctionT): + if ( + fn_type.mutability > StateMutability.VIEW + and self.func.mutability <= StateMutability.VIEW + ): + raise StateAccessViolation( + f"Cannot call a mutating function from a {self.func.mutability.value} function", node, ) - for given, expected in zip(values, self.func.return_type.member_types): - validate_expected_type(given, expected) - else: - validate_expected_type(values, self.func.return_type) - self.expr_visitor.visit(node.value) - def visit_If(self, node): - validate_expected_type(node.test, BoolT()) - self.expr_visitor.visit(node.test) - with self.namespace.enter_scope(): - for n in node.body: - self.visit(n) - with self.namespace.enter_scope(): - for n in node.orelse: - self.visit(n) + if ( + self.func.mutability == StateMutability.PURE + and fn_type.mutability != StateMutability.PURE + ): + raise StateAccessViolation( + "Cannot call non-pure function from a pure function", node + ) + + if isinstance(fn_type, MemberFunctionT) and fn_type.is_modifying: + # it's a dotted function call like dynarray.pop() + expr_info = get_expr_info(node.value.func.value) + expr_info.validate_modification(node, self.func.mutability) + + # NOTE: fetch_call_return validates call args. + return_value = fn_type.fetch_call_return(node.value) + if ( + return_value + and not isinstance(fn_type, MemberFunctionT) + and not isinstance(fn_type, ContractFunctionT) + ): + raise StructureException( + f"Function '{fn_type._id}' cannot be called without assigning the result", node + ) + self.expr_visitor.visit(node.value, fn_type) def visit_For(self, node): if isinstance(node.iter, vy_ast.Subscript): @@ -463,19 +462,18 @@ def visit_For(self, node): f"which potentially modifies iterated storage variable '{iter_name}'", call_node, ) - self.expr_visitor.visit(node.iter) if not isinstance(node.target, vy_ast.Name): raise StructureException("Invalid syntax for loop iterator", node.target) for_loop_exceptions = [] iter_name = node.target.id - for type_ in type_list: + for possible_target_type in type_list: # type check the for loop body using each possible type for iterator value with self.namespace.enter_scope(): try: - self.namespace[iter_name] = VarInfo(type_, is_constant=True) + self.namespace[iter_name] = VarInfo(possible_target_type, is_constant=True) except VyperException as exc: raise exc.with_annotation(node) from None @@ -486,17 +484,27 @@ def visit_For(self, node): except (TypeMismatch, InvalidOperation) as exc: for_loop_exceptions.append(exc) else: - # type information is applied directly here because the - # scope is closed prior to the call to - # `StatementAnnotationVisitor` - node.target._metadata["type"] = type_ - - # success -- bail out instead of error handling. + self.expr_visitor.visit(node.target, possible_target_type) + + if isinstance(node.iter, (vy_ast.Name, vy_ast.Attribute)): + iter_type = get_exact_type_from_node(node.iter) + # note CMC 2023-10-23: slightly redundant with how type_list is computed + validate_expected_type(node.target, iter_type.value_type) + self.expr_visitor.visit(node.iter, iter_type) + if isinstance(node.iter, vy_ast.List): + len_ = len(node.iter.elements) + self.expr_visitor.visit(node.iter, SArrayT(possible_target_type, len_)) + if isinstance(node.iter, vy_ast.Call) and node.iter.func.id == "range": + for a in node.iter.args: + self.expr_visitor.visit(a, possible_target_type) + for a in node.iter.keywords: + if a.arg == "bound": + self.expr_visitor.visit(a.value, possible_target_type) + + # success -- do not enter error handling section return - # if we have gotten here, there was an error for - # every type tried for the iterator - + # failed to find a good type. bail out if len(set(str(i) for i in for_loop_exceptions)) == 1: # if every attempt at type checking raised the same exception raise for_loop_exceptions[0] @@ -510,56 +518,20 @@ def visit_For(self, node): "but type checking fails with all possible types:", node, *( - (f"Casting '{iter_name}' as {type_}: {exc.message}", exc.annotations[0]) - for type_, exc in zip(type_list, for_loop_exceptions) + (f"Casting '{iter_name}' as {typ}: {exc.message}", exc.annotations[0]) + for typ, exc in zip(type_list, for_loop_exceptions) ), ) - def visit_Expr(self, node): - if not isinstance(node.value, vy_ast.Call): - raise StructureException("Expressions without assignment are disallowed", node) - - fn_type = get_exact_type_from_node(node.value.func) - if is_type_t(fn_type, EventT): - raise StructureException("To call an event you must use the `log` statement", node) - - if is_type_t(fn_type, StructT): - raise StructureException("Struct creation without assignment is disallowed", node) - - if isinstance(fn_type, ContractFunctionT): - if ( - fn_type.mutability > StateMutability.VIEW - and self.func.mutability <= StateMutability.VIEW - ): - raise StateAccessViolation( - f"Cannot call a mutating function from a {self.func.mutability.value} function", - node, - ) - - if ( - self.func.mutability == StateMutability.PURE - and fn_type.mutability != StateMutability.PURE - ): - raise StateAccessViolation( - "Cannot call non-pure function from a pure function", node - ) - - if isinstance(fn_type, MemberFunctionT) and fn_type.is_modifying: - # it's a dotted function call like dynarray.pop() - expr_info = get_expr_info(node.value.func.value) - expr_info.validate_modification(node, self.func.mutability) - - # NOTE: fetch_call_return validates call args. - return_value = fn_type.fetch_call_return(node.value) - if ( - return_value - and not isinstance(fn_type, MemberFunctionT) - and not isinstance(fn_type, ContractFunctionT) - ): - raise StructureException( - f"Function '{fn_type._id}' cannot be called without assigning the result", node - ) - self.expr_visitor.visit(node.value) + def visit_If(self, node): + validate_expected_type(node.test, BoolT()) + self.expr_visitor.visit(node.test, BoolT()) + with self.namespace.enter_scope(): + for n in node.body: + self.visit(n) + with self.namespace.enter_scope(): + for n in node.orelse: + self.visit(n) def visit_Log(self, node): if not isinstance(node.value, vy_ast.Call): @@ -572,62 +544,249 @@ def visit_Log(self, node): f"Cannot emit logs from {self.func.mutability.value.lower()} functions", node ) f.fetch_call_return(node.value) - self.expr_visitor.visit(node.value) + node._metadata["type"] = f.typedef + self.expr_visitor.visit(node.value, f.typedef) + + def visit_Raise(self, node): + if node.exc: + self._validate_revert_reason(node.exc) + def visit_Return(self, node): + values = node.value + if values is None: + if self.func.return_type: + raise FunctionDeclarationException("Return statement is missing a value", node) + return + elif self.func.return_type is None: + raise FunctionDeclarationException("Function does not return any values", node) -class _LocalExpressionVisitor(VyperNodeVisitorBase): - ignored_types = (vy_ast.Constant, vy_ast.Name) + if isinstance(values, vy_ast.Tuple): + values = values.elements + if not isinstance(self.func.return_type, TupleT): + raise FunctionDeclarationException("Function only returns a single value", node) + if self.func.return_type.length != len(values): + raise FunctionDeclarationException( + f"Incorrect number of return values: " + f"expected {self.func.return_type.length}, got {len(values)}", + node, + ) + for given, expected in zip(values, self.func.return_type.member_types): + validate_expected_type(given, expected) + else: + validate_expected_type(values, self.func.return_type) + self.expr_visitor.visit(node.value, self.func.return_type) + + +class _ExprVisitor(VyperNodeVisitorBase): scope_name = "function" - def visit_Attribute(self, node: vy_ast.Attribute) -> None: - self.visit(node.value) + def __init__(self, fn_node: ContractFunctionT): + self.func = fn_node + + def visit(self, node, typ): + # recurse and typecheck in case we are being fed the wrong type for + # some reason. note that `validate_expected_type` is unnecessary + # for nodes that already call `get_exact_type_from_node` and + # `get_possible_types_from_node` because `validate_expected_type` + # would be calling the same function again. + # CMC 2023-06-27 would be cleanest to call validate_expected_type() + # before recursing but maybe needs some refactoring before that + # can happen. + super().visit(node, typ) + + # annotate + node._metadata["type"] = typ + + def visit_Attribute(self, node: vy_ast.Attribute, typ: VyperType) -> None: _validate_msg_data_attribute(node) - _validate_address_code_attribute(node) - - def visit_BinOp(self, node: vy_ast.BinOp) -> None: - self.visit(node.left) - self.visit(node.right) - - def visit_BoolOp(self, node: vy_ast.BoolOp) -> None: - for value in node.values: # type: ignore[attr-defined] - self.visit(value) - - def visit_Call(self, node: vy_ast.Call) -> None: - self.visit(node.func) - for arg in node.args: - self.visit(arg) - for kwarg in node.keywords: - self.visit(kwarg.value) - - def visit_Compare(self, node: vy_ast.Compare) -> None: - self.visit(node.left) # type: ignore[attr-defined] - self.visit(node.right) # type: ignore[attr-defined] - - def visit_Dict(self, node: vy_ast.Dict) -> None: - for key in node.keys: - self.visit(key) + + # CMC 2023-10-19 TODO generalize this to mutability check on every node. + # something like, + # if self.func.mutability < expr_info.mutability: + # raise ... + + if self.func.mutability != StateMutability.PAYABLE: + _validate_msg_value_access(node) + + if self.func.mutability == StateMutability.PURE: + _validate_pure_access(node, typ) + + value_type = get_exact_type_from_node(node.value) + _validate_address_code(node, value_type) + + self.visit(node.value, value_type) + + def visit_BinOp(self, node: vy_ast.BinOp, typ: VyperType) -> None: + validate_expected_type(node.left, typ) + self.visit(node.left, typ) + + rtyp = typ + if isinstance(node.op, (vy_ast.LShift, vy_ast.RShift)): + rtyp = get_possible_types_from_node(node.right).pop() + + validate_expected_type(node.right, rtyp) + + self.visit(node.right, rtyp) + + def visit_BoolOp(self, node: vy_ast.BoolOp, typ: VyperType) -> None: + assert typ == BoolT() # sanity check for value in node.values: - self.visit(value) + validate_expected_type(value, BoolT()) + self.visit(value, BoolT()) + + def visit_Call(self, node: vy_ast.Call, typ: VyperType) -> None: + call_type = get_exact_type_from_node(node.func) + # except for builtin functions, `get_exact_type_from_node` + # already calls `validate_expected_type` on the call args + # and kwargs via `call_type.fetch_call_return` + self.visit(node.func, call_type) + + if isinstance(call_type, ContractFunctionT): + # function calls + if call_type.is_internal: + self.func.called_functions.add(call_type) + for arg, typ in zip(node.args, call_type.argument_types): + self.visit(arg, typ) + for kwarg in node.keywords: + # We should only see special kwargs + typ = call_type.call_site_kwargs[kwarg.arg].typ + self.visit(kwarg.value, typ) + + elif is_type_t(call_type, EventT): + # events have no kwargs + expected_types = call_type.typedef.arguments.values() + for arg, typ in zip(node.args, expected_types): + self.visit(arg, typ) + elif is_type_t(call_type, StructT): + # struct ctors + # ctors have no kwargs + expected_types = call_type.typedef.members.values() + for value, arg_type in zip(node.args[0].values, expected_types): + self.visit(value, arg_type) + elif isinstance(call_type, MemberFunctionT): + assert len(node.args) == len(call_type.arg_types) + for arg, arg_type in zip(node.args, call_type.arg_types): + self.visit(arg, arg_type) + else: + # builtin functions + arg_types = call_type.infer_arg_types(node) + # `infer_arg_types` already calls `validate_expected_type` + for arg, arg_type in zip(node.args, arg_types): + self.visit(arg, arg_type) + kwarg_types = call_type.infer_kwarg_types(node) + for kwarg in node.keywords: + self.visit(kwarg.value, kwarg_types[kwarg.arg]) + + def visit_Compare(self, node: vy_ast.Compare, typ: VyperType) -> None: + if isinstance(node.op, (vy_ast.In, vy_ast.NotIn)): + # membership in list literal - `x in [a, b, c]` + # needle: ltyp, haystack: rtyp + if isinstance(node.right, vy_ast.List): + ltyp = get_common_types(node.left, *node.right.elements).pop() + + rlen = len(node.right.elements) + rtyp = SArrayT(ltyp, rlen) + validate_expected_type(node.right, rtyp) + else: + rtyp = get_exact_type_from_node(node.right) + if isinstance(rtyp, EnumT): + # enum membership - `some_enum in other_enum` + ltyp = rtyp + else: + # array membership - `x in my_list_variable` + assert isinstance(rtyp, (SArrayT, DArrayT)) + ltyp = rtyp.value_type - def visit_Index(self, node: vy_ast.Index) -> None: - self.visit(node.value) + validate_expected_type(node.left, ltyp) - def visit_List(self, node: vy_ast.List) -> None: - for element in node.elements: - self.visit(element) + self.visit(node.left, ltyp) + self.visit(node.right, rtyp) + + else: + # ex. a < b + cmp_typ = get_common_types(node.left, node.right).pop() + if isinstance(cmp_typ, _BytestringT): + # for bytestrings, get_common_types automatically downcasts + # to the smaller common type - that will annotate with the + # wrong type, instead use get_exact_type_from_node (which + # resolves to the right type for bytestrings anyways). + ltyp = get_exact_type_from_node(node.left) + rtyp = get_exact_type_from_node(node.right) + else: + ltyp = rtyp = cmp_typ + validate_expected_type(node.left, ltyp) + validate_expected_type(node.right, rtyp) + + self.visit(node.left, ltyp) + self.visit(node.right, rtyp) + + def visit_Constant(self, node: vy_ast.Constant, typ: VyperType) -> None: + validate_expected_type(node, typ) - def visit_Subscript(self, node: vy_ast.Subscript) -> None: - self.visit(node.value) - self.visit(node.slice) + def visit_Index(self, node: vy_ast.Index, typ: VyperType) -> None: + validate_expected_type(node.value, typ) + self.visit(node.value, typ) - def visit_Tuple(self, node: vy_ast.Tuple) -> None: + def visit_List(self, node: vy_ast.List, typ: VyperType) -> None: + assert isinstance(typ, (SArrayT, DArrayT)) for element in node.elements: - self.visit(element) + validate_expected_type(element, typ.value_type) + self.visit(element, typ.value_type) - def visit_UnaryOp(self, node: vy_ast.UnaryOp) -> None: - self.visit(node.operand) # type: ignore[attr-defined] + def visit_Name(self, node: vy_ast.Name, typ: VyperType) -> None: + if self.func.mutability == StateMutability.PURE: + _validate_self_reference(node) + + if not isinstance(typ, TYPE_T): + validate_expected_type(node, typ) + + def visit_Subscript(self, node: vy_ast.Subscript, typ: VyperType) -> None: + if isinstance(typ, TYPE_T): + # don't recurse; can't annotate AST children of type definition + return + + if isinstance(node.value, vy_ast.List): + possible_base_types = get_possible_types_from_node(node.value) + + for possible_type in possible_base_types: + if typ.compare_type(possible_type.value_type): + base_type = possible_type + break + else: + # this should have been caught in + # `get_possible_types_from_node` but wasn't. + raise TypeCheckFailure(f"Expected {typ} but it is not a possible type", node) + + else: + base_type = get_exact_type_from_node(node.value) + + # get the correct type for the index, it might + # not be exactly base_type.key_type + # note: index_type is validated in types_from_Subscript + index_types = get_possible_types_from_node(node.slice.value) + index_type = index_types.pop() + + self.visit(node.slice, index_type) + self.visit(node.value, base_type) + + def visit_Tuple(self, node: vy_ast.Tuple, typ: VyperType) -> None: + if isinstance(typ, TYPE_T): + # don't recurse; can't annotate AST children of type definition + return + + assert isinstance(typ, TupleT) + for element, subtype in zip(node.elements, typ.member_types): + validate_expected_type(element, subtype) + self.visit(element, subtype) - def visit_IfExp(self, node: vy_ast.IfExp) -> None: - self.visit(node.test) - self.visit(node.body) - self.visit(node.orelse) + def visit_UnaryOp(self, node: vy_ast.UnaryOp, typ: VyperType) -> None: + validate_expected_type(node.operand, typ) + self.visit(node.operand, typ) + + def visit_IfExp(self, node: vy_ast.IfExp, typ: VyperType) -> None: + validate_expected_type(node.test, BoolT()) + self.visit(node.test, BoolT()) + validate_expected_type(node.body, typ) + self.visit(node.body, typ) + validate_expected_type(node.orelse, typ) + self.visit(node.orelse, typ) diff --git a/vyper/semantics/analysis/utils.py b/vyper/semantics/analysis/utils.py index 4f911764e0..afa6b56838 100644 --- a/vyper/semantics/analysis/utils.py +++ b/vyper/semantics/analysis/utils.py @@ -312,10 +312,17 @@ def types_from_Constant(self, node): def types_from_List(self, node): # literal array if _is_empty_list(node): - # empty list literal `[]` ret = [] - # subtype can be anything - for t in types.PRIMITIVE_TYPES.values(): + + if len(node.elements) > 0: + # empty nested list literals `[[], []]` + subtypes = self.get_possible_types_from_node(node.elements[0]) + else: + # empty list literal `[]` + # subtype can be anything + subtypes = types.PRIMITIVE_TYPES.values() + + for t in subtypes: # 1 is minimum possible length for dynarray, # can be assigned to anything if isinstance(t, VyperType):