From 5bc54c0c834c3a63dcdb86911aad99527dbf1954 Mon Sep 17 00:00:00 2001 From: tserg <8017125+tserg@users.noreply.github.com> Date: Sun, 7 Jan 2024 01:06:38 +0800 Subject: [PATCH] simpliy visit_For --- .../features/iteration/test_for_in_list.py | 4 +- vyper/semantics/analysis/local.py | 92 ++++++------------- 2 files changed, 28 insertions(+), 68 deletions(-) diff --git a/tests/functional/codegen/features/iteration/test_for_in_list.py b/tests/functional/codegen/features/iteration/test_for_in_list.py index 33ad59370e..22fd7ccb43 100644 --- a/tests/functional/codegen/features/iteration/test_for_in_list.py +++ b/tests/functional/codegen/features/iteration/test_for_in_list.py @@ -648,7 +648,7 @@ def foo(x: int128): """ @external def foo(): - for i: uint256 in range(-3): + for i: int128 in range(-3): pass """, StructureException, @@ -776,7 +776,7 @@ def test_for() -> int128: a = i return a """, - TypeMismatch, + InvalidType, ), ( """ diff --git a/vyper/semantics/analysis/local.py b/vyper/semantics/analysis/local.py index 0f6af0c01e..76b139b055 100644 --- a/vyper/semantics/analysis/local.py +++ b/vyper/semantics/analysis/local.py @@ -6,7 +6,6 @@ ExceptionList, FunctionDeclarationException, ImmutableViolation, - InvalidOperation, InvalidType, IteratorException, NonPayableViolation, @@ -39,7 +38,6 @@ EventT, FlagT, HashMapT, - IntegerT, SArrayT, StringT, StructT, @@ -350,7 +348,6 @@ def visit_For(self, node): raise StructureException("Cannot iterate over a nested list", node.iter) iter_type = type_from_annotation(node.iter_type, DataLocation.MEMORY) - node.target._metadata["type"] = iter_type if isinstance(node.iter, vy_ast.Call): # iteration via range() @@ -358,7 +355,7 @@ def visit_For(self, node): raise IteratorException( "Cannot iterate over the result of a function call", node.iter ) - type_list = _analyse_range_call(node.iter) + _analyse_range_call(node.iter, iter_type) else: # iteration over a variable or literal list @@ -366,14 +363,10 @@ def visit_For(self, node): if isinstance(iter_val, vy_ast.List) and len(iter_val.elements) == 0: raise StructureException("For loop must have at least 1 iteration", node.iter) - type_list = [ - i.value_type - for i in get_possible_types_from_node(node.iter) - if isinstance(i, (DArrayT, SArrayT)) - ] - - if not type_list: - raise InvalidType("Not an iterable type", node.iter) + if not any( + isinstance(i, (DArrayT, SArrayT)) for i in get_possible_types_from_node(node.iter) + ): + raise InvalidType("Not an iterable type", node.iter) if isinstance(node.iter, (vy_ast.Name, vy_ast.Attribute)): # check for references to the iterated value within the body of the loop @@ -420,58 +413,31 @@ def visit_For(self, node): if not isinstance(node.target, vy_ast.Name): raise StructureException("Invalid syntax for loop iterator", node.target) - for_loop_exceptions = [] iter_name = node.target.id with self.namespace.enter_scope(): self.namespace[iter_name] = VarInfo( iter_type, modifiability=Modifiability.RUNTIME_CONSTANT ) - try: - for stmt in node.body: - self.visit(stmt) - - self.expr_visitor.visit(node.target, iter_type) - - if isinstance(node.iter, (vy_ast.Name, vy_ast.Attribute)): - iter_type = get_exact_type_from_node(node.iter) - # note CMC 2023-10-23: slightly redundant with how type_list is computed - validate_expected_type(node.target, iter_type.value_type) - self.expr_visitor.visit(node.iter, iter_type) - if isinstance(node.iter, vy_ast.List): - len_ = len(node.iter.elements) - self.expr_visitor.visit(node.iter, SArrayT(iter_type, len_)) - if isinstance(node.iter, vy_ast.Call) and node.iter.func.id == "range": - for a in node.iter.args: - self.expr_visitor.visit(a, iter_type) - for a in node.iter.keywords: - if a.arg == "bound": - self.expr_visitor.visit(a.value, iter_type) - - except (TypeMismatch, InvalidOperation) as exc: - for_loop_exceptions.append(exc) - else: - # success -- do not enter error handling section - return - - # failed to find a good type. bail out - if len(set(str(i) for i in for_loop_exceptions)) == 1: - # if every attempt at type checking raised the same exception - raise for_loop_exceptions[0] - - # return an aggregate TypeMismatch that shows all possible exceptions - # depending on which type is used - types_str = [str(i) for i in type_list] - given_str = f"{', '.join(types_str[:1])} or {types_str[-1]}" - raise TypeMismatch( - f"Iterator value '{iter_name}' may be cast as {given_str}, " - "but type checking fails with all possible types:", - node, - *( - (f"Casting '{iter_name}' as {typ}: {exc.message}", exc.annotations[0]) - for typ, exc in zip(type_list, for_loop_exceptions) - ), - ) + for stmt in node.body: + self.visit(stmt) + + self.expr_visitor.visit(node.target, iter_type) + + if isinstance(node.iter, (vy_ast.Name, vy_ast.Attribute)): + iter_type = get_exact_type_from_node(node.iter) + # note CMC 2023-10-23: slightly redundant with how type_list is computed + validate_expected_type(node.target, iter_type.value_type) + self.expr_visitor.visit(node.iter, iter_type) + if isinstance(node.iter, vy_ast.List): + len_ = len(node.iter.elements) + self.expr_visitor.visit(node.iter, SArrayT(iter_type, len_)) + if isinstance(node.iter, vy_ast.Call) and node.iter.func.id == "range": + for a in node.iter.args: + self.expr_visitor.visit(a, iter_type) + for a in node.iter.keywords: + if a.arg == "bound": + self.expr_visitor.visit(a.value, iter_type) def visit_If(self, node): validate_expected_type(node.test, BoolT()) @@ -748,7 +714,7 @@ def visit_IfExp(self, node: vy_ast.IfExp, typ: VyperType) -> None: self.visit(node.orelse, typ) -def _analyse_range_call(node: vy_ast.Call) -> list[VyperType]: +def _analyse_range_call(node: vy_ast.Call, iter_type: VyperType) -> list[VyperType]: """ Check that the arguments to a range() call are valid. :param node: call to range() @@ -761,11 +727,7 @@ def _analyse_range_call(node: vy_ast.Call) -> list[VyperType]: all_args = (start, end, *kwargs.values()) for arg1 in all_args: - validate_expected_type(arg1, IntegerT.any()) - - type_list = get_common_types(*all_args) - if not type_list: - raise TypeMismatch("Iterator values are of different types", node) + validate_expected_type(arg1, iter_type) if "bound" in kwargs: bound = kwargs["bound"] @@ -785,5 +747,3 @@ def _analyse_range_call(node: vy_ast.Call) -> list[VyperType]: raise StateAccessViolation(error, arg) if end.value <= start.value: raise StructureException("End must be greater than start", end) - - return type_list