From 17f01fafcb0089acae2325569d73692a265c1cd8 Mon Sep 17 00:00:00 2001 From: tserg <8017125+tserg@users.noreply.github.com> Date: Thu, 28 Dec 2023 10:45:30 +0800 Subject: [PATCH] handle const, list and tuples --- vyper/ast/nodes.py | 47 ++++++++++++++++++++++++++++++- vyper/codegen/expr.py | 3 -- vyper/semantics/analysis/utils.py | 16 +---------- 3 files changed, 47 insertions(+), 19 deletions(-) diff --git a/vyper/ast/nodes.py b/vyper/ast/nodes.py index e4e9f37866..cc4929286e 100644 --- a/vyper/ast/nodes.py +++ b/vyper/ast/nodes.py @@ -777,7 +777,12 @@ class Constant(ExprNode): def __init__(self, parent: Optional["VyperNode"] = None, **kwargs: dict): super().__init__(parent, **kwargs) - self._metadata["folded_value"] = self + + def get_folded_value_throwing(self) -> "VyperNode": + return self + + def get_folded_value_maybe(self) -> Optional["VyperNode"]: + return self class Num(Constant): @@ -911,6 +916,18 @@ def s(self): return self.value +def check_literal(node: VyperNode) -> bool: + """ + Check if the given node is a literal value. + """ + if isinstance(node, Constant): + return True + elif isinstance(node, (Tuple, List)): + return all(check_literal(item) for item in node.elements) + + return False + + class List(ExprNode): __slots__ = ("elements",) _is_prefoldable = True @@ -920,6 +937,18 @@ def fold(self) -> Optional[ExprNode]: elements = [e.get_folded_value_throwing() for e in self.elements] return type(self).from_node(self, elements=elements) + def get_folded_value_throwing(self) -> "VyperNode": + if check_literal(self): + return self + + return super().get_folded_value_throwing() + + def get_folded_value_maybe(self) -> Optional["VyperNode"]: + if check_literal(self): + return self + + return super().get_folded_value_maybe() + class Tuple(ExprNode): __slots__ = ("elements",) @@ -929,6 +958,22 @@ def validate(self): if not self.elements: raise InvalidLiteral("Cannot have an empty tuple", self) + def fold(self) -> Optional[ExprNode]: + elements = [e.get_folded_value_throwing() for e in self.elements] + return type(self).from_node(self, elements=elements) + + def get_folded_value_throwing(self) -> "VyperNode": + if check_literal(self): + return self + + return super().get_folded_value_throwing() + + def get_folded_value_maybe(self) -> Optional["VyperNode"]: + if check_literal(self): + return self + + return super().get_folded_value_maybe() + class NameConstant(Constant): __slots__ = () diff --git a/vyper/codegen/expr.py b/vyper/codegen/expr.py index 28de2a18be..25927f1ab0 100644 --- a/vyper/codegen/expr.py +++ b/vyper/codegen/expr.py @@ -720,9 +720,6 @@ def parse_List(self): if len(self.expr.elements) == 0: return IRnode.from_list("~empty", typ=typ) - for e in self.expr.elements: - if "type" not in e._metadata: - e._metadata["type"] = typ.subtype multi_ir = [Expr(x, self.context).ir_node for x in self.expr.elements] return IRnode.from_list(["multi"] + multi_ir, typ=typ) diff --git a/vyper/semantics/analysis/utils.py b/vyper/semantics/analysis/utils.py index 3f91d5f258..dcf81b4d6e 100644 --- a/vyper/semantics/analysis/utils.py +++ b/vyper/semantics/analysis/utils.py @@ -624,25 +624,11 @@ def validate_unique_method_ids(functions: List) -> None: seen.add(method_id) -def _check_literal(node: vy_ast.VyperNode) -> bool: - """ - Check if the given node is a literal value. - """ - if isinstance(node, vy_ast.Constant): - return True - elif isinstance(node, (vy_ast.Tuple, vy_ast.List)): - return all(_check_literal(item) for item in node.elements) - - if node.get_folded_value_maybe(): - return True - return False - - def check_modifiability(node: vy_ast.VyperNode, modifiability: Modifiability) -> bool: """ Check if the given node is not more modifiable than the given modifiability. """ - if _check_literal(node): + if node.get_folded_value_maybe(): return True if isinstance(node, (vy_ast.BinOp, vy_ast.Compare)):