Skip to content

Commit

Permalink
simpliy visit_For
Browse files Browse the repository at this point in the history
  • Loading branch information
tserg committed Jan 6, 2024
1 parent b913d45 commit 5bc54c0
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 68 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -648,7 +648,7 @@ def foo(x: int128):
"""
@external
def foo():
for i: uint256 in range(-3):
for i: int128 in range(-3):
pass
""",
StructureException,
Expand Down Expand Up @@ -776,7 +776,7 @@ def test_for() -> int128:
a = i
return a
""",
TypeMismatch,
InvalidType,
),
(
"""
Expand Down
92 changes: 26 additions & 66 deletions vyper/semantics/analysis/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
ExceptionList,
FunctionDeclarationException,
ImmutableViolation,
InvalidOperation,
InvalidType,
IteratorException,
NonPayableViolation,
Expand Down Expand Up @@ -39,7 +38,6 @@
EventT,
FlagT,
HashMapT,
IntegerT,
SArrayT,
StringT,
StructT,
Expand Down Expand Up @@ -350,30 +348,25 @@ def visit_For(self, node):
raise StructureException("Cannot iterate over a nested list", node.iter)

iter_type = type_from_annotation(node.iter_type, DataLocation.MEMORY)
node.target._metadata["type"] = iter_type

if isinstance(node.iter, vy_ast.Call):
# iteration via range()
if node.iter.get("func.id") != "range":
raise IteratorException(
"Cannot iterate over the result of a function call", node.iter
)
type_list = _analyse_range_call(node.iter)
_analyse_range_call(node.iter, iter_type)

else:
# iteration over a variable or literal list
iter_val = node.iter.get_folded_value() if node.iter.has_folded_value else node.iter
if isinstance(iter_val, vy_ast.List) and len(iter_val.elements) == 0:
raise StructureException("For loop must have at least 1 iteration", node.iter)

type_list = [
i.value_type
for i in get_possible_types_from_node(node.iter)
if isinstance(i, (DArrayT, SArrayT))
]

if not type_list:
raise InvalidType("Not an iterable type", node.iter)
if not any(
isinstance(i, (DArrayT, SArrayT)) for i in get_possible_types_from_node(node.iter)
):
raise InvalidType("Not an iterable type", node.iter)

if isinstance(node.iter, (vy_ast.Name, vy_ast.Attribute)):
# check for references to the iterated value within the body of the loop
Expand Down Expand Up @@ -420,58 +413,31 @@ def visit_For(self, node):
if not isinstance(node.target, vy_ast.Name):
raise StructureException("Invalid syntax for loop iterator", node.target)

for_loop_exceptions = []
iter_name = node.target.id
with self.namespace.enter_scope():
self.namespace[iter_name] = VarInfo(
iter_type, modifiability=Modifiability.RUNTIME_CONSTANT
)

try:
for stmt in node.body:
self.visit(stmt)

self.expr_visitor.visit(node.target, iter_type)

if isinstance(node.iter, (vy_ast.Name, vy_ast.Attribute)):
iter_type = get_exact_type_from_node(node.iter)
# note CMC 2023-10-23: slightly redundant with how type_list is computed
validate_expected_type(node.target, iter_type.value_type)
self.expr_visitor.visit(node.iter, iter_type)
if isinstance(node.iter, vy_ast.List):
len_ = len(node.iter.elements)
self.expr_visitor.visit(node.iter, SArrayT(iter_type, len_))
if isinstance(node.iter, vy_ast.Call) and node.iter.func.id == "range":
for a in node.iter.args:
self.expr_visitor.visit(a, iter_type)
for a in node.iter.keywords:
if a.arg == "bound":
self.expr_visitor.visit(a.value, iter_type)

except (TypeMismatch, InvalidOperation) as exc:
for_loop_exceptions.append(exc)
else:
# success -- do not enter error handling section
return

# failed to find a good type. bail out
if len(set(str(i) for i in for_loop_exceptions)) == 1:
# if every attempt at type checking raised the same exception
raise for_loop_exceptions[0]

# return an aggregate TypeMismatch that shows all possible exceptions
# depending on which type is used
types_str = [str(i) for i in type_list]
given_str = f"{', '.join(types_str[:1])} or {types_str[-1]}"
raise TypeMismatch(
f"Iterator value '{iter_name}' may be cast as {given_str}, "
"but type checking fails with all possible types:",
node,
*(
(f"Casting '{iter_name}' as {typ}: {exc.message}", exc.annotations[0])
for typ, exc in zip(type_list, for_loop_exceptions)
),
)
for stmt in node.body:
self.visit(stmt)

self.expr_visitor.visit(node.target, iter_type)

if isinstance(node.iter, (vy_ast.Name, vy_ast.Attribute)):
iter_type = get_exact_type_from_node(node.iter)
# note CMC 2023-10-23: slightly redundant with how type_list is computed
validate_expected_type(node.target, iter_type.value_type)
self.expr_visitor.visit(node.iter, iter_type)
if isinstance(node.iter, vy_ast.List):
len_ = len(node.iter.elements)
self.expr_visitor.visit(node.iter, SArrayT(iter_type, len_))
if isinstance(node.iter, vy_ast.Call) and node.iter.func.id == "range":
for a in node.iter.args:
self.expr_visitor.visit(a, iter_type)
for a in node.iter.keywords:
if a.arg == "bound":
self.expr_visitor.visit(a.value, iter_type)

def visit_If(self, node):
validate_expected_type(node.test, BoolT())
Expand Down Expand Up @@ -748,7 +714,7 @@ def visit_IfExp(self, node: vy_ast.IfExp, typ: VyperType) -> None:
self.visit(node.orelse, typ)


def _analyse_range_call(node: vy_ast.Call) -> list[VyperType]:
def _analyse_range_call(node: vy_ast.Call, iter_type: VyperType) -> list[VyperType]:
"""
Check that the arguments to a range() call are valid.
:param node: call to range()
Expand All @@ -761,11 +727,7 @@ def _analyse_range_call(node: vy_ast.Call) -> list[VyperType]:

all_args = (start, end, *kwargs.values())
for arg1 in all_args:
validate_expected_type(arg1, IntegerT.any())

type_list = get_common_types(*all_args)
if not type_list:
raise TypeMismatch("Iterator values are of different types", node)
validate_expected_type(arg1, iter_type)

if "bound" in kwargs:
bound = kwargs["bound"]
Expand All @@ -785,5 +747,3 @@ def _analyse_range_call(node: vy_ast.Call) -> list[VyperType]:
raise StateAccessViolation(error, arg)
if end.value <= start.value:
raise StructureException("End must be greater than start", end)

return type_list

0 comments on commit 5bc54c0

Please sign in to comment.