Skip to content

Commit

Permalink
polish some things, fix some potential spots for bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
charles-cooper committed Oct 19, 2023
1 parent acfe4de commit dbb25cb
Showing 1 changed file with 61 additions and 48 deletions.
109 changes: 61 additions & 48 deletions vyper/semantics/analysis/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,18 +121,6 @@ def _check_iterator_modification(
return None


def _validate_revert_reason(msg_node: vy_ast.VyperNode) -> None:
if msg_node:
if isinstance(msg_node, vy_ast.Str):
if not msg_node.value.strip():
raise StructureException("Reason string cannot be empty", msg_node)
elif not (isinstance(msg_node, vy_ast.Name) and msg_node.id == "UNREACHABLE"):
try:
validate_expected_type(msg_node, StringT(1024))
except TypeMismatch as e:
raise InvalidType("revert reason must fit within String[1024]") from e


# helpers
def _validate_address_code(node: vy_ast.Attribute, value_type: VyperType) -> None:
if isinstance(value_type, AddressT) and node.attr == "code":
Expand Down Expand Up @@ -182,6 +170,7 @@ def _validate_pure_access(node: vy_ast.Attribute, typ: VyperType) -> None:


def _validate_self_reference(node: vy_ast.Name) -> None:
# CMC 2023-10-19 this detector seems sus, things like `a.b(self)` could slip through
if node.id == "self" and not isinstance(node.get_ancestor(), vy_ast.Attribute):
raise StateAccessViolation("not allowed to query self in pure functions", node)

Expand Down Expand Up @@ -214,6 +203,7 @@ def __init__(
f"Missing or unmatched return statements in function '{fn_node.name}'", fn_node
)

# visit default args
assert self.func.n_keyword_args == len(fn_node.args.defaults)
for kwarg in self.func.keyword_args:
self.expr_visitor.visit(kwarg.default_value, kwarg.typ)
Expand Down Expand Up @@ -242,17 +232,31 @@ def visit_AnnAssign(self, node):
self.expr_visitor.visit(node.target, typ)
self.expr_visitor.visit(node.value, typ)

def _validate_revert_reason(self, msg_node: vy_ast.VyperNode) -> None:
if isinstance(msg_node, vy_ast.Str):
if not msg_node.value.strip():
raise StructureException("Reason string cannot be empty", msg_node)
self.expr_visitor.visit(msg_node, get_exact_type_from_node(msg_node))
elif not (isinstance(msg_node, vy_ast.Name) and msg_node.id == "UNREACHABLE"):
try:
validate_expected_type(msg_node, StringT(1024))
except TypeMismatch as e:
raise InvalidType("revert reason must fit within String[1024]") from e
self.expr_visitor.visit(msg_node, get_exact_type_from_node(msg_node))
# CMC 2023-10-19 nice to have: tag UNREACHABLE nodes with a special type

def visit_Assert(self, node):
if node.msg:
_validate_revert_reason(node.msg)
self._validate_revert_reason(node.msg)

try:
validate_expected_type(node.test, BoolT())
except InvalidType:
raise InvalidType("Assertion test value must be a boolean", node.test)
self.expr_visitor.visit(node.test, BoolT())

def visit_Assign(self, node):
# repeated code for Assign and AugAssign
def _assign_helper(self, node):
if isinstance(node.value, vy_ast.Tuple):
raise StructureException("Right-hand side of assignment cannot be a tuple", node.value)

Expand All @@ -268,24 +272,19 @@ def visit_Assign(self, node):
self.expr_visitor.visit(node.value, target.typ)
self.expr_visitor.visit(node.target, target.typ)

def visit_AugAssign(self, node):
if isinstance(node.value, vy_ast.Tuple):
raise StructureException("Right-hand side of assignment cannot be a tuple", node.value)

lhs_info = get_expr_info(node.target)

validate_expected_type(node.value, lhs_info.typ)
lhs_info.validate_modification(node, self.func.mutability)
def visit_Assign(self, node):
self._assign_helper(node)

self.expr_visitor.visit(node.value, lhs_info.typ)
self.expr_visitor.visit(node.target, lhs_info.typ)
def visit_AugAssign(self, node):
self._assign_helper(node)

def visit_Break(self, node):
for_node = node.get_ancestor(vy_ast.For)
if for_node is None:
raise StructureException("`break` must be enclosed in a `for` loop", node)

def visit_Continue(self, node):
# TODO: use context/state instead of ast search
for_node = node.get_ancestor(vy_ast.For)
if for_node is None:
raise StructureException("`continue` must be enclosed in a `for` loop", node)
Expand Down Expand Up @@ -469,12 +468,12 @@ def visit_For(self, node):

for_loop_exceptions = []
iter_name = node.target.id
for typ in type_list:
for possible_target_type in type_list:
# type check the for loop body using each possible type for iterator value

with self.namespace.enter_scope():
try:
self.namespace[iter_name] = VarInfo(typ, is_constant=True)
self.namespace[iter_name] = VarInfo(possible_target_type, is_constant=True)
except VyperException as exc:
raise exc.with_annotation(node) from None

Expand All @@ -485,20 +484,22 @@ def visit_For(self, node):
except (TypeMismatch, InvalidOperation) as exc:
for_loop_exceptions.append(exc)
else:
self.expr_visitor.visit(node.target, typ)
self.expr_visitor.visit(node.target, possible_target_type)

if isinstance(node.iter, (vy_ast.Name, vy_ast.Attribute)):
typ = get_exact_type_from_node(node.iter)
self.expr_visitor.visit(node.iter, typ)
iter_type = get_exact_type_from_node(node.iter)
# note CMC 2023-10-23: slightly redundant with how type_list is computed
validate_expected_type(node.target, iter_type.value_type)
self.expr_visitor.visit(node.iter, iter_type)
if isinstance(node.iter, vy_ast.List):
len_ = len(node.iter.elements)
self.expr_visitor.visit(node.iter, SArrayT(typ, len_))
self.expr_visitor.visit(node.iter, SArrayT(possible_target_type, len_))
if isinstance(node.iter, vy_ast.Call) and node.iter.func.id == "range":
for a in node.iter.args:
self.expr_visitor.visit(a, typ)
self.expr_visitor.visit(a, possible_target_type)
for a in node.iter.keywords:
if a.arg == "bound":
self.expr_visitor.visit(a.value, typ)
self.expr_visitor.visit(a.value, possible_target_type)

# success -- do not enter error handling section
return
Expand Down Expand Up @@ -548,7 +549,7 @@ def visit_Log(self, node):

def visit_Raise(self, node):
if node.exc:
_validate_revert_reason(node.exc)
self._validate_revert_reason(node.exc)

def visit_Return(self, node):
values = node.value
Expand Down Expand Up @@ -599,7 +600,12 @@ def visit(self, node, typ):
def visit_Attribute(self, node: vy_ast.Attribute, typ: VyperType) -> None:
_validate_msg_data_attribute(node)

if self.func.mutability is not StateMutability.PAYABLE:
# CMC 2023-10-19 TODO generalize this to mutability check on every node.
# something like,
# if self.func.mutability < expr_info.mutability:
# raise ...

if self.func.mutability != StateMutability.PAYABLE:
_validate_msg_value_access(node)

if self.func.mutability == StateMutability.PURE:
Expand All @@ -617,8 +623,8 @@ def visit_BinOp(self, node: vy_ast.BinOp, typ: VyperType) -> None:
rtyp = typ
if isinstance(node.op, (vy_ast.LShift, vy_ast.RShift)):
rtyp = get_possible_types_from_node(node.right).pop()
else:
validate_expected_type(node.right, rtyp)

validate_expected_type(node.right, rtyp)

self.visit(node.right, rtyp)

Expand Down Expand Up @@ -674,27 +680,27 @@ def visit_Call(self, node: vy_ast.Call, typ: VyperType) -> None:
def visit_Compare(self, node: vy_ast.Compare, typ: VyperType) -> None:
if isinstance(node.op, (vy_ast.In, vy_ast.NotIn)):
# membership in list literal - `x in [a, b, c]`
# needle: ltyp, haystack: rtyp
if isinstance(node.right, vy_ast.List):
cmp_typ = get_common_types(node.left, *node.right.elements).pop()
ltyp = cmp_typ
ltyp = get_common_types(node.left, *node.right.elements).pop()

rlen = len(node.right.elements)
rtyp = SArrayT(cmp_typ, rlen)
rtyp = SArrayT(ltyp, rlen)
validate_expected_type(node.right, rtyp)
self.visit(node.right, rtyp)
else:
cmp_typ = get_exact_type_from_node(node.right)
self.visit(node.right, cmp_typ)
if isinstance(cmp_typ, EnumT):
rtyp = get_exact_type_from_node(node.right)
if isinstance(rtyp, EnumT):
# enum membership - `some_enum in other_enum`
ltyp = cmp_typ
ltyp = rtyp
else:
# array membership - `x in my_list_variable`
assert isinstance(cmp_typ, (SArrayT, DArrayT))
ltyp = cmp_typ.value_type
assert isinstance(rtyp, (SArrayT, DArrayT))
ltyp = rtyp.value_type

validate_expected_type(node.left, ltyp)

self.visit(node.left, ltyp)
self.visit(node.right, rtyp)

else:
# ex. a < b
Expand Down Expand Up @@ -755,9 +761,16 @@ def visit_Subscript(self, node: vy_ast.Subscript, typ: VyperType) -> None:
base_type = get_exact_type_from_node(node.value)

# get the correct type for the index, it might
# not be base_type.key_type
# not be exactly base_type.key_type
index_types = get_possible_types_from_node(node.slice.value)
index_type = index_types.pop()
for possible_index_type in index_types:
if base_type.key_type.compare_type(possible_index_type):
index_type = possible_index_type
break
else:
raise TypeCheckFailure(
f"Expected {base_type.key_type} but it is not a possible type", node
)

self.visit(node.slice, index_type)
self.visit(node.value, base_type)
Expand Down

0 comments on commit dbb25cb

Please sign in to comment.