Skip to content

Commit

Permalink
handle const, list and tuples
Browse files Browse the repository at this point in the history
  • Loading branch information
tserg committed Dec 28, 2023
1 parent 394c9fb commit 17f01fa
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 19 deletions.
47 changes: 46 additions & 1 deletion vyper/ast/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -777,7 +777,12 @@ class Constant(ExprNode):

def __init__(self, parent: Optional["VyperNode"] = None, **kwargs: dict):
super().__init__(parent, **kwargs)
self._metadata["folded_value"] = self

def get_folded_value_throwing(self) -> "VyperNode":
return self

def get_folded_value_maybe(self) -> Optional["VyperNode"]:
return self


class Num(Constant):
Expand Down Expand Up @@ -911,6 +916,18 @@ def s(self):
return self.value


def check_literal(node: VyperNode) -> bool:
"""
Check if the given node is a literal value.
"""
if isinstance(node, Constant):
return True
elif isinstance(node, (Tuple, List)):
return all(check_literal(item) for item in node.elements)

return False


class List(ExprNode):
__slots__ = ("elements",)
_is_prefoldable = True
Expand All @@ -920,6 +937,18 @@ def fold(self) -> Optional[ExprNode]:
elements = [e.get_folded_value_throwing() for e in self.elements]
return type(self).from_node(self, elements=elements)

def get_folded_value_throwing(self) -> "VyperNode":
if check_literal(self):
return self

return super().get_folded_value_throwing()

def get_folded_value_maybe(self) -> Optional["VyperNode"]:
if check_literal(self):
return self

return super().get_folded_value_maybe()


class Tuple(ExprNode):
__slots__ = ("elements",)
Expand All @@ -929,6 +958,22 @@ def validate(self):
if not self.elements:
raise InvalidLiteral("Cannot have an empty tuple", self)

def fold(self) -> Optional[ExprNode]:
elements = [e.get_folded_value_throwing() for e in self.elements]
return type(self).from_node(self, elements=elements)

def get_folded_value_throwing(self) -> "VyperNode":
if check_literal(self):
return self

return super().get_folded_value_throwing()

def get_folded_value_maybe(self) -> Optional["VyperNode"]:
if check_literal(self):
return self

return super().get_folded_value_maybe()


class NameConstant(Constant):
__slots__ = ()
Expand Down
3 changes: 0 additions & 3 deletions vyper/codegen/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -720,9 +720,6 @@ def parse_List(self):
if len(self.expr.elements) == 0:
return IRnode.from_list("~empty", typ=typ)

for e in self.expr.elements:
if "type" not in e._metadata:
e._metadata["type"] = typ.subtype
multi_ir = [Expr(x, self.context).ir_node for x in self.expr.elements]

return IRnode.from_list(["multi"] + multi_ir, typ=typ)
Expand Down
16 changes: 1 addition & 15 deletions vyper/semantics/analysis/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -624,25 +624,11 @@ def validate_unique_method_ids(functions: List) -> None:
seen.add(method_id)


def _check_literal(node: vy_ast.VyperNode) -> bool:
"""
Check if the given node is a literal value.
"""
if isinstance(node, vy_ast.Constant):
return True
elif isinstance(node, (vy_ast.Tuple, vy_ast.List)):
return all(_check_literal(item) for item in node.elements)

if node.get_folded_value_maybe():
return True
return False


def check_modifiability(node: vy_ast.VyperNode, modifiability: Modifiability) -> bool:
"""
Check if the given node is not more modifiable than the given modifiability.
"""
if _check_literal(node):
if node.get_folded_value_maybe():
return True

if isinstance(node, (vy_ast.BinOp, vy_ast.Compare)):
Expand Down

0 comments on commit 17f01fa

Please sign in to comment.