From aa7f416a5712ea002f8531caf3b3e098b43581e6 Mon Sep 17 00:00:00 2001 From: tserg <8017125+tserg@users.noreply.github.com> Date: Tue, 24 Oct 2023 22:23:50 +0800 Subject: [PATCH] clean up prefold --- vyper/ast/pre_typecheck.py | 44 ++++++++++++++++++++------------------ 1 file changed, 23 insertions(+), 21 deletions(-) diff --git a/vyper/ast/pre_typecheck.py b/vyper/ast/pre_typecheck.py index 500c30da29..ca2becdc99 100644 --- a/vyper/ast/pre_typecheck.py +++ b/vyper/ast/pre_typecheck.py @@ -1,7 +1,7 @@ from decimal import Decimal from typing import Any -from vyper import ast as vy_ast +from vyper.ast import nodes as vy_ast from vyper.exceptions import UnfoldableNode, VyperException from vyper.semantics.namespace import get_namespace @@ -12,19 +12,19 @@ def prefold(node: vy_ast.VyperNode) -> Any: # constant struct members if isinstance(val, dict): return val[node.attr] - return None + elif isinstance(node, vy_ast.BinOp): assert isinstance(node, vy_ast.BinOp) left = prefold(node.left) right = prefold(node.right) - if not (isinstance(left, type(right)) and isinstance(left, (int, Decimal))): - return None - return node.op._op(left, right) + if isinstance(left, type(right)) and isinstance(left, (int, Decimal)): + return node.op._op(left, right) + elif isinstance(node, vy_ast.BoolOp): values = [prefold(i) for i in node.values] - if not all(isinstance(v, bool) for v in values): - return None - return node.op._op(values) + if all(isinstance(v, bool) for v in values): + return node.op._op(values) + elif isinstance(node, vy_ast.Call): # constant structs if len(node.args) == 1 and isinstance(node.args[0], vy_ast.Dict): @@ -40,6 +40,7 @@ def prefold(node: vy_ast.VyperNode) -> Any: return call_type.evaluate(node).value # type: ignore except (UnfoldableNode, VyperException): pass + elif isinstance(node, vy_ast.Compare): left = prefold(node.left) @@ -53,28 +54,29 @@ def prefold(node: vy_ast.VyperNode) -> Any: return node.op._op(left, right) right = prefold(node.right) - if not (isinstance(left, type(right)) and isinstance(left, (int, Decimal))): - return None - return node.op._op(left, right) + if isinstance(left, type(right)) and isinstance(left, (int, Decimal)): + return node.op._op(left, right) + elif isinstance(node, vy_ast.Constant): return node.value + elif isinstance(node, vy_ast.Dict): values = [prefold(v) for v in node.values] - if any(v is None for v in values): - return None - return {k.id: v for (k, v) in zip(node.keys, values)} + if not any(v is None for v in values): + return {k.id: v for (k, v) in zip(node.keys, values)} + elif isinstance(node, (vy_ast.List, vy_ast.Tuple)): val = [prefold(e) for e in node.elements] - if None in val: - return None - return val + if None not in val: + return val + elif isinstance(node, vy_ast.Name): ns = get_namespace() - return ns._constants.get(node.id, None) + return ns._constants.get(node.id) + elif isinstance(node, vy_ast.UnaryOp): operand = prefold(node.operand) - if not isinstance(operand, int): - return None - return node.op._op(operand) + if isinstance(operand, int): + return node.op._op(operand) return None