diff --git a/vyper/builtins/_signatures.py b/vyper/builtins/_signatures.py index 418e62ad53..dc544c716c 100644 --- a/vyper/builtins/_signatures.py +++ b/vyper/builtins/_signatures.py @@ -8,6 +8,7 @@ from vyper.semantics.analysis.utils import get_exact_type_from_node, validate_expected_type from vyper.semantics.types import TYPE_T, KwargSettings, VyperType from vyper.semantics.types.utils import type_from_annotation +from vyper.semantics.utils import get_folded_value def process_arg(arg, expected_arg_type, context): @@ -103,7 +104,7 @@ def _validate_arg_types(self, node): for kwarg in node.keywords: kwarg_settings = self._kwargs[kwarg.arg] - is_literal_value = kwarg.value._metadata.get("folded_value") is not None + is_literal_value = get_folded_value(kwarg.value) is not None if kwarg_settings.require_literal and not is_literal_value: raise TypeMismatch("Value for kwarg must be a literal", kwarg.value) diff --git a/vyper/builtins/functions.py b/vyper/builtins/functions.py index 8047374cea..6b068dbe87 100644 --- a/vyper/builtins/functions.py +++ b/vyper/builtins/functions.py @@ -52,6 +52,7 @@ ZeroDivisionException, ) from vyper.semantics.analysis.base import VarInfo +from vyper.semantics.analysis.pre_typecheck import get_folded_value from vyper.semantics.analysis.utils import ( get_common_types, get_exact_type_from_node, @@ -141,7 +142,7 @@ class Floor(BuiltinFunction): def evaluate(self, node, skip_typecheck=False): validate_call_args(node, 1) - arg = node.args[0]._metadata.get("folded_value") + arg = get_folded_value(node.args[0]) if not isinstance(arg, vy_ast.Decimal): raise UnfoldableNode @@ -172,7 +173,7 @@ class Ceil(BuiltinFunction): def evaluate(self, node, skip_typecheck=False): validate_call_args(node, 1) - arg = node.args[0]._metadata.get("folded_value") + arg = get_folded_value(node.args[0]) if not isinstance(arg, vy_ast.Decimal): raise UnfoldableNode @@ -467,9 +468,9 @@ class Len(BuiltinFunction): def evaluate(self, node, skip_typecheck=False): validate_call_args(node, 1) - arg = node.args[0]._metadata.get("folded_value") + arg = get_folded_value(node.args[0]) if isinstance(arg, (vy_ast.Str, vy_ast.Bytes)): - length = len(arg) + length = len(arg.value) elif isinstance(arg, vy_ast.Hex): # 2 characters represent 1 byte and we subtract 1 to ignore the leading `0x` length = len(arg.value) // 2 - 1 @@ -604,7 +605,7 @@ class Keccak256(BuiltinFunction): def evaluate(self, node, skip_typecheck=False): validate_call_args(node, 1) - arg = node.args[0]._metadata.get("folded_value") + arg = get_folded_value(node.args[0]) if isinstance(arg, vy_ast.Bytes): value = arg.value elif isinstance(arg, vy_ast.Str): @@ -652,7 +653,7 @@ class Sha256(BuiltinFunction): def evaluate(self, node, skip_typecheck=False): validate_call_args(node, 1) - arg = node.args[0]._metadata.get("folded_value") + arg = get_folded_value(node.args[0]) if isinstance(arg, vy_ast.Bytes): value = arg.value elif isinstance(arg, vy_ast.Str): @@ -980,7 +981,7 @@ class AsWeiValue(BuiltinFunction): } def get_denomination(self, node): - arg = node.args[1]._metadata.get("folded_value") + arg = get_folded_value(node.args[1]) if not isinstance(arg, vy_ast.Str): raise ArgumentException( "Wei denomination must be given as a literal string", node.args[1] @@ -996,7 +997,7 @@ def evaluate(self, node, skip_typecheck=False): validate_call_args(node, 2) denom = self.get_denomination(node) - arg = node.args[0]._metadata.get("folded_value") + arg = get_folded_value(node.args[0]) if not isinstance(arg, (vy_ast.Decimal, vy_ast.Int)): raise UnfoldableNode @@ -1082,10 +1083,10 @@ def fetch_call_return(self, node): outsize = kwargz.get("max_outsize") if outsize is not None: - outsize = outsize._metadata.get("folded_value") + outsize = get_folded_value(outsize) revert_on_failure = kwargz.get("revert_on_failure") if revert_on_failure is not None: - revert_on_failure = revert_on_failure._metadata.get("folded_value") + revert_on_failure = get_folded_value(revert_on_failure) revert_on_failure = revert_on_failure if revert_on_failure is not None else True @@ -1358,7 +1359,7 @@ def evaluate(self, node, skip_typecheck=False): self.__class__._warned = True validate_call_args(node, 2) - args = [i._metadata.get("folded_value") for i in node.args] + args = [get_folded_value(i) for i in node.args] for v, arg in zip(args, node.args): if not isinstance(v, vy_ast.Int): raise UnfoldableNode @@ -1385,7 +1386,7 @@ def evaluate(self, node, skip_typecheck=False): self.__class__._warned = True validate_call_args(node, 2) - args = [i._metadata.get("folded_value") for i in node.args] + args = [get_folded_value(i) for i in node.args] for v, arg in zip(args, node.args): if not isinstance(arg, vy_ast.Int): raise UnfoldableNode @@ -1412,7 +1413,7 @@ def evaluate(self, node, skip_typecheck=False): self.__class__._warned = True validate_call_args(node, 2) - args = [i._metadata.get("folded_value") for i in node.args] + args = [get_folded_value(i) for i in node.args] for v, arg in zip(args, node.args): if not isinstance(arg, vy_ast.Int): raise UnfoldableNode @@ -1439,7 +1440,7 @@ def evaluate(self, node, skip_typecheck=False): self.__class__._warned = True validate_call_args(node, 1) - arg = node.args[0]._metadata.get("folded_value") + arg = get_folded_value(node.args[0]) if not isinstance(arg, vy_ast.Int): raise UnfoldableNode @@ -1466,7 +1467,7 @@ def evaluate(self, node, skip_typecheck=False): self.__class__._warned = True validate_call_args(node, 2) - value, shift = [i._metadata.get("folded_value") for i in node.args] + value, shift = [get_folded_value(i) for i in node.args] if any(not isinstance(i, int) for i in [value, shift]): raise UnfoldableNode if value < 0 or value >= 2**256: @@ -1514,11 +1515,11 @@ class _AddMulMod(BuiltinFunction): def evaluate(self, node, skip_typecheck=False): validate_call_args(node, 3) - args = [i._metadata.get("folded_value") for i in node.args] + args = [get_folded_value(i) for i in node.args] if isinstance(args[2], vy_ast.Int) and args[2] == 0: raise ZeroDivisionException("Modulo by 0", node.args[2]) for v, arg in zip(args, node.args): - if not isinstance(v, int): + if not isinstance(v, vy_ast.Int): raise UnfoldableNode if v.value < 0 or v.value >= 2**256: raise InvalidLiteral("Value out of range for uint256", arg) @@ -1557,7 +1558,7 @@ class PowMod256(BuiltinFunction): def evaluate(self, node, skip_typecheck=False): validate_call_args(node, 2) - args = [i._metadata.get("folded_value") for i in node.args] + args = [get_folded_value(i) for i in node.args] if any(not isinstance(i, vy_ast.Int) for i in args): raise UnfoldableNode @@ -1581,7 +1582,7 @@ class Abs(BuiltinFunction): def evaluate(self, node, skip_typecheck=False): validate_call_args(node, 1) - arg = node.args[0]._metadata.get("folded_value") + arg = get_folded_value(node.args[0]) if not isinstance(arg, vy_ast.Int): raise UnfoldableNode @@ -2025,7 +2026,7 @@ class _MinMax(BuiltinFunction): def evaluate(self, node, skip_typecheck=False): validate_call_args(node, 2) - args = [i._metadata.get("folded_value") for i in node.args] + args = [get_folded_value(i) for i in node.args] if not isinstance(args[0], type(args[1])): raise UnfoldableNode if not isinstance(args[0], (vy_ast.Decimal, vy_ast.Int)): @@ -2047,12 +2048,7 @@ def evaluate(self, node, skip_typecheck=False): raise TypeMismatch("Cannot perform action between dislike numeric types", node) value = self._eval_fn(left.value, right.value) - - if isinstance(left, Decimal): - node = vy_ast.Decimal.from_node(node, value=value) - elif isinstance(left, int): - node = vy_ast.Int.from_node(node, value=value) - return node + return type(left).from_node(node, value=value) def fetch_call_return(self, node): self._validate_arg_types(node) @@ -2119,7 +2115,7 @@ def fetch_call_return(self, node): def evaluate(self, node, skip_typecheck=False): validate_call_args(node, 1) - arg = node.args[0]._metadata.get("folded_value") + arg = get_folded_value(node.args[0]) if not isinstance(arg, vy_ast.Int): raise UnfoldableNode diff --git a/vyper/semantics/analysis/local.py b/vyper/semantics/analysis/local.py index 553c2b18df..3073e1c40e 100644 --- a/vyper/semantics/analysis/local.py +++ b/vyper/semantics/analysis/local.py @@ -52,6 +52,7 @@ ) from vyper.semantics.types.function import ContractFunctionT, MemberFunctionT, StateMutability from vyper.semantics.types.utils import type_from_annotation +from vyper.semantics.utils import get_folded_value def validate_functions(vy_module: vy_ast.Module) -> None: @@ -353,7 +354,7 @@ def visit_For(self, node): if len(args) == 1: # range(CONSTANT) n = args[0] - folded_n = n._metadata.get("folded_value") + folded_n = get_folded_value(n) bound = kwargs.pop("bound", None) validate_expected_type(n, IntegerT.any()) @@ -366,7 +367,7 @@ def visit_For(self, node): type_list = get_possible_types_from_node(n) else: - folded_bound = bound._metadata.get("folded_value") + folded_bound = get_folded_value(bound) if folded_bound is None: raise StateAccessViolation("bound must be a literal", bound) if folded_bound.value <= 0: @@ -383,7 +384,7 @@ def visit_For(self, node): validate_expected_type(args[0], IntegerT.any()) type_list = get_common_types(*args) - folded_arg0 = args[0]._metadata.get("folded_value") + folded_arg0 = get_folded_value(args[0]) if not isinstance(folded_arg0, vy_ast.Constant): # range(x, x + CONSTANT) if not isinstance(args[1], vy_ast.BinOp) or not isinstance( @@ -397,7 +398,7 @@ def visit_For(self, node): "First and second variable must be the same", args[1].left ) - folded_right = args[1].right._metadata.get("folded_value") + folded_right = get_folded_value(args[1].right) if not isinstance(folded_right, vy_ast.Int): raise InvalidLiteral("Literal must be an integer", args[1].right) if folded_right.value < 1: @@ -408,7 +409,7 @@ def visit_For(self, node): ) else: # range(CONSTANT, CONSTANT) - folded_arg1 = args[1]._metadata.get("folded_value") + folded_arg1 = get_folded_value(args[1]) if not isinstance(folded_arg1, vy_ast.Int): raise InvalidType("Value must be a literal integer", args[1]) validate_expected_type(folded_arg1, IntegerT.any()) @@ -420,7 +421,7 @@ def visit_For(self, node): else: # iteration over a variable or literal list - folded_iter = node.iter._metadata.get("folded_value") + folded_iter = get_folded_value(node.iter) if isinstance(folded_iter, vy_ast.List) and len(folded_iter.elements) == 0: raise StructureException("For loop must have at least 1 iteration", node.iter) diff --git a/vyper/semantics/analysis/pre_typecheck.py b/vyper/semantics/analysis/pre_typecheck.py index e828d1e527..89ab20c5ff 100644 --- a/vyper/semantics/analysis/pre_typecheck.py +++ b/vyper/semantics/analysis/pre_typecheck.py @@ -1,8 +1,7 @@ -from typing import Optional - from vyper import ast as vy_ast from vyper.exceptions import UnfoldableNode, VyperException from vyper.semantics.analysis.common import VyperNodeVisitorBase +from vyper.semantics.utils import get_folded_value def pre_typecheck(node: vy_ast.VyperNode) -> None: @@ -231,7 +230,7 @@ def visit_Compare(self, node): node._metadata["folded_value"] = vy_ast.NameConstant.from_node(value=value) def visit_Constant(self, node): - node._metadata["folded_value"] = node + pass def visit_Dict(self, node): for v in node.values: @@ -288,10 +287,3 @@ def visit_IfExp(self, node): self.visit(node.test) self.visit(node.body) self.visit(node.orelse) - - -def get_folded_value(node: vy_ast.VyperNode) -> Optional[vy_ast.VyperNode]: - if isinstance(node, vy_ast.Constant): - return node - - return node._metadata.get("folded_value") diff --git a/vyper/semantics/analysis/utils.py b/vyper/semantics/analysis/utils.py index bfe79b5d61..5b3433cbf0 100644 --- a/vyper/semantics/analysis/utils.py +++ b/vyper/semantics/analysis/utils.py @@ -24,6 +24,7 @@ from vyper.semantics.types.bytestrings import BytesT, StringT from vyper.semantics.types.primitives import AddressT, BoolT, BytesM_T, IntegerT from vyper.semantics.types.subscriptable import DArrayT, SArrayT, TupleT +from vyper.semantics.utils import get_folded_value from vyper.utils import checksum_encode, int_to_fourbytes @@ -643,7 +644,7 @@ def check_constant(node: vy_ast.VyperNode) -> bool: """ Check if the given node is a literal or constant value. """ - if node._metadata.get("folded_value") is not None: + if get_folded_value(node) is not None: return True if isinstance(node, vy_ast.Call): call_type = get_exact_type_from_node(node.func) diff --git a/vyper/semantics/types/subscriptable.py b/vyper/semantics/types/subscriptable.py index 1752f971f1..3809ea9362 100644 --- a/vyper/semantics/types/subscriptable.py +++ b/vyper/semantics/types/subscriptable.py @@ -9,6 +9,7 @@ from vyper.semantics.types.primitives import IntegerT from vyper.semantics.types.shortcuts import UINT256_T from vyper.semantics.types.utils import get_index_value, type_from_annotation +from vyper.semantics.utils import get_folded_value class _SubscriptableT(VyperType): @@ -128,7 +129,7 @@ def validate_index_type(self, node): # TODO break this cycle from vyper.semantics.analysis.utils import validate_expected_type - index = node._metadata.get("folded_value") + index = get_folded_value(node) if isinstance(index, vy_ast.Int): value = index.value if value < 0: @@ -287,7 +288,7 @@ def from_annotation(cls, node: vy_ast.Subscript) -> "DArrayT": node, ) - folded_max_length = node.slice.value.elements[1]._metadata.get("folded_value") + folded_max_length = get_folded_value(node.slice.value.elements[1]) if not isinstance(folded_max_length, vy_ast.Int): raise StructureException( "DynArray must have a max length of integer type, e.g. DynArray[bool, 5]", node diff --git a/vyper/semantics/types/utils.py b/vyper/semantics/types/utils.py index b140451cfa..302019ca13 100644 --- a/vyper/semantics/types/utils.py +++ b/vyper/semantics/types/utils.py @@ -4,6 +4,7 @@ from vyper.semantics.data_locations import DataLocation from vyper.semantics.namespace import get_namespace from vyper.semantics.types.base import VyperType +from vyper.semantics.utils import get_folded_value # TODO maybe this should be merged with .types/base.py @@ -139,7 +140,7 @@ def get_index_value(node: vy_ast.Index) -> int: int Literal integer value. """ - folded_node = node.value._metadata.get("folded_value") + folded_node = get_folded_value(node.value) if not isinstance(folded_node, vy_ast.Int): raise InvalidType("Subscript must be a literal integer", node) diff --git a/vyper/semantics/utils.py b/vyper/semantics/utils.py new file mode 100644 index 0000000000..0ffec4bd10 --- /dev/null +++ b/vyper/semantics/utils.py @@ -0,0 +1,10 @@ +from typing import Optional + +from vyper import ast as vy_ast + + +def get_folded_value(node: vy_ast.VyperNode) -> Optional[vy_ast.VyperNode]: + if isinstance(node, vy_ast.Constant): + return node + + return node._metadata.get("folded_value")