Skip to content

Commit

Permalink
clean up prefold
Browse files Browse the repository at this point in the history
  • Loading branch information
tserg committed Oct 24, 2023
1 parent 8c49b24 commit aa7f416
Showing 1 changed file with 23 additions and 21 deletions.
44 changes: 23 additions & 21 deletions vyper/ast/pre_typecheck.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from decimal import Decimal
from typing import Any

from vyper import ast as vy_ast
from vyper.ast import nodes as vy_ast
from vyper.exceptions import UnfoldableNode, VyperException
from vyper.semantics.namespace import get_namespace

Expand All @@ -12,19 +12,19 @@ def prefold(node: vy_ast.VyperNode) -> Any:
# constant struct members
if isinstance(val, dict):
return val[node.attr]
return None

elif isinstance(node, vy_ast.BinOp):
assert isinstance(node, vy_ast.BinOp)
left = prefold(node.left)
right = prefold(node.right)
if not (isinstance(left, type(right)) and isinstance(left, (int, Decimal))):
return None
return node.op._op(left, right)
if isinstance(left, type(right)) and isinstance(left, (int, Decimal)):
return node.op._op(left, right)

elif isinstance(node, vy_ast.BoolOp):
values = [prefold(i) for i in node.values]
if not all(isinstance(v, bool) for v in values):
return None
return node.op._op(values)
if all(isinstance(v, bool) for v in values):
return node.op._op(values)

elif isinstance(node, vy_ast.Call):
# constant structs
if len(node.args) == 1 and isinstance(node.args[0], vy_ast.Dict):
Expand All @@ -40,6 +40,7 @@ def prefold(node: vy_ast.VyperNode) -> Any:
return call_type.evaluate(node).value # type: ignore
except (UnfoldableNode, VyperException):
pass

elif isinstance(node, vy_ast.Compare):
left = prefold(node.left)

Expand All @@ -53,28 +54,29 @@ def prefold(node: vy_ast.VyperNode) -> Any:
return node.op._op(left, right)

right = prefold(node.right)
if not (isinstance(left, type(right)) and isinstance(left, (int, Decimal))):
return None
return node.op._op(left, right)
if isinstance(left, type(right)) and isinstance(left, (int, Decimal)):
return node.op._op(left, right)

elif isinstance(node, vy_ast.Constant):
return node.value

elif isinstance(node, vy_ast.Dict):
values = [prefold(v) for v in node.values]
if any(v is None for v in values):
return None
return {k.id: v for (k, v) in zip(node.keys, values)}
if not any(v is None for v in values):
return {k.id: v for (k, v) in zip(node.keys, values)}

elif isinstance(node, (vy_ast.List, vy_ast.Tuple)):
val = [prefold(e) for e in node.elements]
if None in val:
return None
return val
if None not in val:
return val

elif isinstance(node, vy_ast.Name):
ns = get_namespace()
return ns._constants.get(node.id, None)
return ns._constants.get(node.id)

elif isinstance(node, vy_ast.UnaryOp):
operand = prefold(node.operand)
if not isinstance(operand, int):
return None
return node.op._op(operand)
if isinstance(operand, int):
return node.op._op(operand)

return None

0 comments on commit aa7f416

Please sign in to comment.