Skip to content

Commit

Permalink
remove maybe variant; add is_literal_value and has_folded_value prope…
Browse files Browse the repository at this point in the history
…rties
  • Loading branch information
tserg committed Dec 29, 2023
1 parent d6f79e8 commit 3cf993b
Show file tree
Hide file tree
Showing 10 changed files with 80 additions and 89 deletions.
78 changes: 28 additions & 50 deletions vyper/ast/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,25 +375,33 @@ def description(self):
"""
return getattr(self, "_description", type(self).__name__)

def get_folded_value_throwing(self) -> "VyperNode":
@property
def is_literal_value(self):
"""
Attempt to get the folded value and cache it on `_metadata["folded_value"]`.
Raises UnfoldableNode if not.
Property method to check if the node is a literal value.
"""
if "folded_value" not in self._metadata:
self._metadata["folded_value"] = self.fold()
return self._metadata["folded_value"]
return check_literal(self)

def get_folded_value_maybe(self) -> Optional["VyperNode"]:
@property
def has_folded_value(self):
"""
Property method to check if the node has a folded value.
"""
return "folded_value" in self._metadata

def get_folded_value(self) -> "VyperNode":
"""
Attempt to get the folded value and cache it on `_metadata["folded_value"]`.
Returns None if not.
For constant nodes, the node is directly returned as the folded value without caching
to the metadata.
Raises UnfoldableNode if not.
"""
if check_literal(self):
return self

if "folded_value" not in self._metadata:
try:
self._metadata["folded_value"] = self.fold()
except (UnfoldableNode, VyperException):
return None
self._metadata["folded_value"] = self.fold()
return self._metadata["folded_value"]

def fold(self) -> "VyperNode":
Expand Down Expand Up @@ -778,12 +786,6 @@ class Constant(ExprNode):
def __init__(self, parent: Optional["VyperNode"] = None, **kwargs: dict):
super().__init__(parent, **kwargs)

def get_folded_value_throwing(self) -> "VyperNode":
return self

def get_folded_value_maybe(self) -> Optional["VyperNode"]:
return self


class Num(Constant):
# inherited class for all numeric constant node types
Expand Down Expand Up @@ -934,21 +936,9 @@ class List(ExprNode):
_translated_fields = {"elts": "elements"}

def fold(self) -> Optional[ExprNode]:
elements = [e.get_folded_value_throwing() for e in self.elements]
elements = [e.get_folded_value() for e in self.elements]
return type(self).from_node(self, elements=elements)

def get_folded_value_throwing(self) -> "VyperNode":
if check_literal(self):
return self

return super().get_folded_value_throwing()

def get_folded_value_maybe(self) -> Optional["VyperNode"]:
if check_literal(self):
return self

return super().get_folded_value_maybe()


class Tuple(ExprNode):
__slots__ = ("elements",)
Expand All @@ -960,21 +950,9 @@ def validate(self):
raise InvalidLiteral("Cannot have an empty tuple", self)

def fold(self) -> Optional[ExprNode]:
elements = [e.get_folded_value_throwing() for e in self.elements]
elements = [e.get_folded_value() for e in self.elements]
return type(self).from_node(self, elements=elements)

def get_folded_value_throwing(self) -> "VyperNode":
if check_literal(self):
return self

return super().get_folded_value_throwing()

def get_folded_value_maybe(self) -> Optional["VyperNode"]:
if check_literal(self):
return self

return super().get_folded_value_maybe()


class NameConstant(Constant):
__slots__ = ()
Expand Down Expand Up @@ -1005,7 +983,7 @@ def fold(self) -> ExprNode:
Int | Decimal
Node representing the result of the evaluation.
"""
operand = self.operand.get_folded_value_throwing()
operand = self.operand.get_folded_value()

if isinstance(self.op, Not) and not isinstance(operand, NameConstant):
raise UnfoldableNode("Node contains invalid field(s) for evaluation")
Expand Down Expand Up @@ -1055,7 +1033,7 @@ def fold(self) -> ExprNode:
Int | Decimal
Node representing the result of the evaluation.
"""
left, right = [i.get_folded_value_throwing() for i in (self.left, self.right)]
left, right = [i.get_folded_value() for i in (self.left, self.right)]
if type(left) is not type(right):
raise UnfoldableNode("Node contains invalid field(s) for evaluation")
if not isinstance(left, (Int, Decimal)):
Expand Down Expand Up @@ -1205,7 +1183,7 @@ def fold(self) -> ExprNode:
NameConstant
Node representing the result of the evaluation.
"""
values = [i.get_folded_value_throwing() for i in self.values]
values = [i.get_folded_value() for i in self.values]

if any(not isinstance(i, NameConstant) for i in values):
raise UnfoldableNode("Node contains invalid field(s) for evaluation")
Expand Down Expand Up @@ -1261,7 +1239,7 @@ def fold(self) -> ExprNode:
NameConstant
Node representing the result of the evaluation.
"""
left, right = [i.get_folded_value_throwing() for i in (self.left, self.right)]
left, right = [i.get_folded_value() for i in (self.left, self.right)]
if not isinstance(left, Constant):
raise UnfoldableNode("Node contains invalid field(s) for evaluation")

Expand Down Expand Up @@ -1367,8 +1345,8 @@ def fold(self) -> ExprNode:
ExprNode
Node representing the result of the evaluation.
"""
slice_ = self.slice.value.get_folded_value_throwing()
value = self.value.get_folded_value_throwing()
slice_ = self.slice.value.get_folded_value()
value = self.value.get_folded_value()

if not isinstance(value, List):
raise UnfoldableNode("Subscript object is not a literal list")
Expand Down
7 changes: 5 additions & 2 deletions vyper/ast/nodes.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,13 @@ class VyperNode:
def __eq__(self, other: Any) -> Any: ...
@property
def description(self): ...
@property
def is_literal_value(self): ...
@property
def has_folded_value(self): ...
@classmethod
def get_fields(cls: Any) -> set: ...
def get_folded_value_throwing(self) -> VyperNode: ...
def get_folded_value_maybe(self) -> Optional[VyperNode]: ...
def get_folded_value(self) -> VyperNode: ...
def fold(self) -> VyperNode: ...
@classmethod
def from_node(cls, node: VyperNode, **kwargs: Any) -> Any: ...
Expand Down
2 changes: 1 addition & 1 deletion vyper/builtins/_signatures.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def process_arg(arg, expected_arg_type, context):

def process_kwarg(kwarg_node, kwarg_settings, expected_kwarg_type, context):
if kwarg_settings.require_literal:
return kwarg_node.get_folded_value_throwing().value
return kwarg_node.get_folded_value().value

return process_arg(kwarg_node, expected_kwarg_type, context)

Expand Down
42 changes: 21 additions & 21 deletions vyper/builtins/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ class Floor(BuiltinFunctionT):

def fold(self, node):
validate_call_args(node, 1)
value = node.args[0].get_folded_value_throwing()
value = node.args[0].get_folded_value()
if not isinstance(value, vy_ast.Decimal):
raise UnfoldableNode

Expand Down Expand Up @@ -167,7 +167,7 @@ class Ceil(BuiltinFunctionT):

def fold(self, node):
validate_call_args(node, 1)
value = node.args[0].get_folded_value_throwing()
value = node.args[0].get_folded_value()
if not isinstance(value, vy_ast.Decimal):
raise UnfoldableNode

Expand Down Expand Up @@ -461,7 +461,7 @@ class Len(BuiltinFunctionT):

def fold(self, node):
validate_call_args(node, 1)
arg = node.args[0].get_folded_value_throwing()
arg = node.args[0].get_folded_value()
if isinstance(arg, (vy_ast.Str, vy_ast.Bytes)):
length = len(arg.value)
elif isinstance(arg, vy_ast.Hex):
Expand Down Expand Up @@ -598,7 +598,7 @@ class Keccak256(BuiltinFunctionT):

def fold(self, node):
validate_call_args(node, 1)
value = node.args[0].get_folded_value_throwing()
value = node.args[0].get_folded_value()
if isinstance(value, vy_ast.Bytes):
value = value.value
elif isinstance(value, vy_ast.Str):
Expand Down Expand Up @@ -646,7 +646,7 @@ class Sha256(BuiltinFunctionT):

def fold(self, node):
validate_call_args(node, 1)
value = node.args[0].get_folded_value_throwing()
value = node.args[0].get_folded_value()
if isinstance(value, vy_ast.Bytes):
value = value.value
elif isinstance(value, vy_ast.Str):
Expand Down Expand Up @@ -720,7 +720,7 @@ class MethodID(FoldedFunctionT):
def fold(self, node):
validate_call_args(node, 1, ["output_type"])

value = node.args[0].get_folded_value_throwing()
value = node.args[0].get_folded_value()
if not isinstance(value, vy_ast.Str):
raise InvalidType("method id must be given as a literal string", node.args[0])
if " " in value.value:
Expand Down Expand Up @@ -980,7 +980,7 @@ class AsWeiValue(BuiltinFunctionT):
}

def get_denomination(self, node):
value = node.args[1].get_folded_value_throwing()
value = node.args[1].get_folded_value()
if not isinstance(value, vy_ast.Str):
raise ArgumentException(
"Wei denomination must be given as a literal string", node.args[1]
Expand All @@ -996,7 +996,7 @@ def fold(self, node):
validate_call_args(node, 2)
denom = self.get_denomination(node)

value = node.args[0].get_folded_value_throwing()
value = node.args[0].get_folded_value()
if not isinstance(value, (vy_ast.Decimal, vy_ast.Int)):
raise UnfoldableNode
value = value.value
Expand Down Expand Up @@ -1082,10 +1082,10 @@ def fetch_call_return(self, node):

outsize = kwargz.get("max_outsize")
if outsize is not None:
outsize = outsize.get_folded_value_throwing()
outsize = outsize.get_folded_value()
revert_on_failure = kwargz.get("revert_on_failure")
if revert_on_failure is not None:
revert_on_failure = revert_on_failure.get_folded_value_throwing()
revert_on_failure = revert_on_failure.get_folded_value()
revert_on_failure = revert_on_failure.value if revert_on_failure is not None else True

if outsize is None or outsize.value == 0:
Expand Down Expand Up @@ -1355,7 +1355,7 @@ def fold(self, node):
self.__class__._warned = True

validate_call_args(node, 2)
values = [i.get_folded_value_throwing() for i in node.args]
values = [i.get_folded_value() for i in node.args]
for val in values:
if not isinstance(val, vy_ast.Int):
raise UnfoldableNode
Expand All @@ -1380,7 +1380,7 @@ def fold(self, node):
self.__class__._warned = True

validate_call_args(node, 2)
values = [i.get_folded_value_throwing() for i in node.args]
values = [i.get_folded_value() for i in node.args]
for val in values:
if not isinstance(val, vy_ast.Int):
raise UnfoldableNode
Expand All @@ -1405,7 +1405,7 @@ def fold(self, node):
self.__class__._warned = True

validate_call_args(node, 2)
values = [i.get_folded_value_throwing() for i in node.args]
values = [i.get_folded_value() for i in node.args]
for val in values:
if not isinstance(val, vy_ast.Int):
raise UnfoldableNode
Expand All @@ -1430,7 +1430,7 @@ def fold(self, node):
self.__class__._warned = True

validate_call_args(node, 1)
value = node.args[0].get_folded_value_throwing()
value = node.args[0].get_folded_value()
if not isinstance(value, vy_ast.Int):
raise UnfoldableNode

Expand All @@ -1456,7 +1456,7 @@ def fold(self, node):
self.__class__._warned = True

validate_call_args(node, 2)
args = [i.get_folded_value_throwing() for i in node.args]
args = [i.get_folded_value() for i in node.args]
if any(not isinstance(i, vy_ast.Int) for i in args):
raise UnfoldableNode
value, shift = [i.value for i in args]
Expand Down Expand Up @@ -1503,7 +1503,7 @@ class _AddMulMod(BuiltinFunctionT):

def fold(self, node):
validate_call_args(node, 3)
args = [i.get_folded_value_throwing() for i in node.args]
args = [i.get_folded_value() for i in node.args]
if isinstance(args[2], vy_ast.Int) and args[2].value == 0:
raise ZeroDivisionException("Modulo by 0", node.args[2])
for arg in args:
Expand Down Expand Up @@ -1544,7 +1544,7 @@ class PowMod256(BuiltinFunctionT):

def fold(self, node):
validate_call_args(node, 2)
values = [i.get_folded_value_throwing() for i in node.args]
values = [i.get_folded_value() for i in node.args]
if any(not isinstance(i, vy_ast.Int) for i in values):
raise UnfoldableNode

Expand All @@ -1565,7 +1565,7 @@ class Abs(BuiltinFunctionT):

def fold(self, node):
validate_call_args(node, 1)
value = node.args[0].get_folded_value_throwing()
value = node.args[0].get_folded_value()
if not isinstance(value, vy_ast.Int):
raise UnfoldableNode

Expand Down Expand Up @@ -2005,8 +2005,8 @@ class _MinMax(BuiltinFunctionT):
def fold(self, node):
validate_call_args(node, 2)

left = node.args[0].get_folded_value_throwing()
right = node.args[1].get_folded_value_throwing()
left = node.args[0].get_folded_value()
right = node.args[1].get_folded_value()
if not isinstance(left, type(right)):
raise UnfoldableNode
if not isinstance(left, (vy_ast.Decimal, vy_ast.Int)):
Expand Down Expand Up @@ -2082,7 +2082,7 @@ def fetch_call_return(self, node):

def fold(self, node):
validate_call_args(node, 1)
value = node.args[0].get_folded_value_throwing()
value = node.args[0].get_folded_value()
if not isinstance(value, vy_ast.Int):
raise UnfoldableNode

Expand Down
6 changes: 4 additions & 2 deletions vyper/codegen/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ class Expr:

def __init__(self, node, context):
if isinstance(node, vy_ast.VyperNode):
node = node._metadata.get("folded_value", node)
node = node.get_folded_value() if node.has_folded_value else node

self.expr = node
self.context = context
Expand Down Expand Up @@ -193,7 +193,9 @@ def parse_Name(self):
# using the folded value metadata
assert isinstance(varinfo.typ, StructT)
value_node = varinfo.decl_node.value
value_node = value_node._metadata.get("folded_value", value_node)
value_node = (
value_node.get_folded_value() if value_node.has_folded_value else value_node
)
return Expr.parse_value_expr(value_node, self.context)

assert varinfo.modifiability == Modifiability.IMMUTABLE, "not an immutable!"
Expand Down
Loading

0 comments on commit 3cf993b

Please sign in to comment.