Skip to content

Commit

Permalink
add another example
Browse files Browse the repository at this point in the history
  • Loading branch information
tserg committed Nov 3, 2023
1 parent 3482846 commit 064cd90
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 6 deletions.
3 changes: 0 additions & 3 deletions vyper/ast/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1194,9 +1194,6 @@ def evaluate(self, left, right) -> ExprNode:
if not isinstance(left, type(right)):
raise UnfoldableNode("Cannot compare different literal types")

if not isinstance(self.op, (Eq, NotEq)) and not isinstance(left, (Int, Decimal)):
raise TypeMismatch(f"Invalid literal types for {self.op.description} comparison", self)

value = self.op._op(left.value, right.value)
return NameConstant.from_node(self, value=value)

Expand Down
10 changes: 10 additions & 0 deletions vyper/semantics/analysis/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -727,6 +727,16 @@ def visit_Compare(self, node: vy_ast.Compare, typ: VyperType) -> None:
self.visit(node.right, rtyp)

else:
# TODO: this is unreachable because this would throw when the parent node
# calls `validate_expected_type`
left = get_folded_value(node.left)
if not isinstance(node.op, (vy_ast.Eq, vy_ast.NotEq)) and not isinstance(
left, (vy_ast.Int, vy_ast.Decimal)
):
raise TypeMismatch(
f"Invalid literal types for {node.op.description} comparison", node
)

# ex. a < b
cmp_typ = get_common_types(node.left, node.right).pop()
if isinstance(cmp_typ, _BytestringT):
Expand Down
11 changes: 8 additions & 3 deletions vyper/semantics/analysis/pre_typecheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,10 @@ def visit_BoolOp(self, node):

values = [get_folded_value(i) for i in node.values]
if all(isinstance(v, vy_ast.NameConstant) for v in values):
node._metadata["folded_value"] = node.evaluate(values)
try:
node._metadata["folded_value"] = node.evaluate(values)
except UnfoldableNode:
pass

def visit_Call(self, node):
for arg in node.args:
Expand Down Expand Up @@ -230,9 +233,11 @@ def visit_Compare(self, node):
node._metadata["folded_value"] = node.evaluate(left, right)
return

right = get_folded_value(node)
if isinstance(left, type(right)) and isinstance(left, (vy_ast.Int, vy_ast.Decimal)):
right = get_folded_value(node.right)
try:
node._metadata["folded_value"] = node.evaluate(left, right)
except UnfoldableNode:
pass

def visit_Constant(self, node):
pass
Expand Down

0 comments on commit 064cd90

Please sign in to comment.