Skip to content

Commit

Permalink
convert prefold to semantics pass
Browse files Browse the repository at this point in the history
  • Loading branch information
tserg committed Oct 31, 2023
1 parent 06f978b commit 0f8b234
Show file tree
Hide file tree
Showing 15 changed files with 329 additions and 165 deletions.
4 changes: 2 additions & 2 deletions tests/ast/test_folding.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ def test_replace_constant(source):
unmodified_ast = vy_ast.parse_to_ast(source)
folded_ast = vy_ast.parse_to_ast(source)

folding.replace_constant(folded_ast, "FOO", vy_ast.Int(value=31337), UINT256_T, True)
folding.replace_constant(folded_ast, "FOO", vy_ast.Int(value=31337), UINT256_T, 31337, True)

assert not vy_ast.compare_nodes(unmodified_ast, folded_ast)

Expand All @@ -223,7 +223,7 @@ def test_replace_constant_no(source):
unmodified_ast = vy_ast.parse_to_ast(source)
folded_ast = vy_ast.parse_to_ast(source)

folding.replace_constant(folded_ast, "FOO", vy_ast.Int(value=31337), UINT256_T, True)
folding.replace_constant(folded_ast, "FOO", vy_ast.Int(value=31337), UINT256_T, 31337, True)

assert vy_ast.compare_nodes(unmodified_ast, folded_ast)

Expand Down
16 changes: 13 additions & 3 deletions vyper/ast/folding.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Union
from typing import Any, Union

from vyper.ast import nodes as vy_ast
from vyper.exceptions import UnfoldableNode
Expand Down Expand Up @@ -47,6 +47,7 @@ def replace_literal_ops(vyper_module: vy_ast.Module) -> int:
except UnfoldableNode:
continue

new_node._metadata["folded_value"] = new_node.value
typ = node._metadata.get("type")

# type metadata may not be present
Expand Down Expand Up @@ -87,6 +88,7 @@ def replace_subscripts(vyper_module: vy_ast.Module) -> int:
except UnfoldableNode:
continue

new_node._metadata["folded_value"] = node._metadata["folded_value"]
new_node._metadata["type"] = node._metadata["type"]

changed_nodes += 1
Expand Down Expand Up @@ -125,6 +127,7 @@ def replace_builtin_functions(vyper_module: vy_ast.Module) -> int:
except UnfoldableNode:
continue

new_node._metadata["folded_value"] = new_node.value
new_node._metadata["type"] = node._metadata["type"]

changed_nodes += 1
Expand Down Expand Up @@ -156,7 +159,10 @@ def replace_user_defined_constants(vyper_module: vy_ast.Module) -> int:
continue

type_ = node._metadata["type"]
changed_nodes += replace_constant(vyper_module, node.target.id, node.value, type_, False)
folded_value = node.value._metadata["folded_value"]
changed_nodes += replace_constant(
vyper_module, node.target.id, node.value, type_, folded_value, False
)

return changed_nodes

Expand All @@ -169,7 +175,7 @@ def _replace(old_node, new_node, type_):
new_node = new_node.from_node(old_node, value=new_node.value)
elif isinstance(new_node, vy_ast.List):
base_type = type_.value_type if type_ else None
list_values = [_replace(old_node, i, type_=base_type) for i in new_node.elements]
list_values = [_replace(old_node, i, base_type) for i in new_node.elements]
new_node = new_node.from_node(old_node, elements=list_values)
elif isinstance(new_node, vy_ast.Call):
# Replace `Name` node with `Call` node
Expand All @@ -193,6 +199,7 @@ def replace_constant(
id_: str,
replacement_node: Union[vy_ast.Constant, vy_ast.List, vy_ast.Call],
type_: VyperType,
folded_value: Any,
raise_on_error: bool,
) -> int:
"""
Expand All @@ -209,6 +216,8 @@ def replace_constant(
`Call` nodes are for struct constants.
type_ : VyperType
Type definition to be propagated to type checker.
folded_value: Any
Folded value of the constant
raise_on_error: bool
Boolean indicating if `UnfoldableNode` exception should be raised or ignored.
Expand Down Expand Up @@ -247,6 +256,7 @@ def replace_constant(
try:
# note: _replace creates a copy of the replacement_node
new_node = _replace(node, replacement_node, type_)
new_node._metadata["folded_value"] = folded_value
except UnfoldableNode:
if raise_on_error:
raise
Expand Down
4 changes: 4 additions & 0 deletions vyper/ast/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,6 +411,10 @@ def to_dict(self) -> dict:
if "type" in self._metadata:
ast_dict["type"] = str(self._metadata["type"])

folded_value = self._metadata.get("folded_value")
if folded_value is not None:
ast_dict["folded_value"] = str(self._metadata["folded_value"])

return ast_dict

def get_ancestor(self, node_type: Union["VyperNode", tuple, None] = None) -> "VyperNode":
Expand Down
81 changes: 0 additions & 81 deletions vyper/ast/pre_typecheck.py

This file was deleted.

3 changes: 1 addition & 2 deletions vyper/builtins/_signatures.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import functools
from typing import Dict

from vyper.ast.pre_typecheck import prefold
from vyper.ast.validation import validate_call_args
from vyper.codegen.expr import Expr
from vyper.codegen.ir_node import IRnode
Expand Down Expand Up @@ -104,7 +103,7 @@ def _validate_arg_types(self, node):

for kwarg in node.keywords:
kwarg_settings = self._kwargs[kwarg.arg]
is_literal_value = prefold(kwarg.value) is not None
is_literal_value = kwarg.value._metadata.get("folded_value") is not None

if kwarg_settings.require_literal and not is_literal_value:
raise TypeMismatch("Value for kwarg must be a literal", kwarg.value)
Expand Down
44 changes: 24 additions & 20 deletions vyper/builtins/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

from vyper import ast as vy_ast
from vyper.abi_types import ABI_Tuple
from vyper.ast.pre_typecheck import prefold
from vyper.ast.validation import validate_call_args
from vyper.codegen.abi_encoder import abi_encode
from vyper.codegen.context import Context, VariableRecord
Expand Down Expand Up @@ -142,7 +141,7 @@ class Floor(BuiltinFunction):

def evaluate(self, node):
validate_call_args(node, 1)
input_val = prefold(node.args[0])
input_val = node.args[0]._metadata.get("folded_value")
if not isinstance(input_val, Decimal):
raise UnfoldableNode

Expand Down Expand Up @@ -173,7 +172,7 @@ class Ceil(BuiltinFunction):

def evaluate(self, node):
validate_call_args(node, 1)
input_val = prefold(node.args[0])
input_val = node.args[0]._metadata.get("folded_value")
if not isinstance(input_val, Decimal):
raise UnfoldableNode

Expand Down Expand Up @@ -468,7 +467,7 @@ class Len(BuiltinFunction):

def evaluate(self, node):
validate_call_args(node, 1)
arg = prefold(node.args[0])
arg = node.args[0]._metadata.get("folded_value")
if isinstance(arg, (str, bytes)):
length = len(arg)
else:
Expand Down Expand Up @@ -602,7 +601,7 @@ class Keccak256(BuiltinFunction):

def evaluate(self, node):
validate_call_args(node, 1)
value = prefold(node.args[0])
value = node.args[0]._metadata.get("folded_value")
if not isinstance(value, (bytes, str)):
raise UnfoldableNode

Expand Down Expand Up @@ -653,7 +652,7 @@ class Sha256(BuiltinFunction):

def evaluate(self, node):
validate_call_args(node, 1)
value = prefold(node.args[0])
value = node.args[0]._metadata.get("folded_value")
if not isinstance(value, (bytes, str)):
raise UnfoldableNode

Expand Down Expand Up @@ -984,7 +983,7 @@ class AsWeiValue(BuiltinFunction):
}

def get_denomination(self, node):
value = prefold(node.args[1])
value = node.args[1]._metadata.get("folded_value")
if not isinstance(value, str):
raise ArgumentException(
"Wei denomination must be given as a literal string", node.args[1]
Expand All @@ -1000,7 +999,7 @@ def evaluate(self, node):
validate_call_args(node, 2)
denom = self.get_denomination(node)

value = prefold(node.args[0])
value = node.args[0]._metadata.get("folded_value")
if not isinstance(value, (Decimal, int)):
raise UnfoldableNode

Expand Down Expand Up @@ -1083,8 +1082,13 @@ def fetch_call_return(self, node):

kwargz = {i.arg: i.value for i in node.keywords}

outsize = prefold(kwargz.get("max_outsize"))
revert_on_failure = prefold(kwargz.get("revert_on_failure"))
outsize = kwargz.get("max_outsize")
if outsize is not None:
outsize = outsize._metadata.get("folded_value")
revert_on_failure = kwargz.get("revert_on_failure")
if revert_on_failure is not None:
revert_on_failure = revert_on_failure._metadata.get("folded_value")

revert_on_failure = revert_on_failure if revert_on_failure is not None else True

if outsize is None or outsize == 0:
Expand Down Expand Up @@ -1352,7 +1356,7 @@ def evaluate(self, node):
self.__class__._warned = True

validate_call_args(node, 2)
values = [prefold(i) for i in node.args]
values = [i._metadata.get("folded_value") for i in node.args]
for v, arg in zip(values, node.args):
if not isinstance(v, int):
raise UnfoldableNode
Expand All @@ -1379,7 +1383,7 @@ def evaluate(self, node):
self.__class__._warned = True

validate_call_args(node, 2)
values = [prefold(i) for i in node.args]
values = [i._metadata.get("folded_value") for i in node.args]
for v, arg in zip(values, node.args):
if not isinstance(arg, int):
raise UnfoldableNode
Expand All @@ -1406,7 +1410,7 @@ def evaluate(self, node):
self.__class__._warned = True

validate_call_args(node, 2)
values = [prefold(i) for i in node.args]
values = [i._metadata.get("folded_value") for i in node.args]
for v, arg in zip(values, node.args):
if not isinstance(arg, int):
raise UnfoldableNode
Expand All @@ -1433,7 +1437,7 @@ def evaluate(self, node):
self.__class__._warned = True

validate_call_args(node, 1)
value = prefold(node.args[0])
value = node.args[0]._metadata.get("folded_value")
if not isinstance(value, int):
raise UnfoldableNode

Expand All @@ -1460,7 +1464,7 @@ def evaluate(self, node):
self.__class__._warned = True

validate_call_args(node, 2)
value, shift = [prefold(i) for i in node.args]
value, shift = [i._metadata.get("folded_value") for i in node.args]
if any(not isinstance(i, int) for i in [value, shift]):
raise UnfoldableNode
if value < 0 or value >= 2**256:
Expand Down Expand Up @@ -1508,7 +1512,7 @@ class _AddMulMod(BuiltinFunction):

def evaluate(self, node):
validate_call_args(node, 3)
values = [prefold(i) for i in node.args]
values = [i._metadata.get("folded_value") for i in node.args]
if isinstance(values[2], int) and values[2] == 0:
raise ZeroDivisionException("Modulo by 0", node.args[2])
for v, arg in zip(values, node.args):
Expand Down Expand Up @@ -1551,7 +1555,7 @@ class PowMod256(BuiltinFunction):

def evaluate(self, node):
validate_call_args(node, 2)
values = [prefold(i) for i in node.args]
values = [i._metadata.get("folded_value") for i in node.args]
if any(not isinstance(i, int) for i in values):
raise UnfoldableNode

Expand All @@ -1575,7 +1579,7 @@ class Abs(BuiltinFunction):

def evaluate(self, node):
validate_call_args(node, 1)
value = prefold(node.args[0])
value = node.args[0]._metadata.get("folded_value")
if not isinstance(value, int):
raise UnfoldableNode

Expand Down Expand Up @@ -2019,7 +2023,7 @@ class _MinMax(BuiltinFunction):

def evaluate(self, node):
validate_call_args(node, 2)
values = [prefold(i) for i in node.args]
values = [i._metadata.get("folded_value") for i in node.args]
if not isinstance(values[0], type(values[1])):
raise UnfoldableNode
if not isinstance(values[0], (Decimal, int)):
Expand Down Expand Up @@ -2111,7 +2115,7 @@ def fetch_call_return(self, node):

def evaluate(self, node):
validate_call_args(node, 1)
value = prefold(node.args[0])
value = node.args[0]._metadata.get("folded_value")
if not isinstance(value, int):
raise UnfoldableNode

Expand Down
2 changes: 2 additions & 0 deletions vyper/semantics/analysis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from ..namespace import get_namespace
from .local import validate_functions
from .module import add_module_namespace
from .pre_typecheck import pre_typecheck
from .utils import _ExprAnalyser


Expand All @@ -12,6 +13,7 @@ def validate_semantics(vyper_ast, interface_codes):
namespace = get_namespace()

with namespace.enter_scope():
pre_typecheck(vyper_ast)
add_module_namespace(vyper_ast, interface_codes)
vy_ast.expansion.generate_public_variable_getters(vyper_ast)
validate_functions(vyper_ast)
Loading

0 comments on commit 0f8b234

Please sign in to comment.