Skip to content

Commit

Permalink
minor clean up
Browse files Browse the repository at this point in the history
  • Loading branch information
tserg committed Dec 29, 2023
1 parent b3d9f20 commit 6d2addc
Showing 1 changed file with 40 additions and 40 deletions.
80 changes: 40 additions & 40 deletions vyper/semantics/analysis/pre_typecheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,36 @@
from vyper.exceptions import UnfoldableNode


def get_constants(node: vy_ast.Module) -> dict:
def _prefold(node: vy_ast.VyperNode, constants: dict[str, vy_ast.VyperNode]):
if isinstance(node, vy_ast.Name):
var_name = node.id
if var_name in constants:
node._metadata["folded_value"] = constants[var_name]
return

if isinstance(node, vy_ast.Call):
if isinstance(node.func, vy_ast.Name):
from vyper.builtins.functions import DISPATCH_TABLE

func_name = node.func.id

call_type = DISPATCH_TABLE.get(func_name)
if call_type and hasattr(call_type, "fold"):
try:
node._metadata["folded_value"] = call_type.fold(node)
return
except UnfoldableNode:
pass

# call `get_folded_value` for its side effects and allow all
# exceptions other than `UnfoldableNode` to raise
try:
node.get_folded_value()
except UnfoldableNode:
pass


def _get_constants(node: vy_ast.Module) -> dict:
constants: dict[str, vy_ast.VyperNode] = {}
module_nodes = node.body.copy()
const_var_decls = [
Expand All @@ -19,18 +48,18 @@ def get_constants(node: vy_ast.Module) -> dict:
continue

for n in c.value.get_descendants(include_self=True, reverse=True):
prefold(n, constants)
_prefold(n, constants)

try:
val = c.value.get_folded_value()

# note that if a constant is redefined, its value will be overwritten,
# but it is okay because the syntax error is handled downstream
constants[name] = val
n_processed += 1
const_var_decls.remove(c)
except UnfoldableNode:
pass
continue

# note that if a constant is redefined, its value will be overwritten,
# but it is okay because the syntax error is handled downstream
constants[name] = val
n_processed += 1
const_var_decls.remove(c)

if not n_processed:
break
Expand All @@ -39,39 +68,10 @@ def get_constants(node: vy_ast.Module) -> dict:


def pre_typecheck(node: vy_ast.Module) -> None:
constants = get_constants(node)
constants = _get_constants(node)

for n in node.get_descendants(reverse=True):
if isinstance(n, vy_ast.VariableDecl):
continue

prefold(n, constants)


def prefold(node: vy_ast.VyperNode, constants: dict[str, vy_ast.VyperNode]):
if isinstance(node, vy_ast.Name):
var_name = node.id
if var_name in constants:
node._metadata["folded_value"] = constants[var_name]
return

if isinstance(node, vy_ast.Call):
if isinstance(node.func, vy_ast.Name):
from vyper.builtins.functions import DISPATCH_TABLE

func_name = node.func.id

call_type = DISPATCH_TABLE.get(func_name)
if call_type and hasattr(call_type, "fold"):
try:
node._metadata["folded_value"] = call_type.fold(node)
return
except UnfoldableNode:
pass

# call `get_folded_value` for its side effects and allow all
# exceptions other than `UnfoldableNode` to raise
try:
node.get_folded_value()
except UnfoldableNode:
pass
_prefold(n, constants)

0 comments on commit 6d2addc

Please sign in to comment.