From 3cf993b2097f6403cac95e2c5a27ddbfa3b29e4c Mon Sep 17 00:00:00 2001 From: tserg <8017125+tserg@users.noreply.github.com> Date: Fri, 29 Dec 2023 11:32:46 +0800 Subject: [PATCH] remove maybe variant; add is_literal_value and has_folded_value properties --- vyper/ast/nodes.py | 78 ++++++++--------------- vyper/ast/nodes.pyi | 7 +- vyper/builtins/_signatures.py | 2 +- vyper/builtins/functions.py | 42 ++++++------ vyper/codegen/expr.py | 6 +- vyper/semantics/analysis/local.py | 16 +++-- vyper/semantics/analysis/pre_typecheck.py | 6 +- vyper/semantics/analysis/utils.py | 5 +- vyper/semantics/types/subscriptable.py | 5 +- vyper/semantics/types/utils.py | 2 +- 10 files changed, 80 insertions(+), 89 deletions(-) diff --git a/vyper/ast/nodes.py b/vyper/ast/nodes.py index 25da0714ee..12341aa076 100644 --- a/vyper/ast/nodes.py +++ b/vyper/ast/nodes.py @@ -375,25 +375,33 @@ def description(self): """ return getattr(self, "_description", type(self).__name__) - def get_folded_value_throwing(self) -> "VyperNode": + @property + def is_literal_value(self): """ - Attempt to get the folded value and cache it on `_metadata["folded_value"]`. - Raises UnfoldableNode if not. + Property method to check if the node is a literal value. """ - if "folded_value" not in self._metadata: - self._metadata["folded_value"] = self.fold() - return self._metadata["folded_value"] + return check_literal(self) - def get_folded_value_maybe(self) -> Optional["VyperNode"]: + @property + def has_folded_value(self): + """ + Property method to check if the node has a folded value. + """ + return "folded_value" in self._metadata + + def get_folded_value(self) -> "VyperNode": """ Attempt to get the folded value and cache it on `_metadata["folded_value"]`. - Returns None if not. + For constant nodes, the node is directly returned as the folded value without caching + to the metadata. + + Raises UnfoldableNode if not. """ + if check_literal(self): + return self + if "folded_value" not in self._metadata: - try: - self._metadata["folded_value"] = self.fold() - except (UnfoldableNode, VyperException): - return None + self._metadata["folded_value"] = self.fold() return self._metadata["folded_value"] def fold(self) -> "VyperNode": @@ -778,12 +786,6 @@ class Constant(ExprNode): def __init__(self, parent: Optional["VyperNode"] = None, **kwargs: dict): super().__init__(parent, **kwargs) - def get_folded_value_throwing(self) -> "VyperNode": - return self - - def get_folded_value_maybe(self) -> Optional["VyperNode"]: - return self - class Num(Constant): # inherited class for all numeric constant node types @@ -934,21 +936,9 @@ class List(ExprNode): _translated_fields = {"elts": "elements"} def fold(self) -> Optional[ExprNode]: - elements = [e.get_folded_value_throwing() for e in self.elements] + elements = [e.get_folded_value() for e in self.elements] return type(self).from_node(self, elements=elements) - def get_folded_value_throwing(self) -> "VyperNode": - if check_literal(self): - return self - - return super().get_folded_value_throwing() - - def get_folded_value_maybe(self) -> Optional["VyperNode"]: - if check_literal(self): - return self - - return super().get_folded_value_maybe() - class Tuple(ExprNode): __slots__ = ("elements",) @@ -960,21 +950,9 @@ def validate(self): raise InvalidLiteral("Cannot have an empty tuple", self) def fold(self) -> Optional[ExprNode]: - elements = [e.get_folded_value_throwing() for e in self.elements] + elements = [e.get_folded_value() for e in self.elements] return type(self).from_node(self, elements=elements) - def get_folded_value_throwing(self) -> "VyperNode": - if check_literal(self): - return self - - return super().get_folded_value_throwing() - - def get_folded_value_maybe(self) -> Optional["VyperNode"]: - if check_literal(self): - return self - - return super().get_folded_value_maybe() - class NameConstant(Constant): __slots__ = () @@ -1005,7 +983,7 @@ def fold(self) -> ExprNode: Int | Decimal Node representing the result of the evaluation. """ - operand = self.operand.get_folded_value_throwing() + operand = self.operand.get_folded_value() if isinstance(self.op, Not) and not isinstance(operand, NameConstant): raise UnfoldableNode("Node contains invalid field(s) for evaluation") @@ -1055,7 +1033,7 @@ def fold(self) -> ExprNode: Int | Decimal Node representing the result of the evaluation. """ - left, right = [i.get_folded_value_throwing() for i in (self.left, self.right)] + left, right = [i.get_folded_value() for i in (self.left, self.right)] if type(left) is not type(right): raise UnfoldableNode("Node contains invalid field(s) for evaluation") if not isinstance(left, (Int, Decimal)): @@ -1205,7 +1183,7 @@ def fold(self) -> ExprNode: NameConstant Node representing the result of the evaluation. """ - values = [i.get_folded_value_throwing() for i in self.values] + values = [i.get_folded_value() for i in self.values] if any(not isinstance(i, NameConstant) for i in values): raise UnfoldableNode("Node contains invalid field(s) for evaluation") @@ -1261,7 +1239,7 @@ def fold(self) -> ExprNode: NameConstant Node representing the result of the evaluation. """ - left, right = [i.get_folded_value_throwing() for i in (self.left, self.right)] + left, right = [i.get_folded_value() for i in (self.left, self.right)] if not isinstance(left, Constant): raise UnfoldableNode("Node contains invalid field(s) for evaluation") @@ -1367,8 +1345,8 @@ def fold(self) -> ExprNode: ExprNode Node representing the result of the evaluation. """ - slice_ = self.slice.value.get_folded_value_throwing() - value = self.value.get_folded_value_throwing() + slice_ = self.slice.value.get_folded_value() + value = self.value.get_folded_value() if not isinstance(value, List): raise UnfoldableNode("Subscript object is not a literal list") diff --git a/vyper/ast/nodes.pyi b/vyper/ast/nodes.pyi index 7531a6d02c..fc14fd810c 100644 --- a/vyper/ast/nodes.pyi +++ b/vyper/ast/nodes.pyi @@ -24,10 +24,13 @@ class VyperNode: def __eq__(self, other: Any) -> Any: ... @property def description(self): ... + @property + def is_literal_value(self): ... + @property + def has_folded_value(self): ... @classmethod def get_fields(cls: Any) -> set: ... - def get_folded_value_throwing(self) -> VyperNode: ... - def get_folded_value_maybe(self) -> Optional[VyperNode]: ... + def get_folded_value(self) -> VyperNode: ... def fold(self) -> VyperNode: ... @classmethod def from_node(cls, node: VyperNode, **kwargs: Any) -> Any: ... diff --git a/vyper/builtins/_signatures.py b/vyper/builtins/_signatures.py index 0c65a0756f..3257911c36 100644 --- a/vyper/builtins/_signatures.py +++ b/vyper/builtins/_signatures.py @@ -34,7 +34,7 @@ def process_arg(arg, expected_arg_type, context): def process_kwarg(kwarg_node, kwarg_settings, expected_kwarg_type, context): if kwarg_settings.require_literal: - return kwarg_node.get_folded_value_throwing().value + return kwarg_node.get_folded_value().value return process_arg(kwarg_node, expected_kwarg_type, context) diff --git a/vyper/builtins/functions.py b/vyper/builtins/functions.py index 2843b567e1..0f4b4a5500 100644 --- a/vyper/builtins/functions.py +++ b/vyper/builtins/functions.py @@ -136,7 +136,7 @@ class Floor(BuiltinFunctionT): def fold(self, node): validate_call_args(node, 1) - value = node.args[0].get_folded_value_throwing() + value = node.args[0].get_folded_value() if not isinstance(value, vy_ast.Decimal): raise UnfoldableNode @@ -167,7 +167,7 @@ class Ceil(BuiltinFunctionT): def fold(self, node): validate_call_args(node, 1) - value = node.args[0].get_folded_value_throwing() + value = node.args[0].get_folded_value() if not isinstance(value, vy_ast.Decimal): raise UnfoldableNode @@ -461,7 +461,7 @@ class Len(BuiltinFunctionT): def fold(self, node): validate_call_args(node, 1) - arg = node.args[0].get_folded_value_throwing() + arg = node.args[0].get_folded_value() if isinstance(arg, (vy_ast.Str, vy_ast.Bytes)): length = len(arg.value) elif isinstance(arg, vy_ast.Hex): @@ -598,7 +598,7 @@ class Keccak256(BuiltinFunctionT): def fold(self, node): validate_call_args(node, 1) - value = node.args[0].get_folded_value_throwing() + value = node.args[0].get_folded_value() if isinstance(value, vy_ast.Bytes): value = value.value elif isinstance(value, vy_ast.Str): @@ -646,7 +646,7 @@ class Sha256(BuiltinFunctionT): def fold(self, node): validate_call_args(node, 1) - value = node.args[0].get_folded_value_throwing() + value = node.args[0].get_folded_value() if isinstance(value, vy_ast.Bytes): value = value.value elif isinstance(value, vy_ast.Str): @@ -720,7 +720,7 @@ class MethodID(FoldedFunctionT): def fold(self, node): validate_call_args(node, 1, ["output_type"]) - value = node.args[0].get_folded_value_throwing() + value = node.args[0].get_folded_value() if not isinstance(value, vy_ast.Str): raise InvalidType("method id must be given as a literal string", node.args[0]) if " " in value.value: @@ -980,7 +980,7 @@ class AsWeiValue(BuiltinFunctionT): } def get_denomination(self, node): - value = node.args[1].get_folded_value_throwing() + value = node.args[1].get_folded_value() if not isinstance(value, vy_ast.Str): raise ArgumentException( "Wei denomination must be given as a literal string", node.args[1] @@ -996,7 +996,7 @@ def fold(self, node): validate_call_args(node, 2) denom = self.get_denomination(node) - value = node.args[0].get_folded_value_throwing() + value = node.args[0].get_folded_value() if not isinstance(value, (vy_ast.Decimal, vy_ast.Int)): raise UnfoldableNode value = value.value @@ -1082,10 +1082,10 @@ def fetch_call_return(self, node): outsize = kwargz.get("max_outsize") if outsize is not None: - outsize = outsize.get_folded_value_throwing() + outsize = outsize.get_folded_value() revert_on_failure = kwargz.get("revert_on_failure") if revert_on_failure is not None: - revert_on_failure = revert_on_failure.get_folded_value_throwing() + revert_on_failure = revert_on_failure.get_folded_value() revert_on_failure = revert_on_failure.value if revert_on_failure is not None else True if outsize is None or outsize.value == 0: @@ -1355,7 +1355,7 @@ def fold(self, node): self.__class__._warned = True validate_call_args(node, 2) - values = [i.get_folded_value_throwing() for i in node.args] + values = [i.get_folded_value() for i in node.args] for val in values: if not isinstance(val, vy_ast.Int): raise UnfoldableNode @@ -1380,7 +1380,7 @@ def fold(self, node): self.__class__._warned = True validate_call_args(node, 2) - values = [i.get_folded_value_throwing() for i in node.args] + values = [i.get_folded_value() for i in node.args] for val in values: if not isinstance(val, vy_ast.Int): raise UnfoldableNode @@ -1405,7 +1405,7 @@ def fold(self, node): self.__class__._warned = True validate_call_args(node, 2) - values = [i.get_folded_value_throwing() for i in node.args] + values = [i.get_folded_value() for i in node.args] for val in values: if not isinstance(val, vy_ast.Int): raise UnfoldableNode @@ -1430,7 +1430,7 @@ def fold(self, node): self.__class__._warned = True validate_call_args(node, 1) - value = node.args[0].get_folded_value_throwing() + value = node.args[0].get_folded_value() if not isinstance(value, vy_ast.Int): raise UnfoldableNode @@ -1456,7 +1456,7 @@ def fold(self, node): self.__class__._warned = True validate_call_args(node, 2) - args = [i.get_folded_value_throwing() for i in node.args] + args = [i.get_folded_value() for i in node.args] if any(not isinstance(i, vy_ast.Int) for i in args): raise UnfoldableNode value, shift = [i.value for i in args] @@ -1503,7 +1503,7 @@ class _AddMulMod(BuiltinFunctionT): def fold(self, node): validate_call_args(node, 3) - args = [i.get_folded_value_throwing() for i in node.args] + args = [i.get_folded_value() for i in node.args] if isinstance(args[2], vy_ast.Int) and args[2].value == 0: raise ZeroDivisionException("Modulo by 0", node.args[2]) for arg in args: @@ -1544,7 +1544,7 @@ class PowMod256(BuiltinFunctionT): def fold(self, node): validate_call_args(node, 2) - values = [i.get_folded_value_throwing() for i in node.args] + values = [i.get_folded_value() for i in node.args] if any(not isinstance(i, vy_ast.Int) for i in values): raise UnfoldableNode @@ -1565,7 +1565,7 @@ class Abs(BuiltinFunctionT): def fold(self, node): validate_call_args(node, 1) - value = node.args[0].get_folded_value_throwing() + value = node.args[0].get_folded_value() if not isinstance(value, vy_ast.Int): raise UnfoldableNode @@ -2005,8 +2005,8 @@ class _MinMax(BuiltinFunctionT): def fold(self, node): validate_call_args(node, 2) - left = node.args[0].get_folded_value_throwing() - right = node.args[1].get_folded_value_throwing() + left = node.args[0].get_folded_value() + right = node.args[1].get_folded_value() if not isinstance(left, type(right)): raise UnfoldableNode if not isinstance(left, (vy_ast.Decimal, vy_ast.Int)): @@ -2082,7 +2082,7 @@ def fetch_call_return(self, node): def fold(self, node): validate_call_args(node, 1) - value = node.args[0].get_folded_value_throwing() + value = node.args[0].get_folded_value() if not isinstance(value, vy_ast.Int): raise UnfoldableNode diff --git a/vyper/codegen/expr.py b/vyper/codegen/expr.py index 27266577a0..be7a69de77 100644 --- a/vyper/codegen/expr.py +++ b/vyper/codegen/expr.py @@ -71,7 +71,7 @@ class Expr: def __init__(self, node, context): if isinstance(node, vy_ast.VyperNode): - node = node._metadata.get("folded_value", node) + node = node.get_folded_value() if node.has_folded_value else node self.expr = node self.context = context @@ -193,7 +193,9 @@ def parse_Name(self): # using the folded value metadata assert isinstance(varinfo.typ, StructT) value_node = varinfo.decl_node.value - value_node = value_node._metadata.get("folded_value", value_node) + value_node = ( + value_node.get_folded_value() if value_node.has_folded_value else value_node + ) return Expr.parse_value_expr(value_node, self.context) assert varinfo.modifiability == Modifiability.IMMUTABLE, "not an immutable!" diff --git a/vyper/semantics/analysis/local.py b/vyper/semantics/analysis/local.py index 417e9e7018..f6508c3032 100644 --- a/vyper/semantics/analysis/local.py +++ b/vyper/semantics/analysis/local.py @@ -360,7 +360,7 @@ def visit_For(self, node): else: # iteration over a variable or literal list - iter_val = node.iter.get_folded_value_maybe() + iter_val = node.iter.get_folded_value() if node.iter.has_folded_value else node.iter if isinstance(iter_val, vy_ast.List) and len(iter_val.elements) == 0: raise StructureException("For loop must have at least 1 iteration", node.iter) @@ -549,10 +549,10 @@ def visit(self, node, typ): node._metadata["type"] = typ # validate and annotate folded value - folded_value = node._metadata.get("folded_value") - if folded_value: - validate_expected_type(folded_value, typ) - folded_value._metadata["type"] = typ + if node.has_folded_value: + folded_node = node.get_folded_value() + validate_expected_type(folded_node, typ) + folded_node._metadata["type"] = typ def visit_Attribute(self, node: vy_ast.Attribute, typ: VyperType) -> None: _validate_msg_data_attribute(node) @@ -769,10 +769,9 @@ def _analyse_range_call(node: vy_ast.Call) -> list[VyperType]: if not type_list: raise TypeMismatch("Iterator values are of different types", node) - folded_start, folded_end = [i.get_folded_value_maybe() for i in (start, end)] if "bound" in kwargs: bound = kwargs["bound"] - folded_bound = bound.get_folded_value_maybe() + folded_bound = bound.get_folded_value() if bound.has_folded_value else bound if not isinstance(folded_bound, vy_ast.Num): raise StateAccessViolation("Bound must be a literal", bound) if folded_bound.value <= 0: @@ -781,6 +780,9 @@ def _analyse_range_call(node: vy_ast.Call) -> list[VyperType]: error = "Please remove the `bound=` kwarg when using range with constants" raise StructureException(error, bound) else: + folded_start, folded_end = [ + i.get_folded_value() if i.has_folded_value else i for i in (start, end) + ] for original_arg, folded_arg in zip([start, end], [folded_start, folded_end]): if not isinstance(folded_arg, vy_ast.Num): error = "Value must be a literal integer, unless a bound is specified" diff --git a/vyper/semantics/analysis/pre_typecheck.py b/vyper/semantics/analysis/pre_typecheck.py index b89c1c6759..3e1d514be4 100644 --- a/vyper/semantics/analysis/pre_typecheck.py +++ b/vyper/semantics/analysis/pre_typecheck.py @@ -22,7 +22,7 @@ def get_constants(node: vy_ast.Module) -> dict: prefold(n, constants) try: - val = c.value.get_folded_value_throwing() + val = c.value.get_folded_value() # note that if a constant is redefined, its value will be overwritten, # but it is okay because the syntax error is handled downstream @@ -70,9 +70,9 @@ def prefold(node: vy_ast.VyperNode, constants: dict[str, vy_ast.VyperNode]): pass if getattr(node, "_is_prefoldable", None): - # call `get_folded_value_throwing` for its side effects and allow all + # call `get_folded_value` for its side effects and allow all # exceptions other than `UnfoldableNode` to raise try: - node.get_folded_value_throwing() + node.get_folded_value() except UnfoldableNode: pass diff --git a/vyper/semantics/analysis/utils.py b/vyper/semantics/analysis/utils.py index ebc4f27e84..f20e6fa903 100644 --- a/vyper/semantics/analysis/utils.py +++ b/vyper/semantics/analysis/utils.py @@ -627,7 +627,10 @@ def check_modifiability(node: vy_ast.VyperNode, modifiability: Modifiability) -> """ Check if the given node is not more modifiable than the given modifiability. """ - if node.get_folded_value_maybe(): + if node.is_literal_value: + return True + + if node.has_folded_value: return True if isinstance(node, (vy_ast.BinOp, vy_ast.Compare)): diff --git a/vyper/semantics/types/subscriptable.py b/vyper/semantics/types/subscriptable.py index 5e1154416a..c3c86e12ad 100644 --- a/vyper/semantics/types/subscriptable.py +++ b/vyper/semantics/types/subscriptable.py @@ -287,7 +287,10 @@ def from_annotation(cls, node: vy_ast.Subscript) -> "DArrayT": ): raise StructureException(err_msg, node.slice) - length_node = node.slice.value.elements[1].get_folded_value_maybe() + length_node = node.slice.value.elements[1] + length_node = ( + length_node.get_folded_value() if length_node.has_folded_value else length_node + ) if not isinstance(length_node, vy_ast.Int): raise StructureException(err_msg, length_node) diff --git a/vyper/semantics/types/utils.py b/vyper/semantics/types/utils.py index 137d8d56f5..a696682e0e 100644 --- a/vyper/semantics/types/utils.py +++ b/vyper/semantics/types/utils.py @@ -179,7 +179,7 @@ def get_index_value(node: vy_ast.Index) -> int: # TODO: revisit this! from vyper.semantics.analysis.utils import get_possible_types_from_node - value = node.value.get_folded_value_maybe() + value = node.value.get_folded_value() if node.value.has_folded_value else node.value if not isinstance(value, vy_ast.Int): if hasattr(node, "value"): # even though the subscript is an invalid type, first check if it's a valid _something_