From 7898efeda2d106dafa874eca6dd39b9d9a1c51d0 Mon Sep 17 00:00:00 2001 From: tserg <8017125+tserg@users.noreply.github.com> Date: Sun, 24 Dec 2023 11:13:19 +0800 Subject: [PATCH] remove prefold; add get_folded_value and get_folded_value_maybe --- vyper/ast/nodes.py | 117 +++++++++------------- vyper/ast/nodes.pyi | 3 +- vyper/builtins/_signatures.py | 14 +-- vyper/semantics/analysis/pre_typecheck.py | 23 +++-- vyper/semantics/types/utils.py | 2 +- 5 files changed, 67 insertions(+), 92 deletions(-) diff --git a/vyper/ast/nodes.py b/vyper/ast/nodes.py index 1393bc7d15..2eede3f24e 100644 --- a/vyper/ast/nodes.py +++ b/vyper/ast/nodes.py @@ -375,15 +375,26 @@ def description(self): """ return getattr(self, "_description", type(self).__name__) - def prefold(self) -> Optional["VyperNode"]: + def get_folded_value(self) -> "VyperNode": """ - Attempt to evaluate the content of a node and generate a new node from it, - allowing for values that may be out of bounds during semantics typechecking. + Attempt to get the folded value and cache it on `_metadata["folded_value"]`. + Raises UnfoldableNode if not. + """ + if "folded_value" not in self._metadata: + self._metadata["folded_value"] = self.fold() + return self._metadata["folded_value"] - If a node cannot be prefolded, it should return None. This base method acts - as a catch-call for all inherited classes that do not implement the method. + def get_folded_value_maybe(self) -> Optional["VyperNode"]: """ - return None + Attempt to get the folded value and cache it on `_metadata["folded_value"]`. + Returns None if not. + """ + if "folded_value" not in self._metadata: + try: + self._metadata["folded_value"] = self.fold() + except (UnfoldableNode, VyperException): + return None + return self._metadata["folded_value"] def fold(self) -> "VyperNode": """ @@ -905,8 +916,8 @@ class List(ExprNode): _is_prefoldable = True _translated_fields = {"elts": "elements"} - def prefold(self) -> Optional[ExprNode]: - elements = [e._metadata.get("folded_value") for e in self.elements] + def fold(self) -> Optional[ExprNode]: + elements = [e.get_folded_value_maybe() for e in self.elements] if None not in elements: return type(self).from_node(self, elements=elements) @@ -942,14 +953,6 @@ class UnaryOp(ExprNode): __slots__ = ("op", "operand") _is_prefoldable = True - def prefold(self) -> Optional[ExprNode]: - operand = self.operand._metadata.get("folded_value") - if operand is not None: - value = self.op._op(operand.value) - return type(operand).from_node(self, value=value) - - return None - def fold(self) -> ExprNode: """ Attempt to evaluate the unary operation. @@ -959,14 +962,16 @@ def fold(self) -> ExprNode: Int | Decimal Node representing the result of the evaluation. """ - if isinstance(self.op, Not) and not isinstance(self.operand, NameConstant): + operand = self.operand.get_folded_value_maybe() + + if isinstance(self.op, Not) and not isinstance(operand, NameConstant): raise UnfoldableNode("Node contains invalid field(s) for evaluation") - if isinstance(self.op, USub) and not isinstance(self.operand, (Int, Decimal)): + if isinstance(self.op, USub) and not isinstance(operand, (Int, Decimal)): raise UnfoldableNode("Node contains invalid field(s) for evaluation") - if isinstance(self.op, Invert) and not isinstance(self.operand, Int): + if isinstance(self.op, Invert) and not isinstance(operand, Int): raise UnfoldableNode("Node contains invalid field(s) for evaluation") - value = self.op._op(self.operand.value) + value = self.op._op(operand.value) return type(self.operand).from_node(self, value=value) @@ -998,22 +1003,6 @@ class BinOp(ExprNode): __slots__ = ("left", "op", "right") _is_prefoldable = True - def prefold(self) -> Optional[ExprNode]: - left = self.left._metadata.get("folded_value") - right = self.right._metadata.get("folded_value") - - if None in (left, right): - return None - - # this validation is performed to prevent the compiler from hanging - # on very large shifts and improve the error message for negative - # values. - if isinstance(self.op, (LShift, RShift)) and not (0 <= right.value <= 256): - raise InvalidLiteral("Shift bits must be between 0 and 256", self.right) - - value = self.op._op(left.value, right.value) - return type(left).from_node(self, value=value) - def fold(self) -> ExprNode: """ Attempt to evaluate the arithmetic operation. @@ -1023,12 +1012,18 @@ def fold(self) -> ExprNode: Int | Decimal Node representing the result of the evaluation. """ - left, right = self.left, self.right + left, right = [i.get_folded_value_maybe() for i in (self.left, self.right)] if type(left) is not type(right): raise UnfoldableNode("Node contains invalid field(s) for evaluation") if not isinstance(left, (Int, Decimal)): raise UnfoldableNode("Node contains invalid field(s) for evaluation") + # this validation is performed to prevent the compiler from hanging + # on very large shifts and improve the error message for negative + # values. + if isinstance(self.op, (LShift, RShift)) and not (0 <= right.value <= 256): + raise InvalidLiteral("Shift bits must be between 0 and 256", self.right) + value = self.op._op(left.value, right.value) return type(left).from_node(self, value=value) @@ -1158,14 +1153,6 @@ class BoolOp(ExprNode): __slots__ = ("op", "values") _is_prefoldable = True - def prefold(self) -> Optional[ExprNode]: - values = [i._metadata.get("folded_value") for i in self.values] - if None in values: - return None - - value = self.op._op(values) - return NameConstant.from_node(self, value=value) - def fold(self) -> ExprNode: """ Attempt to evaluate the boolean operation. @@ -1175,13 +1162,12 @@ def fold(self) -> ExprNode: NameConstant Node representing the result of the evaluation. """ - if next((i for i in self.values if not isinstance(i, NameConstant)), None): - raise UnfoldableNode("Node contains invalid field(s) for evaluation") + values = [i.get_folded_value_maybe() for i in self.values] - values = [i.value for i in self.values] - if None in values: + if any(not isinstance(i, NameConstant) for i in values): raise UnfoldableNode("Node contains invalid field(s) for evaluation") + values = [i.value for i in values] value = self.op._op(values) return NameConstant.from_node(self, value=value) @@ -1223,16 +1209,6 @@ def __init__(self, *args, **kwargs): kwargs["right"] = kwargs.pop("comparators")[0] super().__init__(*args, **kwargs) - def prefold(self) -> Optional[ExprNode]: - left = self.left._metadata.get("folded_value") - right = self.right._metadata.get("folded_value") - - if None in (left, right): - return None - - value = self.op._op(left.value, right.value) - return NameConstant.from_node(self, value=value) - def fold(self) -> ExprNode: """ Attempt to evaluate the comparison. @@ -1242,7 +1218,7 @@ def fold(self) -> ExprNode: NameConstant Node representing the result of the evaluation. """ - left, right = self.left, self.right + left, right = [i.get_folded_value_maybe() for i in (self.left, self.right)] if not isinstance(left, Constant): raise UnfoldableNode("Node contains invalid field(s) for evaluation") @@ -1336,15 +1312,6 @@ class Subscript(ExprNode): __slots__ = ("slice", "value") _is_prefoldable = True - def prefold(self) -> Optional[ExprNode]: - slice_ = self.slice.value._metadata.get("folded_value") - value = self.value._metadata.get("folded_value") - - if None in (slice_, value): - return None - - return value.elements[slice_.value] - def fold(self) -> ExprNode: """ Attempt to evaluate the subscript. @@ -1357,12 +1324,18 @@ def fold(self) -> ExprNode: ExprNode Node representing the result of the evaluation. """ - if not isinstance(self.value, List): + slice_ = self.slice.value.get_folded_value_maybe() + value = self.value.get_folded_value_maybe() + + if not isinstance(value, List): raise UnfoldableNode("Subscript object is not a literal list") - elements = self.value.elements + elements = value.elements if len(set([type(i) for i in elements])) > 1: raise UnfoldableNode("List contains multiple node types") - idx = self.slice.get("value.value") + + if not isinstance(slice_, Int): + raise UnfoldableNode("Node contains invalid field(s) for evaluation") + idx = slice_.value if not isinstance(idx, int) or idx < 0 or idx >= len(elements): raise UnfoldableNode("Invalid index value") diff --git a/vyper/ast/nodes.pyi b/vyper/ast/nodes.pyi index e5b62f17c1..a57db549a4 100644 --- a/vyper/ast/nodes.pyi +++ b/vyper/ast/nodes.pyi @@ -26,7 +26,8 @@ class VyperNode: def description(self): ... @classmethod def get_fields(cls: Any) -> set: ... - def prefold(self) -> Optional[VyperNode]: ... + def get_folded_value(self) -> VyperNode: ... + def get_folded_value_maybe(self) -> Optional[VyperNode]: ... def fold(self) -> VyperNode: ... @classmethod def from_node(cls, node: VyperNode, **kwargs: Any) -> Any: ... diff --git a/vyper/builtins/_signatures.py b/vyper/builtins/_signatures.py index 616dd39a07..9629510c79 100644 --- a/vyper/builtins/_signatures.py +++ b/vyper/builtins/_signatures.py @@ -5,7 +5,7 @@ from vyper.ast.validation import validate_call_args from vyper.codegen.expr import Expr from vyper.codegen.ir_node import IRnode -from vyper.exceptions import CompilerPanic, TypeMismatch, UnfoldableNode, VyperException +from vyper.exceptions import CompilerPanic, TypeMismatch, UnfoldableNode from vyper.semantics.analysis.base import Modifiability from vyper.semantics.analysis.utils import ( check_variable_constancy, @@ -127,20 +127,14 @@ def _validate_arg_types(self, node: vy_ast.Call) -> None: # ensures the type can be inferred exactly. get_exact_type_from_node(arg) - def prefold(self, node): - if not hasattr(self, "fold"): - return None - - try: - return self.fold(node) - except (UnfoldableNode, VyperException): - return None - def fetch_call_return(self, node: vy_ast.Call) -> Optional[VyperType]: self._validate_arg_types(node) return self._return_type + def fold(self, node: vy_ast.Call) -> vy_ast.VyperNode: + raise UnfoldableNode(f"{type(self)} cannot be folded") + def infer_arg_types(self, node: vy_ast.Call, expected_return_typ=None) -> list[VyperType]: self._validate_arg_types(node) ret = [expected for (_, expected) in self._inputs] diff --git a/vyper/semantics/analysis/pre_typecheck.py b/vyper/semantics/analysis/pre_typecheck.py index 797fe3bbd3..426b377a9d 100644 --- a/vyper/semantics/analysis/pre_typecheck.py +++ b/vyper/semantics/analysis/pre_typecheck.py @@ -1,4 +1,5 @@ from vyper import ast as vy_ast +from vyper.exceptions import UnfoldableNode, VyperException def get_constants(node: vy_ast.Module) -> dict: @@ -45,11 +46,7 @@ def pre_typecheck(node: vy_ast.Module) -> None: prefold(n, constants) -def prefold(node: vy_ast.VyperNode, constants: dict[str, vy_ast.VyperNode]) -> None: - if getattr(node, "_is_prefoldable", None): - node._metadata["folded_value"] = node.prefold() - return - +def prefold(node: vy_ast.VyperNode, constants: dict[str, vy_ast.VyperNode]): if isinstance(node, vy_ast.Name): var_name = node.id if var_name in constants: @@ -63,6 +60,16 @@ def prefold(node: vy_ast.VyperNode, constants: dict[str, vy_ast.VyperNode]) -> N func_name = node.func.id call_type = DISPATCH_TABLE.get(func_name) - if call_type: - node._metadata["folded_value"] = call_type.prefold(node) # type: ignore - return + if call_type and hasattr(call_type, "fold"): + try: + node._metadata["folded_value"] = call_type.fold(node) + return + except (UnfoldableNode, VyperException): + pass + + if getattr(node, "_is_prefoldable", None): + try: + # call `get_folded_value`` for its side effects + node.get_folded_value() + except (UnfoldableNode, VyperException): + pass diff --git a/vyper/semantics/types/utils.py b/vyper/semantics/types/utils.py index 0981265dc2..137d8d56f5 100644 --- a/vyper/semantics/types/utils.py +++ b/vyper/semantics/types/utils.py @@ -179,7 +179,7 @@ def get_index_value(node: vy_ast.Index) -> int: # TODO: revisit this! from vyper.semantics.analysis.utils import get_possible_types_from_node - value = node.value._metadata.get("folded_value") + value = node.value.get_folded_value_maybe() if not isinstance(value, vy_ast.Int): if hasattr(node, "value"): # even though the subscript is an invalid type, first check if it's a valid _something_