Skip to content

Commit

Permalink
rewrite visit_For, use AnnAssign for the target
Browse files Browse the repository at this point in the history
add some more error messages
  • Loading branch information
charles-cooper committed Jan 6, 2024
1 parent b951b47 commit 3c5c0cb
Show file tree
Hide file tree
Showing 5 changed files with 90 additions and 76 deletions.
2 changes: 1 addition & 1 deletion vyper/ast/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1546,7 +1546,7 @@ class IfExp(ExprNode):


class For(Stmt):
__slots__ = ("iter", "iter_type", "target", "body")
__slots__ = ("target", "iter", "body")
_only_empty_fields = ("orelse",)


Expand Down
100 changes: 56 additions & 44 deletions vyper/ast/parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,13 +54,9 @@ def parse_to_ast_with_settings(
"""
if "\x00" in source_code:
raise ParserException("No null bytes (\\x00) allowed in the source code.")
settings, loop_var_annotations, class_types, reformatted_code = pre_parse(source_code)
settings, class_types, for_loop_annotations, reformatted_code = pre_parse(source_code)
try:
py_ast = python_ast.parse(reformatted_code)

for k, v in loop_var_annotations.items():
parsed_v = python_ast.parse(v["source_code"])
loop_var_annotations[k]["parsed_ast"] = parsed_v
except SyntaxError as e:
# TODO: Ensure 1-to-1 match of source_code:reformatted_code SyntaxErrors
raise SyntaxException(str(e), source_code, e.lineno, e.offset) from e
Expand All @@ -76,13 +72,16 @@ def parse_to_ast_with_settings(
annotate_python_ast(
py_ast,
source_code,
loop_var_annotations,
class_types,
for_loop_annotations,
source_id,
module_path=module_path,
resolved_path=resolved_path,
)

# postcondition: consumed all the for loop annotations
assert len(for_loop_annotations) == 0

# Convert to Vyper AST.
module = vy_ast.get_node(py_ast)
assert isinstance(module, vy_ast.Module) # mypy hint
Expand Down Expand Up @@ -123,8 +122,8 @@ class AnnotatingVisitor(python_ast.NodeTransformer):
def __init__(
self,
source_code: str,
loop_var_annotations: dict[int, dict[str, Any]],
modification_offsets: Optional[ModificationOffsets],
modification_offsets: ModificationOffsets,
for_loop_annotations: dict,
tokens: asttokens.ASTTokens,
source_id: int,
module_path: Optional[str] = None,
Expand All @@ -134,12 +133,11 @@ def __init__(
self._source_id = source_id
self._module_path = module_path
self._resolved_path = resolved_path
self._source_code: str = source_code
self._source_code = source_code
self._modification_offsets = modification_offsets
self._for_loop_annotations = for_loop_annotations

self.counter: int = 0
self._modification_offsets = {}
self._loop_var_annotations = loop_var_annotations
if modification_offsets is not None:
self._modification_offsets = modification_offsets

def generic_visit(self, node):
"""
Expand Down Expand Up @@ -221,6 +219,45 @@ def visit_ClassDef(self, node):
node.ast_type = self._modification_offsets[(node.lineno, node.col_offset)]
return node

def visit_For(self, node):
"""
Visit a For node, splicing in the loop variable annotation provided by
the pre-parser
"""
raw_annotation = self._for_loop_annotations.pop((node.lineno, node.col_offset))

if not raw_annotation:
# a common case for people migrating to 0.4.0, provide a more
# specific error message than "invalid type annotation"
raise SyntaxException(
"missing type annotation\n\n"
"(hint: did you mean something like "
f"`for {node.target.id}: uint256 in ...`?)\n",
self._source_code,
node.lineno,
node.col_offset,
)

try:
annotation = python_ast.parse(raw_annotation, mode="eval")
except SyntaxError as e:
raise SyntaxException(
"invalid type annotation", self._source_code, node.lineno, node.col_offset
) from e

assert isinstance(annotation, python_ast.Expression)
annotation = annotation.body

node.target_annotation = annotation

old_target = node.target
new_target = python_ast.AnnAssign(target=old_target, annotation=annotation, simple=1)
node.target = new_target

self.generic_visit(node)

return node

def visit_Expr(self, node):
"""
Convert the `Yield` node into a Vyper-specific node type.
Expand All @@ -240,28 +277,6 @@ def visit_Expr(self, node):

return node

def visit_For(self, node):
"""
Annotate `For` nodes with the iterator's type annotation that was extracted
during pre-parsing.
"""
iter_type_info = self._loop_var_annotations.get(node.lineno)
if not iter_type_info:
raise SyntaxException(
"For loop iterator requires type annotation",
self._source_code,
node.iter.lineno,
node.iter.col_offset,
)

iter_type_ast = iter_type_info["parsed_ast"]
self.generic_visit(iter_type_ast)
self.generic_visit(node)

node.iter_type = iter_type_ast.body[0].value

return node

def visit_Subscript(self, node):
"""
Maintain consistency of `Subscript.slice` across python versions.
Expand Down Expand Up @@ -322,13 +337,10 @@ def visit_Num(self, node):
# modify vyper AST type according to the format of the literal value
self.generic_visit(node)

# the type annotation of a for loop iterator is removed from the source
# code during pre-parsing, and therefore the `node_source_code` attribute
# of an integer in the type annotation would not be available e.g. DynArray[uint256, 3]
value = node.node_source_code if hasattr(node, "node_source_code") else None
value = node.node_source_code

# deduce non base-10 types based on prefix
if value and value.lower()[:2] == "0x":
if value.lower()[:2] == "0x":
if len(value) % 2:
raise SyntaxException(
"Hex notation requires an even number of digits",
Expand All @@ -339,7 +351,7 @@ def visit_Num(self, node):
node.ast_type = "Hex"
node.n = value

elif value and value.lower()[:2] == "0b":
elif value.lower()[:2] == "0b":
node.ast_type = "Bytes"
mod = (len(value) - 2) % 8
if mod:
Expand Down Expand Up @@ -389,8 +401,8 @@ def visit_UnaryOp(self, node):
def annotate_python_ast(
parsed_ast: python_ast.AST,
source_code: str,
loop_var_annotations: dict[int, dict[str, Any]],
modification_offsets: Optional[ModificationOffsets] = None,
modification_offsets: ModificationOffsets,
for_loop_annotations: dict,
source_id: int = 0,
module_path: Optional[str] = None,
resolved_path: Optional[str] = None,
Expand Down Expand Up @@ -418,8 +430,8 @@ def annotate_python_ast(
tokens = asttokens.ASTTokens(source_code, tree=cast(Optional[python_ast.Module], parsed_ast))
visitor = AnnotatingVisitor(
source_code,
loop_var_annotations,
modification_offsets,
for_loop_annotations,
tokens,
source_id,
module_path=module_path,
Expand Down
46 changes: 24 additions & 22 deletions vyper/ast/pre_parser.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import io
import enum
import io
import re
from tokenize import COMMENT, NAME, OP, TokenError, TokenInfo, tokenize, untokenize
from typing import Any
Expand Down Expand Up @@ -44,16 +44,19 @@ def validate_version_pragma(version_str: str, start: ParserPosition) -> None:
start,
)


class ForParserState(enum.Enum):
NOT_RUNNING = enum.auto()
START_SOON = enum.auto()
RUNNING = enum.auto()


# a simple state machine which allows us to handle loop variable annotations
# (which are rejected by the python parser due to pep-526, so we scoop up the
# tokens between `:` and `in` and parse them and add them back in later).
class ForParser:
def __init__(self):
def __init__(self, code):
self._code = code
self.annotations = {}
self._current_annotation = None

Expand All @@ -74,20 +77,27 @@ def consume(self, token):
# state machine: start slurping tokens
if token.type == OP and token.string == ":":
self._state = ForParserState.RUNNING
assert self._current_annotation is None, (self._current_for_loop, self._current_annotation)
self._current_annotation = []
return False

if self._state != ForParserState.RUNNING:
return False
# sanity check -- this should never really happen, but if it does,
# try to raise an exception which pinpoints the source.
if self._current_annotation is not None:
raise SyntaxException(
"for loop parse error", self._code, token.start[0], token.start[1]
)

self._current_annotation = []
return True # do not add ":" to tokens.

# state machine: end slurping tokens
if token.type == NAME and token.string == "in":
self._state = ForParserState.NOT_RUNNING
self.annotations[self._current_for_loop] = self._current_annotation
self.annotations[self._current_for_loop] = self._current_annotation or []
self._current_annotation = None
return False

if self._state != ForParserState.RUNNING:
return False

# slurp the token
self._current_annotation.append(token)
return True
Expand All @@ -101,7 +111,7 @@ def consume(self, token):
VYPER_EXPRESSION_TYPES = {"log"}


def pre_parse(code: str) -> tuple[Settings, dict[int, dict[str, Any]], ModificationOffsets, str]:
def pre_parse(code: str) -> tuple[Settings, ModificationOffsets, dict, str]:
"""
Re-formats a vyper source string into a python source string and performs
some validation. More specifically,
Expand All @@ -127,15 +137,15 @@ def pre_parse(code: str) -> tuple[Settings, dict[int, dict[str, Any]], Modificat
Compilation settings based on the directives in the source code
ModificationOffsets
A mapping of class names to their original class types.
dict[int, dict[str, Any]]
A mapping of line numbers of `For` nodes to the type annotation of the iterator
dict[tuple[int, int], str]
A mapping of line/column offsets of `For` nodes to the annotation of the for loop target
str
Reformatted python source string.
"""
result = []
modification_offsets: ModificationOffsets = {}
settings = Settings()
for_parser = ForParser()
for_parser = ForParser(code)

try:
code_bytes = code.encode("utf-8")
Expand Down Expand Up @@ -211,15 +221,7 @@ def pre_parse(code: str) -> tuple[Settings, dict[int, dict[str, Any]], Modificat

for_loop_annotations = {}
for k, v in for_parser.annotations.items():
updated_v = untokenize(v)
# print("untokenized v: ", updated_v)
# updated_v = updated_v.replace("\\", "")
# updated_v = updated_v.replace("\n", "")
# import textwrap

# print("updated v: ", textwrap.dedent(updated_v))
for_loop_annotations[k] = updated_v
v_source = untokenize(v).replace("\\", "").strip()
for_loop_annotations[k] = v_source

# print("untokenized result: ", type(untokenize(result)))
# print("untokenized result decoded: ", untokenize(result).decode("utf-8"))
return settings, modification_offsets, for_loop_annotations, untokenize(result).decode("utf-8")
4 changes: 2 additions & 2 deletions vyper/codegen/stmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,11 +297,11 @@ def _parse_For_list(self):
with self.context.range_scope():
iter_list = Expr(self.stmt.iter, self.context).ir_node

target_type = self.stmt.target._metadata["type"]
target_type = self.stmt.target.target._metadata["type"]
assert target_type == iter_list.typ.value_type

# user-supplied name for loop variable
varname = self.stmt.target.id
varname = self.stmt.target.target.id
loop_var = IRnode.from_list(
self.context.new_variable(varname, target_type), typ=target_type, location=MEMORY
)
Expand Down
14 changes: 7 additions & 7 deletions vyper/semantics/analysis/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,7 +347,10 @@ def visit_For(self, node):
if isinstance(node.iter, vy_ast.Subscript):
raise StructureException("Cannot iterate over a nested list", node.iter)

iter_type = type_from_annotation(node.iter_type, DataLocation.MEMORY)
if not isinstance(node.target, vy_ast.AnnAssign):
raise StructureException("Invalid syntax for loop iterator", node.target)

iter_type = type_from_annotation(node.target.annotation, DataLocation.MEMORY)

if isinstance(node.iter, vy_ast.Call):
# iteration via range()
Expand Down Expand Up @@ -410,10 +413,7 @@ def visit_For(self, node):
call_node,
)

if not isinstance(node.target, vy_ast.Name):
raise StructureException("Invalid syntax for loop iterator", node.target)

iter_name = node.target.id
iter_name = node.target.target.id
with self.namespace.enter_scope():
self.namespace[iter_name] = VarInfo(
iter_type, modifiability=Modifiability.RUNTIME_CONSTANT
Expand All @@ -422,7 +422,7 @@ def visit_For(self, node):
for stmt in node.body:
self.visit(stmt)

self.expr_visitor.visit(node.target, iter_type)
self.expr_visitor.visit(node.target.target, iter_type)

if isinstance(node.iter, (vy_ast.Name, vy_ast.Attribute)):
iter_type = get_exact_type_from_node(node.iter)
Expand Down Expand Up @@ -714,7 +714,7 @@ def visit_IfExp(self, node: vy_ast.IfExp, typ: VyperType) -> None:
self.visit(node.orelse, typ)


def _analyse_range_call(node: vy_ast.Call, iter_type: VyperType) -> list[VyperType]:
def _analyse_range_call(node: vy_ast.Call, iter_type: VyperType):
"""
Check that the arguments to a range() call are valid.
:param node: call to range()
Expand Down

0 comments on commit 3c5c0cb

Please sign in to comment.