From 6eca837db02d4b5713e93e4c7514b74c1d2a8fb0 Mon Sep 17 00:00:00 2001 From: tserg <8017125+tserg@users.noreply.github.com> Date: Fri, 29 Dec 2023 16:35:13 +0800 Subject: [PATCH] fix range --- .../features/iteration/test_for_in_list.py | 2 +- vyper/semantics/analysis/local.py | 54 ++++++++++--------- vyper/semantics/analysis/pre_typecheck.py | 2 +- 3 files changed, 31 insertions(+), 27 deletions(-) diff --git a/tests/functional/codegen/features/iteration/test_for_in_list.py b/tests/functional/codegen/features/iteration/test_for_in_list.py index 5544b896a2..bc1a12ae9e 100644 --- a/tests/functional/codegen/features/iteration/test_for_in_list.py +++ b/tests/functional/codegen/features/iteration/test_for_in_list.py @@ -776,7 +776,7 @@ def test_for() -> int128: a = i return a """, - InvalidType, + TypeMismatch, ), ( """ diff --git a/vyper/semantics/analysis/local.py b/vyper/semantics/analysis/local.py index f6508c3032..5683883da6 100644 --- a/vyper/semantics/analysis/local.py +++ b/vyper/semantics/analysis/local.py @@ -432,26 +432,27 @@ def visit_For(self, node): with NodeMetadata.enter_typechecker_speculation(): for stmt in node.body: self.visit(stmt) + + self.expr_visitor.visit(node.target, possible_target_type) + + if isinstance(node.iter, (vy_ast.Name, vy_ast.Attribute)): + iter_type = get_exact_type_from_node(node.iter) + # note CMC 2023-10-23: slightly redundant with how type_list is computed + validate_expected_type(node.target, iter_type.value_type) + self.expr_visitor.visit(node.iter, iter_type) + if isinstance(node.iter, vy_ast.List): + len_ = len(node.iter.elements) + self.expr_visitor.visit(node.iter, SArrayT(possible_target_type, len_)) + if isinstance(node.iter, vy_ast.Call) and node.iter.func.id == "range": + for a in node.iter.args: + self.expr_visitor.visit(a, possible_target_type) + for a in node.iter.keywords: + if a.arg == "bound": + self.expr_visitor.visit(a.value, possible_target_type) + except (TypeMismatch, InvalidOperation) as exc: for_loop_exceptions.append(exc) else: - self.expr_visitor.visit(node.target, possible_target_type) - - if isinstance(node.iter, (vy_ast.Name, vy_ast.Attribute)): - iter_type = get_exact_type_from_node(node.iter) - # note CMC 2023-10-23: slightly redundant with how type_list is computed - validate_expected_type(node.target, iter_type.value_type) - self.expr_visitor.visit(node.iter, iter_type) - if isinstance(node.iter, vy_ast.List): - len_ = len(node.iter.elements) - self.expr_visitor.visit(node.iter, SArrayT(possible_target_type, len_)) - if isinstance(node.iter, vy_ast.Call) and node.iter.func.id == "range": - for a in node.iter.args: - self.expr_visitor.visit(a, possible_target_type) - for a in node.iter.keywords: - if a.arg == "bound": - self.expr_visitor.visit(a.value, possible_target_type) - # success -- do not enter error handling section return @@ -757,8 +758,14 @@ def _analyse_range_call(node: vy_ast.Call) -> list[VyperType]: """ validate_call_args(node, (1, 2), kwargs=["bound"]) kwargs = {s.arg: s.value for s in node.keywords or []} + arg0 = node.args[0].get_folded_value if node.args[0].has_folded_value else node.args[0] start, end = ( - (vy_ast.Int(value=0), node.args[0]) if len(node.args) == 1 else [i for i in node.args] + (vy_ast.Int(value=0), arg0) + if len(node.args) == 1 + else ( + arg0, + node.args[1].get_folded_value() if node.args[1].has_folded_value else node.args[1], + ) ) all_args = (start, end, *kwargs.values()) @@ -780,14 +787,11 @@ def _analyse_range_call(node: vy_ast.Call) -> list[VyperType]: error = "Please remove the `bound=` kwarg when using range with constants" raise StructureException(error, bound) else: - folded_start, folded_end = [ - i.get_folded_value() if i.has_folded_value else i for i in (start, end) - ] - for original_arg, folded_arg in zip([start, end], [folded_start, folded_end]): - if not isinstance(folded_arg, vy_ast.Num): + for arg in (start, end): + if not isinstance(arg, vy_ast.Num): error = "Value must be a literal integer, unless a bound is specified" - raise StateAccessViolation(error, original_arg) - if folded_end.value <= folded_start.value: + raise StateAccessViolation(error, arg) + if end.value <= start.value: raise StructureException("End must be greater than start", end) return type_list diff --git a/vyper/semantics/analysis/pre_typecheck.py b/vyper/semantics/analysis/pre_typecheck.py index cb73271aa6..59afbfa94f 100644 --- a/vyper/semantics/analysis/pre_typecheck.py +++ b/vyper/semantics/analysis/pre_typecheck.py @@ -54,7 +54,7 @@ def _get_constants(node: vy_ast.Module) -> dict: val = c.value.get_folded_value() except UnfoldableNode: continue - + # note that if a constant is redefined, its value will be overwritten, # but it is okay because the syntax error is handled downstream constants[name] = val