diff --git a/vyper/ast/nodes.py b/vyper/ast/nodes.py index ffd442f89a..40535e4b90 100644 --- a/vyper/ast/nodes.py +++ b/vyper/ast/nodes.py @@ -1194,9 +1194,6 @@ def evaluate(self, left, right) -> ExprNode: if not isinstance(left, type(right)): raise UnfoldableNode("Cannot compare different literal types") - if not isinstance(self.op, (Eq, NotEq)) and not isinstance(left, (Int, Decimal)): - raise TypeMismatch(f"Invalid literal types for {self.op.description} comparison", self) - value = self.op._op(left.value, right.value) return NameConstant.from_node(self, value=value) diff --git a/vyper/semantics/analysis/local.py b/vyper/semantics/analysis/local.py index b865c8461a..dac55824d1 100644 --- a/vyper/semantics/analysis/local.py +++ b/vyper/semantics/analysis/local.py @@ -727,6 +727,16 @@ def visit_Compare(self, node: vy_ast.Compare, typ: VyperType) -> None: self.visit(node.right, rtyp) else: + # TODO: this is unreachable because this would throw when the parent node + # calls `validate_expected_type` + left = get_folded_value(node.left) + if not isinstance(node.op, (vy_ast.Eq, vy_ast.NotEq)) and not isinstance( + left, (vy_ast.Int, vy_ast.Decimal) + ): + raise TypeMismatch( + f"Invalid literal types for {node.op.description} comparison", node + ) + # ex. a < b cmp_typ = get_common_types(node.left, node.right).pop() if isinstance(cmp_typ, _BytestringT): diff --git a/vyper/semantics/analysis/pre_typecheck.py b/vyper/semantics/analysis/pre_typecheck.py index c0d1c90403..88aeca20bc 100644 --- a/vyper/semantics/analysis/pre_typecheck.py +++ b/vyper/semantics/analysis/pre_typecheck.py @@ -179,7 +179,10 @@ def visit_BoolOp(self, node): values = [get_folded_value(i) for i in node.values] if all(isinstance(v, vy_ast.NameConstant) for v in values): - node._metadata["folded_value"] = node.evaluate(values) + try: + node._metadata["folded_value"] = node.evaluate(values) + except UnfoldableNode: + pass def visit_Call(self, node): for arg in node.args: @@ -230,9 +233,11 @@ def visit_Compare(self, node): node._metadata["folded_value"] = node.evaluate(left, right) return - right = get_folded_value(node) - if isinstance(left, type(right)) and isinstance(left, (vy_ast.Int, vy_ast.Decimal)): + right = get_folded_value(node.right) + try: node._metadata["folded_value"] = node.evaluate(left, right) + except UnfoldableNode: + pass def visit_Constant(self, node): pass