diff --git a/vyper/ast/nodes.py b/vyper/ast/nodes.py index 9ad556b470..fffd3ca7cd 100644 --- a/vyper/ast/nodes.py +++ b/vyper/ast/nodes.py @@ -1546,7 +1546,7 @@ class IfExp(ExprNode): class For(Stmt): - __slots__ = ("iter", "iter_type", "target", "body") + __slots__ = ("target", "iter", "body") _only_empty_fields = ("orelse",) diff --git a/vyper/ast/parse.py b/vyper/ast/parse.py index 3bbc24c073..1e869dfb87 100644 --- a/vyper/ast/parse.py +++ b/vyper/ast/parse.py @@ -54,13 +54,9 @@ def parse_to_ast_with_settings( """ if "\x00" in source_code: raise ParserException("No null bytes (\\x00) allowed in the source code.") - settings, loop_var_annotations, class_types, reformatted_code = pre_parse(source_code) + settings, class_types, for_loop_annotations, reformatted_code = pre_parse(source_code) try: py_ast = python_ast.parse(reformatted_code) - - for k, v in loop_var_annotations.items(): - parsed_v = python_ast.parse(v["source_code"]) - loop_var_annotations[k]["parsed_ast"] = parsed_v except SyntaxError as e: # TODO: Ensure 1-to-1 match of source_code:reformatted_code SyntaxErrors raise SyntaxException(str(e), source_code, e.lineno, e.offset) from e @@ -76,13 +72,16 @@ def parse_to_ast_with_settings( annotate_python_ast( py_ast, source_code, - loop_var_annotations, class_types, + for_loop_annotations, source_id, module_path=module_path, resolved_path=resolved_path, ) + # postcondition: consumed all the for loop annotations + assert len(for_loop_annotations) == 0 + # Convert to Vyper AST. module = vy_ast.get_node(py_ast) assert isinstance(module, vy_ast.Module) # mypy hint @@ -123,8 +122,8 @@ class AnnotatingVisitor(python_ast.NodeTransformer): def __init__( self, source_code: str, - loop_var_annotations: dict[int, dict[str, Any]], - modification_offsets: Optional[ModificationOffsets], + modification_offsets: ModificationOffsets, + for_loop_annotations: dict, tokens: asttokens.ASTTokens, source_id: int, module_path: Optional[str] = None, @@ -134,12 +133,11 @@ def __init__( self._source_id = source_id self._module_path = module_path self._resolved_path = resolved_path - self._source_code: str = source_code + self._source_code = source_code + self._modification_offsets = modification_offsets + self._for_loop_annotations = for_loop_annotations + self.counter: int = 0 - self._modification_offsets = {} - self._loop_var_annotations = loop_var_annotations - if modification_offsets is not None: - self._modification_offsets = modification_offsets def generic_visit(self, node): """ @@ -221,6 +219,45 @@ def visit_ClassDef(self, node): node.ast_type = self._modification_offsets[(node.lineno, node.col_offset)] return node + def visit_For(self, node): + """ + Visit a For node, splicing in the loop variable annotation provided by + the pre-parser + """ + raw_annotation = self._for_loop_annotations.pop((node.lineno, node.col_offset)) + + if not raw_annotation: + # a common case for people migrating to 0.4.0, provide a more + # specific error message than "invalid type annotation" + raise SyntaxException( + "missing type annotation\n\n" + "(hint: did you mean something like " + f"`for {node.target.id}: uint256 in ...`?)\n", + self._source_code, + node.lineno, + node.col_offset, + ) + + try: + annotation = python_ast.parse(raw_annotation, mode="eval") + except SyntaxError as e: + raise SyntaxException( + "invalid type annotation", self._source_code, node.lineno, node.col_offset + ) from e + + assert isinstance(annotation, python_ast.Expression) + annotation = annotation.body + + node.target_annotation = annotation + + old_target = node.target + new_target = python_ast.AnnAssign(target=old_target, annotation=annotation, simple=1) + node.target = new_target + + self.generic_visit(node) + + return node + def visit_Expr(self, node): """ Convert the `Yield` node into a Vyper-specific node type. @@ -240,28 +277,6 @@ def visit_Expr(self, node): return node - def visit_For(self, node): - """ - Annotate `For` nodes with the iterator's type annotation that was extracted - during pre-parsing. - """ - iter_type_info = self._loop_var_annotations.get(node.lineno) - if not iter_type_info: - raise SyntaxException( - "For loop iterator requires type annotation", - self._source_code, - node.iter.lineno, - node.iter.col_offset, - ) - - iter_type_ast = iter_type_info["parsed_ast"] - self.generic_visit(iter_type_ast) - self.generic_visit(node) - - node.iter_type = iter_type_ast.body[0].value - - return node - def visit_Subscript(self, node): """ Maintain consistency of `Subscript.slice` across python versions. @@ -322,13 +337,10 @@ def visit_Num(self, node): # modify vyper AST type according to the format of the literal value self.generic_visit(node) - # the type annotation of a for loop iterator is removed from the source - # code during pre-parsing, and therefore the `node_source_code` attribute - # of an integer in the type annotation would not be available e.g. DynArray[uint256, 3] - value = node.node_source_code if hasattr(node, "node_source_code") else None + value = node.node_source_code # deduce non base-10 types based on prefix - if value and value.lower()[:2] == "0x": + if value.lower()[:2] == "0x": if len(value) % 2: raise SyntaxException( "Hex notation requires an even number of digits", @@ -339,7 +351,7 @@ def visit_Num(self, node): node.ast_type = "Hex" node.n = value - elif value and value.lower()[:2] == "0b": + elif value.lower()[:2] == "0b": node.ast_type = "Bytes" mod = (len(value) - 2) % 8 if mod: @@ -389,8 +401,8 @@ def visit_UnaryOp(self, node): def annotate_python_ast( parsed_ast: python_ast.AST, source_code: str, - loop_var_annotations: dict[int, dict[str, Any]], - modification_offsets: Optional[ModificationOffsets] = None, + modification_offsets: ModificationOffsets, + for_loop_annotations: dict, source_id: int = 0, module_path: Optional[str] = None, resolved_path: Optional[str] = None, @@ -418,8 +430,8 @@ def annotate_python_ast( tokens = asttokens.ASTTokens(source_code, tree=cast(Optional[python_ast.Module], parsed_ast)) visitor = AnnotatingVisitor( source_code, - loop_var_annotations, modification_offsets, + for_loop_annotations, tokens, source_id, module_path=module_path, diff --git a/vyper/ast/pre_parser.py b/vyper/ast/pre_parser.py index d4115ab2b3..10f895d9b0 100644 --- a/vyper/ast/pre_parser.py +++ b/vyper/ast/pre_parser.py @@ -1,5 +1,5 @@ -import io import enum +import io import re from tokenize import COMMENT, NAME, OP, TokenError, TokenInfo, tokenize, untokenize from typing import Any @@ -44,16 +44,19 @@ def validate_version_pragma(version_str: str, start: ParserPosition) -> None: start, ) + class ForParserState(enum.Enum): NOT_RUNNING = enum.auto() START_SOON = enum.auto() RUNNING = enum.auto() + # a simple state machine which allows us to handle loop variable annotations # (which are rejected by the python parser due to pep-526, so we scoop up the # tokens between `:` and `in` and parse them and add them back in later). class ForParser: - def __init__(self): + def __init__(self, code): + self._code = code self.annotations = {} self._current_annotation = None @@ -74,20 +77,27 @@ def consume(self, token): # state machine: start slurping tokens if token.type == OP and token.string == ":": self._state = ForParserState.RUNNING - assert self._current_annotation is None, (self._current_for_loop, self._current_annotation) - self._current_annotation = [] - return False - if self._state != ForParserState.RUNNING: - return False + # sanity check -- this should never really happen, but if it does, + # try to raise an exception which pinpoints the source. + if self._current_annotation is not None: + raise SyntaxException( + "for loop parse error", self._code, token.start[0], token.start[1] + ) + + self._current_annotation = [] + return True # do not add ":" to tokens. # state machine: end slurping tokens if token.type == NAME and token.string == "in": self._state = ForParserState.NOT_RUNNING - self.annotations[self._current_for_loop] = self._current_annotation + self.annotations[self._current_for_loop] = self._current_annotation or [] self._current_annotation = None return False + if self._state != ForParserState.RUNNING: + return False + # slurp the token self._current_annotation.append(token) return True @@ -101,7 +111,7 @@ def consume(self, token): VYPER_EXPRESSION_TYPES = {"log"} -def pre_parse(code: str) -> tuple[Settings, dict[int, dict[str, Any]], ModificationOffsets, str]: +def pre_parse(code: str) -> tuple[Settings, ModificationOffsets, dict, str]: """ Re-formats a vyper source string into a python source string and performs some validation. More specifically, @@ -127,15 +137,15 @@ def pre_parse(code: str) -> tuple[Settings, dict[int, dict[str, Any]], Modificat Compilation settings based on the directives in the source code ModificationOffsets A mapping of class names to their original class types. - dict[int, dict[str, Any]] - A mapping of line numbers of `For` nodes to the type annotation of the iterator + dict[tuple[int, int], str] + A mapping of line/column offsets of `For` nodes to the annotation of the for loop target str Reformatted python source string. """ result = [] modification_offsets: ModificationOffsets = {} settings = Settings() - for_parser = ForParser() + for_parser = ForParser(code) try: code_bytes = code.encode("utf-8") @@ -211,15 +221,7 @@ def pre_parse(code: str) -> tuple[Settings, dict[int, dict[str, Any]], Modificat for_loop_annotations = {} for k, v in for_parser.annotations.items(): - updated_v = untokenize(v) - # print("untokenized v: ", updated_v) - # updated_v = updated_v.replace("\\", "") - # updated_v = updated_v.replace("\n", "") - # import textwrap - - # print("updated v: ", textwrap.dedent(updated_v)) - for_loop_annotations[k] = updated_v + v_source = untokenize(v).replace("\\", "").strip() + for_loop_annotations[k] = v_source - # print("untokenized result: ", type(untokenize(result))) - # print("untokenized result decoded: ", untokenize(result).decode("utf-8")) return settings, modification_offsets, for_loop_annotations, untokenize(result).decode("utf-8") diff --git a/vyper/codegen/stmt.py b/vyper/codegen/stmt.py index bc29a79734..5487421a06 100644 --- a/vyper/codegen/stmt.py +++ b/vyper/codegen/stmt.py @@ -297,11 +297,11 @@ def _parse_For_list(self): with self.context.range_scope(): iter_list = Expr(self.stmt.iter, self.context).ir_node - target_type = self.stmt.target._metadata["type"] + target_type = self.stmt.target.target._metadata["type"] assert target_type == iter_list.typ.value_type # user-supplied name for loop variable - varname = self.stmt.target.id + varname = self.stmt.target.target.id loop_var = IRnode.from_list( self.context.new_variable(varname, target_type), typ=target_type, location=MEMORY ) diff --git a/vyper/semantics/analysis/local.py b/vyper/semantics/analysis/local.py index 76b139b055..effc545a0c 100644 --- a/vyper/semantics/analysis/local.py +++ b/vyper/semantics/analysis/local.py @@ -347,7 +347,10 @@ def visit_For(self, node): if isinstance(node.iter, vy_ast.Subscript): raise StructureException("Cannot iterate over a nested list", node.iter) - iter_type = type_from_annotation(node.iter_type, DataLocation.MEMORY) + if not isinstance(node.target, vy_ast.AnnAssign): + raise StructureException("Invalid syntax for loop iterator", node.target) + + iter_type = type_from_annotation(node.target.annotation, DataLocation.MEMORY) if isinstance(node.iter, vy_ast.Call): # iteration via range() @@ -410,10 +413,7 @@ def visit_For(self, node): call_node, ) - if not isinstance(node.target, vy_ast.Name): - raise StructureException("Invalid syntax for loop iterator", node.target) - - iter_name = node.target.id + iter_name = node.target.target.id with self.namespace.enter_scope(): self.namespace[iter_name] = VarInfo( iter_type, modifiability=Modifiability.RUNTIME_CONSTANT @@ -422,7 +422,7 @@ def visit_For(self, node): for stmt in node.body: self.visit(stmt) - self.expr_visitor.visit(node.target, iter_type) + self.expr_visitor.visit(node.target.target, iter_type) if isinstance(node.iter, (vy_ast.Name, vy_ast.Attribute)): iter_type = get_exact_type_from_node(node.iter) @@ -714,7 +714,7 @@ def visit_IfExp(self, node: vy_ast.IfExp, typ: VyperType) -> None: self.visit(node.orelse, typ) -def _analyse_range_call(node: vy_ast.Call, iter_type: VyperType) -> list[VyperType]: +def _analyse_range_call(node: vy_ast.Call, iter_type: VyperType): """ Check that the arguments to a range() call are valid. :param node: call to range()