Skip to content

Commit

Permalink
remove prefold; add get_folded_value and get_folded_value_maybe
Browse files Browse the repository at this point in the history
  • Loading branch information
tserg committed Dec 24, 2023
1 parent be4be41 commit 7898efe
Show file tree
Hide file tree
Showing 5 changed files with 67 additions and 92 deletions.
117 changes: 45 additions & 72 deletions vyper/ast/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,15 +375,26 @@ def description(self):
"""
return getattr(self, "_description", type(self).__name__)

def prefold(self) -> Optional["VyperNode"]:
def get_folded_value(self) -> "VyperNode":
"""
Attempt to evaluate the content of a node and generate a new node from it,
allowing for values that may be out of bounds during semantics typechecking.
Attempt to get the folded value and cache it on `_metadata["folded_value"]`.
Raises UnfoldableNode if not.
"""
if "folded_value" not in self._metadata:
self._metadata["folded_value"] = self.fold()
return self._metadata["folded_value"]

If a node cannot be prefolded, it should return None. This base method acts
as a catch-call for all inherited classes that do not implement the method.
def get_folded_value_maybe(self) -> Optional["VyperNode"]:
"""
return None
Attempt to get the folded value and cache it on `_metadata["folded_value"]`.
Returns None if not.
"""
if "folded_value" not in self._metadata:
try:
self._metadata["folded_value"] = self.fold()
except (UnfoldableNode, VyperException):
return None
return self._metadata["folded_value"]

def fold(self) -> "VyperNode":
"""
Expand Down Expand Up @@ -905,8 +916,8 @@ class List(ExprNode):
_is_prefoldable = True
_translated_fields = {"elts": "elements"}

def prefold(self) -> Optional[ExprNode]:
elements = [e._metadata.get("folded_value") for e in self.elements]
def fold(self) -> Optional[ExprNode]:
elements = [e.get_folded_value_maybe() for e in self.elements]
if None not in elements:
return type(self).from_node(self, elements=elements)

Expand Down Expand Up @@ -942,14 +953,6 @@ class UnaryOp(ExprNode):
__slots__ = ("op", "operand")
_is_prefoldable = True

def prefold(self) -> Optional[ExprNode]:
operand = self.operand._metadata.get("folded_value")
if operand is not None:
value = self.op._op(operand.value)
return type(operand).from_node(self, value=value)

return None

def fold(self) -> ExprNode:
"""
Attempt to evaluate the unary operation.
Expand All @@ -959,14 +962,16 @@ def fold(self) -> ExprNode:
Int | Decimal
Node representing the result of the evaluation.
"""
if isinstance(self.op, Not) and not isinstance(self.operand, NameConstant):
operand = self.operand.get_folded_value_maybe()

if isinstance(self.op, Not) and not isinstance(operand, NameConstant):
raise UnfoldableNode("Node contains invalid field(s) for evaluation")
if isinstance(self.op, USub) and not isinstance(self.operand, (Int, Decimal)):
if isinstance(self.op, USub) and not isinstance(operand, (Int, Decimal)):
raise UnfoldableNode("Node contains invalid field(s) for evaluation")
if isinstance(self.op, Invert) and not isinstance(self.operand, Int):
if isinstance(self.op, Invert) and not isinstance(operand, Int):
raise UnfoldableNode("Node contains invalid field(s) for evaluation")

value = self.op._op(self.operand.value)
value = self.op._op(operand.value)
return type(self.operand).from_node(self, value=value)


Expand Down Expand Up @@ -998,22 +1003,6 @@ class BinOp(ExprNode):
__slots__ = ("left", "op", "right")
_is_prefoldable = True

def prefold(self) -> Optional[ExprNode]:
left = self.left._metadata.get("folded_value")
right = self.right._metadata.get("folded_value")

if None in (left, right):
return None

# this validation is performed to prevent the compiler from hanging
# on very large shifts and improve the error message for negative
# values.
if isinstance(self.op, (LShift, RShift)) and not (0 <= right.value <= 256):
raise InvalidLiteral("Shift bits must be between 0 and 256", self.right)

value = self.op._op(left.value, right.value)
return type(left).from_node(self, value=value)

def fold(self) -> ExprNode:
"""
Attempt to evaluate the arithmetic operation.
Expand All @@ -1023,12 +1012,18 @@ def fold(self) -> ExprNode:
Int | Decimal
Node representing the result of the evaluation.
"""
left, right = self.left, self.right
left, right = [i.get_folded_value_maybe() for i in (self.left, self.right)]
if type(left) is not type(right):
raise UnfoldableNode("Node contains invalid field(s) for evaluation")
if not isinstance(left, (Int, Decimal)):
raise UnfoldableNode("Node contains invalid field(s) for evaluation")

# this validation is performed to prevent the compiler from hanging
# on very large shifts and improve the error message for negative
# values.
if isinstance(self.op, (LShift, RShift)) and not (0 <= right.value <= 256):
raise InvalidLiteral("Shift bits must be between 0 and 256", self.right)

value = self.op._op(left.value, right.value)
return type(left).from_node(self, value=value)

Expand Down Expand Up @@ -1158,14 +1153,6 @@ class BoolOp(ExprNode):
__slots__ = ("op", "values")
_is_prefoldable = True

def prefold(self) -> Optional[ExprNode]:
values = [i._metadata.get("folded_value") for i in self.values]
if None in values:
return None

value = self.op._op(values)
return NameConstant.from_node(self, value=value)

def fold(self) -> ExprNode:
"""
Attempt to evaluate the boolean operation.
Expand All @@ -1175,13 +1162,12 @@ def fold(self) -> ExprNode:
NameConstant
Node representing the result of the evaluation.
"""
if next((i for i in self.values if not isinstance(i, NameConstant)), None):
raise UnfoldableNode("Node contains invalid field(s) for evaluation")
values = [i.get_folded_value_maybe() for i in self.values]

values = [i.value for i in self.values]
if None in values:
if any(not isinstance(i, NameConstant) for i in values):
raise UnfoldableNode("Node contains invalid field(s) for evaluation")

values = [i.value for i in values]
value = self.op._op(values)
return NameConstant.from_node(self, value=value)

Expand Down Expand Up @@ -1223,16 +1209,6 @@ def __init__(self, *args, **kwargs):
kwargs["right"] = kwargs.pop("comparators")[0]
super().__init__(*args, **kwargs)

def prefold(self) -> Optional[ExprNode]:
left = self.left._metadata.get("folded_value")
right = self.right._metadata.get("folded_value")

if None in (left, right):
return None

value = self.op._op(left.value, right.value)
return NameConstant.from_node(self, value=value)

def fold(self) -> ExprNode:
"""
Attempt to evaluate the comparison.
Expand All @@ -1242,7 +1218,7 @@ def fold(self) -> ExprNode:
NameConstant
Node representing the result of the evaluation.
"""
left, right = self.left, self.right
left, right = [i.get_folded_value_maybe() for i in (self.left, self.right)]
if not isinstance(left, Constant):
raise UnfoldableNode("Node contains invalid field(s) for evaluation")

Expand Down Expand Up @@ -1336,15 +1312,6 @@ class Subscript(ExprNode):
__slots__ = ("slice", "value")
_is_prefoldable = True

def prefold(self) -> Optional[ExprNode]:
slice_ = self.slice.value._metadata.get("folded_value")
value = self.value._metadata.get("folded_value")

if None in (slice_, value):
return None

return value.elements[slice_.value]

def fold(self) -> ExprNode:
"""
Attempt to evaluate the subscript.
Expand All @@ -1357,12 +1324,18 @@ def fold(self) -> ExprNode:
ExprNode
Node representing the result of the evaluation.
"""
if not isinstance(self.value, List):
slice_ = self.slice.value.get_folded_value_maybe()
value = self.value.get_folded_value_maybe()

if not isinstance(value, List):
raise UnfoldableNode("Subscript object is not a literal list")
elements = self.value.elements
elements = value.elements
if len(set([type(i) for i in elements])) > 1:
raise UnfoldableNode("List contains multiple node types")
idx = self.slice.get("value.value")

if not isinstance(slice_, Int):
raise UnfoldableNode("Node contains invalid field(s) for evaluation")
idx = slice_.value
if not isinstance(idx, int) or idx < 0 or idx >= len(elements):
raise UnfoldableNode("Invalid index value")

Expand Down
3 changes: 2 additions & 1 deletion vyper/ast/nodes.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ class VyperNode:
def description(self): ...
@classmethod
def get_fields(cls: Any) -> set: ...
def prefold(self) -> Optional[VyperNode]: ...
def get_folded_value(self) -> VyperNode: ...
def get_folded_value_maybe(self) -> Optional[VyperNode]: ...
def fold(self) -> VyperNode: ...
@classmethod
def from_node(cls, node: VyperNode, **kwargs: Any) -> Any: ...
Expand Down
14 changes: 4 additions & 10 deletions vyper/builtins/_signatures.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from vyper.ast.validation import validate_call_args
from vyper.codegen.expr import Expr
from vyper.codegen.ir_node import IRnode
from vyper.exceptions import CompilerPanic, TypeMismatch, UnfoldableNode, VyperException
from vyper.exceptions import CompilerPanic, TypeMismatch, UnfoldableNode
from vyper.semantics.analysis.base import Modifiability
from vyper.semantics.analysis.utils import (
check_variable_constancy,
Expand Down Expand Up @@ -127,20 +127,14 @@ def _validate_arg_types(self, node: vy_ast.Call) -> None:
# ensures the type can be inferred exactly.
get_exact_type_from_node(arg)

def prefold(self, node):
if not hasattr(self, "fold"):
return None

try:
return self.fold(node)
except (UnfoldableNode, VyperException):
return None

def fetch_call_return(self, node: vy_ast.Call) -> Optional[VyperType]:
self._validate_arg_types(node)

return self._return_type

def fold(self, node: vy_ast.Call) -> vy_ast.VyperNode:
raise UnfoldableNode(f"{type(self)} cannot be folded")

def infer_arg_types(self, node: vy_ast.Call, expected_return_typ=None) -> list[VyperType]:
self._validate_arg_types(node)
ret = [expected for (_, expected) in self._inputs]
Expand Down
23 changes: 15 additions & 8 deletions vyper/semantics/analysis/pre_typecheck.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from vyper import ast as vy_ast
from vyper.exceptions import UnfoldableNode, VyperException


def get_constants(node: vy_ast.Module) -> dict:
Expand Down Expand Up @@ -45,11 +46,7 @@ def pre_typecheck(node: vy_ast.Module) -> None:
prefold(n, constants)


def prefold(node: vy_ast.VyperNode, constants: dict[str, vy_ast.VyperNode]) -> None:
if getattr(node, "_is_prefoldable", None):
node._metadata["folded_value"] = node.prefold()
return

def prefold(node: vy_ast.VyperNode, constants: dict[str, vy_ast.VyperNode]):
if isinstance(node, vy_ast.Name):
var_name = node.id
if var_name in constants:
Expand All @@ -63,6 +60,16 @@ def prefold(node: vy_ast.VyperNode, constants: dict[str, vy_ast.VyperNode]) -> N
func_name = node.func.id

call_type = DISPATCH_TABLE.get(func_name)
if call_type:
node._metadata["folded_value"] = call_type.prefold(node) # type: ignore
return
if call_type and hasattr(call_type, "fold"):
try:
node._metadata["folded_value"] = call_type.fold(node)
return
except (UnfoldableNode, VyperException):
pass

if getattr(node, "_is_prefoldable", None):
try:
# call `get_folded_value`` for its side effects
node.get_folded_value()
except (UnfoldableNode, VyperException):
pass
2 changes: 1 addition & 1 deletion vyper/semantics/types/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ def get_index_value(node: vy_ast.Index) -> int:
# TODO: revisit this!
from vyper.semantics.analysis.utils import get_possible_types_from_node

value = node.value._metadata.get("folded_value")
value = node.value.get_folded_value_maybe()
if not isinstance(value, vy_ast.Int):
if hasattr(node, "value"):
# even though the subscript is an invalid type, first check if it's a valid _something_
Expand Down

0 comments on commit 7898efe

Please sign in to comment.