Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
tserg committed Jul 25, 2023
1 parent bc9f993 commit 3cd97a9
Showing 1 changed file with 36 additions and 19 deletions.
55 changes: 36 additions & 19 deletions vyper/semantics/analysis/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,13 +273,19 @@ def visit_AnnAssign(self, node):

def visit_Assert(self, node):
if node.msg:
_validate_revert_reason(node.msg)
if isinstance(node.msg, vy_ast.Str):
if not node.msg.value.strip():
raise StructureException("Reason string cannot be empty", node.msg)
elif not (isinstance(node.msg, vy_ast.Name) and node.msg.id == "UNREACHABLE"):
try:
self.visit(node.msg, StringT(1024))
except TypeMismatch as e:
raise InvalidType("revert reason must fit within String[1024]") from e

try:
validate_expected_type(node.test, BoolT())
self.expr_visitor.visit(node.test, BoolT())
except InvalidType:
raise InvalidType("Assertion test value must be a boolean", node.test)
self.expr_visitor.visit(node.test, BoolT())

def visit_Assign(self, node):
if isinstance(node.value, vy_ast.Tuple):
Expand Down Expand Up @@ -307,7 +313,7 @@ def visit_AugAssign(self, node):

lhs_info = get_expr_info(node.target)

validate_expected_type(node.value, lhs_info.typ)
#validate_expected_type(node.value, lhs_info.typ)
lhs_info.validate_modification(node, self.func.mutability)

self.expr_visitor.visit(node.value, lhs_info.typ)
Expand Down Expand Up @@ -388,10 +394,10 @@ def visit_For(self, node):
raise StateAccessViolation("Value must be a literal", node)
if args[0].value <= 0:
raise StructureException("For loop must have at least 1 iteration", args[0])
validate_expected_type(args[0], IntegerT.any())
self.expr_visitor.visit(args[0])
type_list = get_possible_types_from_node(args[0])
else:
validate_expected_type(args[0], IntegerT.any())
self.expr_visitor.visit(args[0])
type_list = get_common_types(*args)
if not isinstance(args[0], vy_ast.Constant):
# range(x, x + CONSTANT)
Expand All @@ -417,7 +423,7 @@ def visit_For(self, node):
# range(CONSTANT, CONSTANT)
if not isinstance(args[1], vy_ast.Int):
raise InvalidType("Value must be a literal integer", args[1])
validate_expected_type(args[1], IntegerT.any())
self.expr_visitor.visit(args[1])
if args[0].value >= args[1].value:
raise StructureException("Second value must be > first value", args[1])

Expand Down Expand Up @@ -532,7 +538,6 @@ def visit_For(self, node):
)

def visit_If(self, node):
validate_expected_type(node.test, BoolT())
self.expr_visitor.visit(node.test, BoolT())
with self.namespace.enter_scope():
for n in node.body:
Expand Down Expand Up @@ -697,7 +702,6 @@ def visit_BinOp(self, node: vy_ast.BinOp, typ: Optional[VyperType] = None) -> No
def visit_BoolOp(self, node: vy_ast.BoolOp, typ: Optional[VyperType] = None) -> None:
assert typ == BoolT() # sanity check
for value in node.values:
validate_expected_type(value, BoolT())
self.visit(value, BoolT())

def visit_Call(self, node: vy_ast.Call, typ: Optional[VyperType] = None) -> None:
Expand Down Expand Up @@ -765,7 +769,6 @@ def visit_Compare(self, node: vy_ast.Compare, typ: Optional[VyperType] = None) -

rlen = len(node.right.elements)
rtyp = SArrayT(cmp_typ, rlen)
validate_expected_type(node.right, rtyp)
self.visit(node.right, rtyp)
else:
cmp_typ = get_exact_type_from_node(node.right)
Expand All @@ -778,7 +781,6 @@ def visit_Compare(self, node: vy_ast.Compare, typ: Optional[VyperType] = None) -
assert isinstance(cmp_typ, (SArrayT, DArrayT))
ltyp = cmp_typ.value_type

validate_expected_type(node.left, ltyp)
self.visit(node.left, ltyp)

else:
Expand All @@ -793,8 +795,6 @@ def visit_Compare(self, node: vy_ast.Compare, typ: Optional[VyperType] = None) -
rtyp = get_exact_type_from_node(node.right)
else:
ltyp = rtyp = cmp_typ
validate_expected_type(node.left, ltyp)
validate_expected_type(node.right, rtyp)

self.visit(node.left, ltyp)
self.visit(node.right, rtyp)
Expand Down Expand Up @@ -839,7 +839,6 @@ def visit_Constant(self, node: vy_ast.Constant, typ: Optional[VyperType] = None)
raise InvalidLiteral(f"Could not determine type for literal value '{node.value}'", node)

def visit_Index(self, node: vy_ast.Index, typ: Optional[VyperType] = None) -> None:
validate_expected_type(node.value, typ)
self.visit(node.value, typ)

def visit_List(self, node: vy_ast.List, typ: Optional[VyperType] = None) -> None:
Expand Down Expand Up @@ -967,17 +966,35 @@ def visit_Tuple(self, node: vy_ast.Tuple, typ: Optional[VyperType] = None) -> No

assert isinstance(typ, TupleT)
for element, subtype in zip(node.elements, typ.member_types):
validate_expected_type(element, subtype)
self.visit(element, subtype)

def visit_UnaryOp(self, node: vy_ast.UnaryOp, typ: Optional[VyperType] = None) -> None:
validate_expected_type(node.operand, typ)
self.visit(node.operand, typ)

types_list = get_possible_types_from_node(node.operand)
_validate_op(node, types_list, "validate_numeric_op")

if typ:
for t in types_list:
if typ.compare_type(t):
break
else:
raise TypeMismatch(f"{typ} is not a possible type", node)

def visit_IfExp(self, node: vy_ast.IfExp, typ: Optional[VyperType] = None) -> None:
validate_expected_type(node.test, BoolT())
self.visit(node.test, BoolT())
validate_expected_type(node.body, typ)

types_list = get_common_types(node.body, node.orelse)

if not types_list:
a = get_possible_types_from_node(node.body)[0]
b = get_possible_types_from_node(node.orelse)[0]
raise TypeMismatch(f"Dislike types: {a} and {b}", node)

for t in types_list:
if t.compare_type(typ):
break
else:
raise TypeMismatch(f"{typ} is not a possible type", node)
self.visit(node.body, typ)
validate_expected_type(node.orelse, typ)
self.visit(node.orelse, typ)

0 comments on commit 3cd97a9

Please sign in to comment.