diff --git a/tests/ast/test_folding.py b/tests/ast/test_folding.py index eed00215a8..4b2f32e112 100644 --- a/tests/ast/test_folding.py +++ b/tests/ast/test_folding.py @@ -200,7 +200,7 @@ def test_replace_constant(source): unmodified_ast = vy_ast.parse_to_ast(source) folded_ast = vy_ast.parse_to_ast(source) - folding.replace_constant(folded_ast, "FOO", vy_ast.Int(value=31337), UINT256_T, True) + folding.replace_constant(folded_ast, "FOO", vy_ast.Int(value=31337), UINT256_T, 31337, True) assert not vy_ast.compare_nodes(unmodified_ast, folded_ast) @@ -223,7 +223,7 @@ def test_replace_constant_no(source): unmodified_ast = vy_ast.parse_to_ast(source) folded_ast = vy_ast.parse_to_ast(source) - folding.replace_constant(folded_ast, "FOO", vy_ast.Int(value=31337), UINT256_T, True) + folding.replace_constant(folded_ast, "FOO", vy_ast.Int(value=31337), UINT256_T, 31337, True) assert vy_ast.compare_nodes(unmodified_ast, folded_ast) diff --git a/vyper/ast/folding.py b/vyper/ast/folding.py index 530cb820b5..e02af55ad8 100644 --- a/vyper/ast/folding.py +++ b/vyper/ast/folding.py @@ -1,4 +1,4 @@ -from typing import Union +from typing import Any, Union from vyper.ast import nodes as vy_ast from vyper.exceptions import UnfoldableNode @@ -47,6 +47,7 @@ def replace_literal_ops(vyper_module: vy_ast.Module) -> int: except UnfoldableNode: continue + new_node._metadata["folded_value"] = new_node.value typ = node._metadata.get("type") # type metadata may not be present @@ -87,6 +88,7 @@ def replace_subscripts(vyper_module: vy_ast.Module) -> int: except UnfoldableNode: continue + new_node._metadata["folded_value"] = node._metadata["folded_value"] new_node._metadata["type"] = node._metadata["type"] changed_nodes += 1 @@ -125,6 +127,7 @@ def replace_builtin_functions(vyper_module: vy_ast.Module) -> int: except UnfoldableNode: continue + new_node._metadata["folded_value"] = new_node.value new_node._metadata["type"] = node._metadata["type"] changed_nodes += 1 @@ -156,7 +159,10 @@ def replace_user_defined_constants(vyper_module: vy_ast.Module) -> int: continue type_ = node._metadata["type"] - changed_nodes += replace_constant(vyper_module, node.target.id, node.value, type_, False) + folded_value = node.value._metadata["folded_value"] + changed_nodes += replace_constant( + vyper_module, node.target.id, node.value, type_, folded_value, False + ) return changed_nodes @@ -169,7 +175,7 @@ def _replace(old_node, new_node, type_): new_node = new_node.from_node(old_node, value=new_node.value) elif isinstance(new_node, vy_ast.List): base_type = type_.value_type if type_ else None - list_values = [_replace(old_node, i, type_=base_type) for i in new_node.elements] + list_values = [_replace(old_node, i, base_type) for i in new_node.elements] new_node = new_node.from_node(old_node, elements=list_values) elif isinstance(new_node, vy_ast.Call): # Replace `Name` node with `Call` node @@ -193,6 +199,7 @@ def replace_constant( id_: str, replacement_node: Union[vy_ast.Constant, vy_ast.List, vy_ast.Call], type_: VyperType, + folded_value: Any, raise_on_error: bool, ) -> int: """ @@ -209,6 +216,8 @@ def replace_constant( `Call` nodes are for struct constants. type_ : VyperType Type definition to be propagated to type checker. + folded_value: Any + Folded value of the constant raise_on_error: bool Boolean indicating if `UnfoldableNode` exception should be raised or ignored. @@ -247,6 +256,7 @@ def replace_constant( try: # note: _replace creates a copy of the replacement_node new_node = _replace(node, replacement_node, type_) + new_node._metadata["folded_value"] = folded_value except UnfoldableNode: if raise_on_error: raise diff --git a/vyper/ast/nodes.py b/vyper/ast/nodes.py index 2497928035..b0cbedfaa7 100644 --- a/vyper/ast/nodes.py +++ b/vyper/ast/nodes.py @@ -411,6 +411,10 @@ def to_dict(self) -> dict: if "type" in self._metadata: ast_dict["type"] = str(self._metadata["type"]) + folded_value = self._metadata.get("folded_value") + if folded_value is not None: + ast_dict["folded_value"] = str(self._metadata["folded_value"]) + return ast_dict def get_ancestor(self, node_type: Union["VyperNode", tuple, None] = None) -> "VyperNode": diff --git a/vyper/ast/pre_typecheck.py b/vyper/ast/pre_typecheck.py deleted file mode 100644 index 685bedb52f..0000000000 --- a/vyper/ast/pre_typecheck.py +++ /dev/null @@ -1,81 +0,0 @@ -from decimal import Decimal -from typing import Any - -from vyper.ast import nodes as vy_ast -from vyper.exceptions import UnfoldableNode, VyperException -from vyper.semantics.namespace import get_namespace - - -def prefold(node: vy_ast.VyperNode) -> Any: - if isinstance(node, vy_ast.Attribute): - val = prefold(node.value) - # constant struct members - if isinstance(val, dict): - return val[node.attr] - - elif isinstance(node, vy_ast.BinOp): - left = prefold(node.left) - right = prefold(node.right) - if isinstance(left, type(right)) and isinstance(left, (int, Decimal)): - return node.op._op(left, right) - - elif isinstance(node, vy_ast.BoolOp): - values = [prefold(i) for i in node.values] - if all(isinstance(v, bool) for v in values): - return node.op._op(values) - - elif isinstance(node, vy_ast.Call): - # constant structs - if len(node.args) == 1 and isinstance(node.args[0], vy_ast.Dict): - return prefold(node.args[0]) - - from vyper.builtins.functions import DISPATCH_TABLE - - # builtins - if isinstance(node.func, vy_ast.Name): - call_type = DISPATCH_TABLE.get(node.func.id) - if call_type and hasattr(call_type, "evaluate"): - try: - return call_type.evaluate(node).value # type: ignore - except (UnfoldableNode, VyperException): - pass - - elif isinstance(node, vy_ast.Compare): - left = prefold(node.left) - - if isinstance(node.op, (vy_ast.In, vy_ast.NotIn)): - if not isinstance(node.right, (vy_ast.List, vy_ast.Tuple)): - return None - - right = [prefold(i) for i in node.right.elements] - if left is None or len(set([type(i) for i in right])) > 1: - return None - return node.op._op(left, right) - - right = prefold(node.right) - if isinstance(left, type(right)) and isinstance(left, (int, Decimal)): - return node.op._op(left, right) - - elif isinstance(node, vy_ast.Constant): - return node.value - - elif isinstance(node, vy_ast.Dict): - values = [prefold(v) for v in node.values] - if not any(v is None for v in values): - return {k.id: v for (k, v) in zip(node.keys, values)} - - elif isinstance(node, (vy_ast.List, vy_ast.Tuple)): - val = [prefold(e) for e in node.elements] - if None not in val: - return val - - elif isinstance(node, vy_ast.Name): - ns = get_namespace() - return ns._constants.get(node.id) - - elif isinstance(node, vy_ast.UnaryOp): - operand = prefold(node.operand) - if isinstance(operand, int): - return node.op._op(operand) - - return None diff --git a/vyper/builtins/_signatures.py b/vyper/builtins/_signatures.py index 4f92b8efb2..418e62ad53 100644 --- a/vyper/builtins/_signatures.py +++ b/vyper/builtins/_signatures.py @@ -1,7 +1,6 @@ import functools from typing import Dict -from vyper.ast.pre_typecheck import prefold from vyper.ast.validation import validate_call_args from vyper.codegen.expr import Expr from vyper.codegen.ir_node import IRnode @@ -104,7 +103,7 @@ def _validate_arg_types(self, node): for kwarg in node.keywords: kwarg_settings = self._kwargs[kwarg.arg] - is_literal_value = prefold(kwarg.value) is not None + is_literal_value = kwarg.value._metadata.get("folded_value") is not None if kwarg_settings.require_literal and not is_literal_value: raise TypeMismatch("Value for kwarg must be a literal", kwarg.value) diff --git a/vyper/builtins/functions.py b/vyper/builtins/functions.py index 8cecbe8a2a..1f49936f91 100644 --- a/vyper/builtins/functions.py +++ b/vyper/builtins/functions.py @@ -5,7 +5,6 @@ from vyper import ast as vy_ast from vyper.abi_types import ABI_Tuple -from vyper.ast.pre_typecheck import prefold from vyper.ast.validation import validate_call_args from vyper.codegen.abi_encoder import abi_encode from vyper.codegen.context import Context, VariableRecord @@ -142,7 +141,7 @@ class Floor(BuiltinFunction): def evaluate(self, node): validate_call_args(node, 1) - input_val = prefold(node.args[0]) + input_val = node.args[0]._metadata.get("folded_value") if not isinstance(input_val, Decimal): raise UnfoldableNode @@ -173,7 +172,7 @@ class Ceil(BuiltinFunction): def evaluate(self, node): validate_call_args(node, 1) - input_val = prefold(node.args[0]) + input_val = node.args[0]._metadata.get("folded_value") if not isinstance(input_val, Decimal): raise UnfoldableNode @@ -468,7 +467,7 @@ class Len(BuiltinFunction): def evaluate(self, node): validate_call_args(node, 1) - arg = prefold(node.args[0]) + arg = node.args[0]._metadata.get("folded_value") if isinstance(arg, (str, bytes)): length = len(arg) else: @@ -602,7 +601,7 @@ class Keccak256(BuiltinFunction): def evaluate(self, node): validate_call_args(node, 1) - value = prefold(node.args[0]) + value = node.args[0]._metadata.get("folded_value") if not isinstance(value, (bytes, str)): raise UnfoldableNode @@ -653,7 +652,7 @@ class Sha256(BuiltinFunction): def evaluate(self, node): validate_call_args(node, 1) - value = prefold(node.args[0]) + value = node.args[0]._metadata.get("folded_value") if not isinstance(value, (bytes, str)): raise UnfoldableNode @@ -984,7 +983,7 @@ class AsWeiValue(BuiltinFunction): } def get_denomination(self, node): - value = prefold(node.args[1]) + value = node.args[1]._metadata.get("folded_value") if not isinstance(value, str): raise ArgumentException( "Wei denomination must be given as a literal string", node.args[1] @@ -1000,7 +999,7 @@ def evaluate(self, node): validate_call_args(node, 2) denom = self.get_denomination(node) - value = prefold(node.args[0]) + value = node.args[0]._metadata.get("folded_value") if not isinstance(value, (Decimal, int)): raise UnfoldableNode @@ -1083,8 +1082,13 @@ def fetch_call_return(self, node): kwargz = {i.arg: i.value for i in node.keywords} - outsize = prefold(kwargz.get("max_outsize")) - revert_on_failure = prefold(kwargz.get("revert_on_failure")) + outsize = kwargz.get("max_outsize") + if outsize is not None: + outsize = outsize._metadata.get("folded_value") + revert_on_failure = kwargz.get("revert_on_failure") + if revert_on_failure is not None: + revert_on_failure = revert_on_failure._metadata.get("folded_value") + revert_on_failure = revert_on_failure if revert_on_failure is not None else True if outsize is None or outsize == 0: @@ -1352,7 +1356,7 @@ def evaluate(self, node): self.__class__._warned = True validate_call_args(node, 2) - values = [prefold(i) for i in node.args] + values = [i._metadata.get("folded_value") for i in node.args] for v, arg in zip(values, node.args): if not isinstance(v, int): raise UnfoldableNode @@ -1379,7 +1383,7 @@ def evaluate(self, node): self.__class__._warned = True validate_call_args(node, 2) - values = [prefold(i) for i in node.args] + values = [i._metadata.get("folded_value") for i in node.args] for v, arg in zip(values, node.args): if not isinstance(arg, int): raise UnfoldableNode @@ -1406,7 +1410,7 @@ def evaluate(self, node): self.__class__._warned = True validate_call_args(node, 2) - values = [prefold(i) for i in node.args] + values = [i._metadata.get("folded_value") for i in node.args] for v, arg in zip(values, node.args): if not isinstance(arg, int): raise UnfoldableNode @@ -1433,7 +1437,7 @@ def evaluate(self, node): self.__class__._warned = True validate_call_args(node, 1) - value = prefold(node.args[0]) + value = node.args[0]._metadata.get("folded_value") if not isinstance(value, int): raise UnfoldableNode @@ -1460,7 +1464,7 @@ def evaluate(self, node): self.__class__._warned = True validate_call_args(node, 2) - value, shift = [prefold(i) for i in node.args] + value, shift = [i._metadata.get("folded_value") for i in node.args] if any(not isinstance(i, int) for i in [value, shift]): raise UnfoldableNode if value < 0 or value >= 2**256: @@ -1508,7 +1512,7 @@ class _AddMulMod(BuiltinFunction): def evaluate(self, node): validate_call_args(node, 3) - values = [prefold(i) for i in node.args] + values = [i._metadata.get("folded_value") for i in node.args] if isinstance(values[2], int) and values[2] == 0: raise ZeroDivisionException("Modulo by 0", node.args[2]) for v, arg in zip(values, node.args): @@ -1551,7 +1555,7 @@ class PowMod256(BuiltinFunction): def evaluate(self, node): validate_call_args(node, 2) - values = [prefold(i) for i in node.args] + values = [i._metadata.get("folded_value") for i in node.args] if any(not isinstance(i, int) for i in values): raise UnfoldableNode @@ -1575,7 +1579,7 @@ class Abs(BuiltinFunction): def evaluate(self, node): validate_call_args(node, 1) - value = prefold(node.args[0]) + value = node.args[0]._metadata.get("folded_value") if not isinstance(value, int): raise UnfoldableNode @@ -2019,7 +2023,7 @@ class _MinMax(BuiltinFunction): def evaluate(self, node): validate_call_args(node, 2) - values = [prefold(i) for i in node.args] + values = [i._metadata.get("folded_value") for i in node.args] if not isinstance(values[0], type(values[1])): raise UnfoldableNode if not isinstance(values[0], (Decimal, int)): @@ -2111,7 +2115,7 @@ def fetch_call_return(self, node): def evaluate(self, node): validate_call_args(node, 1) - value = prefold(node.args[0]) + value = node.args[0]._metadata.get("folded_value") if not isinstance(value, int): raise UnfoldableNode diff --git a/vyper/semantics/analysis/__init__.py b/vyper/semantics/analysis/__init__.py index 0e30c3da1a..2358cd52a6 100644 --- a/vyper/semantics/analysis/__init__.py +++ b/vyper/semantics/analysis/__init__.py @@ -4,6 +4,7 @@ from ..namespace import get_namespace from .local import validate_functions from .module import add_module_namespace +from .pre_typecheck import pre_typecheck from .utils import _ExprAnalyser @@ -12,6 +13,7 @@ def validate_semantics(vyper_ast, interface_codes): namespace = get_namespace() with namespace.enter_scope(): + pre_typecheck(vyper_ast) add_module_namespace(vyper_ast, interface_codes) vy_ast.expansion.generate_public_variable_getters(vyper_ast) validate_functions(vyper_ast) diff --git a/vyper/semantics/analysis/local.py b/vyper/semantics/analysis/local.py index 2bc3a9bcd6..95cfd3f505 100644 --- a/vyper/semantics/analysis/local.py +++ b/vyper/semantics/analysis/local.py @@ -2,7 +2,6 @@ from vyper import ast as vy_ast from vyper.ast.metadata import NodeMetadata -from vyper.ast.pre_typecheck import prefold from vyper.ast.validation import validate_call_args from vyper.exceptions import ( ExceptionList, @@ -358,7 +357,7 @@ def visit_For(self, node): validate_expected_type(n, IntegerT.any()) if bound is None: - n_val = prefold(n) + n_val = n._metadata.get("folded_value") if not isinstance(n_val, int): raise StateAccessViolation("Value must be a literal integer", n) if n_val <= 0: @@ -366,7 +365,7 @@ def visit_For(self, node): type_list = get_possible_types_from_node(n) else: - bound_val = prefold(bound) + bound_val = bound._metadata.get("folded_value") if bound_val is None: raise StateAccessViolation("bound must be a literal", bound) if bound_val <= 0: @@ -383,7 +382,7 @@ def visit_For(self, node): validate_expected_type(args[0], IntegerT.any()) type_list = get_common_types(*args) - arg0_val = prefold(args[0]) + arg0_val = args[0]._metadata.get("folded_value") if not isinstance(arg0_val, int): # range(x, x + CONSTANT) if not isinstance(args[1], vy_ast.BinOp) or not isinstance( @@ -397,7 +396,7 @@ def visit_For(self, node): "First and second variable must be the same", args[1].left ) - right_val = prefold(args[1].right) + right_val = args[1].right._metadata.get("folded_value") if not isinstance(right_val, int): raise InvalidLiteral("Literal must be an integer", args[1].right) if right_val < 1: @@ -408,7 +407,7 @@ def visit_For(self, node): ) else: # range(CONSTANT, CONSTANT) - arg1_val = prefold(args[1]) + arg1_val = args[1]._metadata.get("folded_value") if not isinstance(arg1_val, int): raise InvalidType("Value must be a literal integer", args[1]) validate_expected_type(args[1], IntegerT.any()) @@ -420,7 +419,7 @@ def visit_For(self, node): else: # iteration over a variable or literal list - iter_ = prefold(node.iter) + iter_ = node.iter._metadata.get("folded_value") if isinstance(iter_, list) and len(iter_) == 0: raise StructureException("For loop must have at least 1 iteration", node.iter) diff --git a/vyper/semantics/analysis/module.py b/vyper/semantics/analysis/module.py index ddf0e958c3..f9581b1458 100644 --- a/vyper/semantics/analysis/module.py +++ b/vyper/semantics/analysis/module.py @@ -4,7 +4,6 @@ import vyper.builtins.interfaces from vyper import ast as vy_ast -from vyper.ast.pre_typecheck import prefold from vyper.evm.opcodes import version_check from vyper.exceptions import ( CallViolation, @@ -56,35 +55,6 @@ def _find_cyclic_call(fn_names: list, self_members: dict) -> Optional[list]: return None -def _add_constants_to_namespace(module_nodes: list[vy_ast.VyperNode], ns: Namespace) -> None: - const_var_decls = [ - n for n in module_nodes if isinstance(n, vy_ast.VariableDecl) and n.is_constant - ] - - while const_var_decls: - derived_nodes = 0 - - for c in const_var_decls: - name = c.get("target.id") - # Handle syntax errors downstream - if c.value is None: - continue - - val = prefold(c.value) - - # note that if a constant is redefined, its value will be overwritten, - # but it is okay because the syntax error is handled downstream - if val is not None: - ns.add_constant(name, val) - derived_nodes += 1 - const_var_decls.remove(c) - - if not derived_nodes: - break - - return - - class ModuleAnalyzer(VyperNodeVisitorBase): scope_name = "module" @@ -97,9 +67,6 @@ def __init__( # TODO: Move computation out of constructor module_nodes = module_node.body.copy() - - _add_constants_to_namespace(module_nodes, namespace) - while module_nodes: count = len(module_nodes) err_list = ExceptionList() @@ -130,7 +97,6 @@ def __init__( # note that we don't just copy the namespace because # there are constructor issues. _ns.update({k: namespace[k] for k in namespace._scopes[-1]}) # type: ignore - _ns._constants = self.namespace._constants # type: ignore module_node._metadata["namespace"] = _ns self_members = namespace["self"].typ.members diff --git a/vyper/semantics/analysis/pre_typecheck.py b/vyper/semantics/analysis/pre_typecheck.py new file mode 100644 index 0000000000..c2e8d64d6b --- /dev/null +++ b/vyper/semantics/analysis/pre_typecheck.py @@ -0,0 +1,271 @@ +from decimal import Decimal + +from vyper import ast as vy_ast +from vyper.exceptions import UnfoldableNode, VyperException +from vyper.semantics.analysis.common import VyperNodeVisitorBase + + +def pre_typecheck(node: vy_ast.VyperNode) -> None: + PreTypecheckVisitor(node) + + +class PreTypecheckVisitor(VyperNodeVisitorBase): + ignored_types = ( + vy_ast.Pass, + vy_ast.ImplementsDecl, + vy_ast.EnumDef, + vy_ast.Import, + vy_ast.ImportFrom, + vy_ast.Break, + vy_ast.Continue, + ) + scope_name = "module" + + def __init__(self, node: vy_ast.VyperNode) -> None: + self.constants = {} + + if isinstance(node, vy_ast.Module): + module_nodes = node.body.copy() + const_var_decls = [ + n for n in module_nodes if isinstance(n, vy_ast.VariableDecl) and n.is_constant + ] + + while const_var_decls: + derived_nodes = 0 + + for c in const_var_decls: + name = c.get("target.id") + # Handle syntax errors downstream + if c.value is None: + continue + + self.visit(c.value) + + val = c.value._metadata.get("folded_value") + + # note that if a constant is redefined, its value will be overwritten, + # but it is okay because the syntax error is handled downstream + if val is not None: + self.constants[name] = val + derived_nodes += 1 + const_var_decls.remove(c) + + if not derived_nodes: + break + + self.visit(node) + + def visit(self, node): + super().visit(node) + + # Module-level declarations + + def visit_EventDef(self, node): + for n in node.body: + self.visit(n.annotation) + + def visit_FunctionDef(self, node): + # visit type annotations of arguments + # e.g. def foo(a: DynArray[uint256, 2 ** 8]): ... + for arg in node.args.args: + self.visit(arg.annotation) + + for kwarg in node.args.defaults: + self.visit(kwarg) + + if node.returns: + self.visit(node.returns) + + for n in node.body: + self.visit(n) + + def visit_InterfaceDef(self, node): + for n in node.body: + self.visit(n) + + def visit_Module(self, node): + for n in node.body: + self.visit(n) + + def visit_StructDef(self, node): + for n in node.body: + self.visit(n.annotation) + + def visit_VariableDecl(self, node): + self.visit(node.annotation) + if node.is_constant: + self.visit(node.value) + + # Stmts + + def visit_AnnAssign(self, node): + self.visit(node.target) + self.visit(node.value) + self.visit(node.annotation) + + def visit_Assert(self, node): + self.visit(node.test) + if node.msg: + self.visit(node.msg) + + def _assign_helper(self, node): + self.visit(node.target) + self.visit(node.value) + + def visit_Assign(self, node): + self._assign_helper(node) + + def visit_AugAssign(self, node): + self._assign_helper(node) + + def visit_Expr(self, node): + self.visit(node.value) + + def visit_For(self, node): + for n in node.body: + self.visit(n) + + self.visit(node.iter) + self.visit(node.target) + + def visit_If(self, node): + for n in node.body: + self.visit(n) + for n in node.orelse: + self.visit(n) + + def visit_Log(self, node): + self.visit(node.value) + + def visit_Raise(self, node): + if node.exc: + self.visit(node.exc) + + def visit_Return(self, node): + if node.value: + self.visit(node.value) + + # Expr + + def visit_Attribute(self, node): + self.visit(node.value) + value_node_val = node.value._metadata.get("folded_value") + if isinstance(value_node_val, dict): + node._metadata["folded_value"] = value_node_val[node.attr] + + def visit_BinOp(self, node): + self.visit(node.left) + self.visit(node.right) + + left = node.left._metadata.get("folded_value") + right = node.right._metadata.get("folded_value") + if isinstance(left, type(right)) and isinstance(left, (int, Decimal)): + node._metadata["folded_value"] = node.op._op(left, right) + + def visit_BoolOp(self, node): + for i in node.values: + self.visit(i) + + values = [i._metadata.get("folded_value") for i in node.values] + if all(isinstance(v, bool) for v in values): + node._metadata["folded_value"] = node.op._op(values) + + def visit_Call(self, node): + for arg in node.args: + self.visit(arg) + for kwarg in node.keywords: + self.visit(kwarg.value) + + # constant structs + if len(node.args) == 1 and isinstance(node.args[0], vy_ast.Dict): + self.visit(node.args[0]) + node._metadata["folded_value"] = node.args[0]._metadata.get("folded_value") + + from vyper.builtins.functions import DISPATCH_TABLE + + # builtins + if isinstance(node.func, vy_ast.Name): + func_name = node.func.id + + call_type = DISPATCH_TABLE.get(func_name) + if call_type and hasattr(call_type, "evaluate"): + try: + node._metadata["folded_value"] = call_type.evaluate(node).value # type: ignore + return + except (UnfoldableNode, VyperException): + pass + + def visit_Compare(self, node): + self.visit(node.left) + self.visit(node.right) + + left = node.left._metadata.get("folded_value") + + if isinstance(node.op, (vy_ast.In, vy_ast.NotIn)): + if not isinstance(node.right, (vy_ast.List, vy_ast.Tuple)): + return + + right = [i._metadata.get("folded_value") for i in node.right.elements] + if left is None or len(set([type(i) for i in right])) > 1: + return + node._metadata["folded_value"] = node.op._op(left, right) + + right = node.right._metadata.get("folded_value") + if isinstance(left, type(right)) and isinstance(left, (int, Decimal)): + node._metadata["folded_value"] = node.op._op(left, right) + + def visit_Constant(self, node): + node._metadata["folded_value"] = node.value + + def visit_Dict(self, node): + for v in node.values: + self.visit(v) + + values = [v._metadata.get("folded_value") for v in node.values] + if not any(v is None for v in values): + node._metadata["folded_value"] = {k.id: v for (k, v) in zip(node.keys, values)} + + def visit_Index(self, node): + self.visit(node.value) + index_val = node.value._metadata.get("folded_value") + if index_val is not None: + node._metadata["folded_value"] = index_val + + # repeated code for List and Tuple + def _subscriptable_helper(self, node): + for e in node.elements: + self.visit(e) + + values = [e._metadata.get("folded_value") for e in node.elements] + if None not in values: + node._metadata["folded_value"] = values + + def visit_List(self, node): + self._subscriptable_helper(node) + + def visit_Name(self, node): + if node.id in self.constants: + node._metadata["folded_value"] = self.constants.get(node.id) + + def visit_Subscript(self, node): + self.visit(node.slice) + self.visit(node.value) + + slice_val = node.slice._metadata.get("folded_value") + sliced_val = node.value._metadata.get("folded_value") + if None not in (slice_val, sliced_val): + node._metadata["folded_value"] = sliced_val[slice_val] + + def visit_Tuple(self, node): + self._subscriptable_helper(node) + + def visit_UnaryOp(self, node): + self.visit(node.operand) + val = node.operand._metadata.get("folded_value") + if isinstance(val, int): + node._metadata["folded_value"] = node.op._op(val) + + def visit_IfExp(self, node): + self.visit(node.test) + self.visit(node.body) + self.visit(node.orelse) diff --git a/vyper/semantics/analysis/utils.py b/vyper/semantics/analysis/utils.py index 5b395f6fac..bfe79b5d61 100644 --- a/vyper/semantics/analysis/utils.py +++ b/vyper/semantics/analysis/utils.py @@ -2,7 +2,6 @@ from typing import Callable, List from vyper import ast as vy_ast -from vyper.ast.pre_typecheck import prefold from vyper.exceptions import ( CompilerPanic, InvalidLiteral, @@ -644,7 +643,7 @@ def check_constant(node: vy_ast.VyperNode) -> bool: """ Check if the given node is a literal or constant value. """ - if prefold(node) is not None: + if node._metadata.get("folded_value") is not None: return True if isinstance(node, vy_ast.Call): call_type = get_exact_type_from_node(node.func) diff --git a/vyper/semantics/namespace.py b/vyper/semantics/namespace.py index 8683e61e01..416c17bd8d 100644 --- a/vyper/semantics/namespace.py +++ b/vyper/semantics/namespace.py @@ -13,14 +13,11 @@ class Namespace(dict): ---------- _scopes : list[set] List of sets containing the key names for each scope - _constants: dict - Set containing user-defined constants and their values """ def __new__(cls, *args, **kwargs): self = super().__new__(cls, *args, **kwargs) self._scopes = [] - self._constants = {} return self def __init__(self): @@ -77,8 +74,6 @@ def enter_scope(self): if len(self._scopes) == 1: # add mutable vars (`self`) to the initial scope self.update(environment.get_mutable_vars()) - # reset constants - self._constants = {} return self @@ -88,7 +83,6 @@ def update(self, other): def clear(self): super().clear() - self._constants.clear() self.__init__() def validate_assignment(self, attr): @@ -98,9 +92,6 @@ def validate_assignment(self, attr): obj = super().__getitem__(attr) raise NamespaceCollision(f"'{attr}' has already been declared as a {obj}") - def add_constant(self, name, value): - self._constants[name] = value - def get_namespace(): """ diff --git a/vyper/semantics/types/subscriptable.py b/vyper/semantics/types/subscriptable.py index 967991ce76..6b802f541e 100644 --- a/vyper/semantics/types/subscriptable.py +++ b/vyper/semantics/types/subscriptable.py @@ -3,7 +3,6 @@ from vyper import ast as vy_ast from vyper.abi_types import ABI_DynamicArray, ABI_StaticArray, ABI_Tuple, ABIType -from vyper.ast.pre_typecheck import prefold from vyper.exceptions import ArrayIndexException, InvalidType, StructureException from vyper.semantics.data_locations import DataLocation from vyper.semantics.types.base import VyperType @@ -129,7 +128,7 @@ def validate_index_type(self, node): # TODO break this cycle from vyper.semantics.analysis.utils import validate_expected_type - index_val = prefold(node) + index_val = node._metadata.get("folded_value") if isinstance(index_val, int): if index_val < 0: raise ArrayIndexException("Vyper does not support negative indexing", node) @@ -287,7 +286,7 @@ def from_annotation(cls, node: vy_ast.Subscript) -> "DArrayT": node, ) - max_length = prefold(node.slice.value.elements[1]) + max_length = node.slice.value.elements[1]._metadata.get("folded_value") if not isinstance(max_length, int): raise StructureException( "DynArray must have a max length of integer type, e.g. DynArray[bool, 5]", node diff --git a/vyper/semantics/types/user.py b/vyper/semantics/types/user.py index ce82731c34..ffb3994203 100644 --- a/vyper/semantics/types/user.py +++ b/vyper/semantics/types/user.py @@ -16,6 +16,7 @@ ) from vyper.semantics.analysis.base import VarInfo from vyper.semantics.analysis.levenshtein_utils import get_levenshtein_error_suggestions +from vyper.semantics.analysis.pre_typecheck import pre_typecheck from vyper.semantics.analysis.utils import validate_expected_type, validate_unique_method_ids from vyper.semantics.data_locations import DataLocation from vyper.semantics.namespace import get_namespace @@ -428,6 +429,7 @@ def from_ast(cls, node: Union[vy_ast.InterfaceDef, vy_ast.Module]) -> "Interface InterfaceT primitive interface type """ + pre_typecheck(node) if isinstance(node, vy_ast.Module): members, events = _get_module_definitions(node) elif isinstance(node, vy_ast.InterfaceDef): diff --git a/vyper/semantics/types/utils.py b/vyper/semantics/types/utils.py index 5cde0530a7..38f160bf7c 100644 --- a/vyper/semantics/types/utils.py +++ b/vyper/semantics/types/utils.py @@ -1,5 +1,4 @@ from vyper import ast as vy_ast -from vyper.ast.pre_typecheck import prefold from vyper.exceptions import ArrayIndexException, InstantiationException, InvalidType, UnknownType from vyper.semantics.analysis.levenshtein_utils import get_levenshtein_error_suggestions from vyper.semantics.data_locations import DataLocation @@ -140,7 +139,7 @@ def get_index_value(node: vy_ast.Index) -> int: int Literal integer value. """ - val = prefold(node.value) + val = node.value._metadata.get("folded_value") if not isinstance(val, int): raise InvalidType("Subscript must be a literal integer", node)