diff --git a/vyper/ast/nodes.py b/vyper/ast/nodes.py index e4e9f37866..25da0714ee 100644 --- a/vyper/ast/nodes.py +++ b/vyper/ast/nodes.py @@ -777,7 +777,12 @@ class Constant(ExprNode): def __init__(self, parent: Optional["VyperNode"] = None, **kwargs: dict): super().__init__(parent, **kwargs) - self._metadata["folded_value"] = self + + def get_folded_value_throwing(self) -> "VyperNode": + return self + + def get_folded_value_maybe(self) -> Optional["VyperNode"]: + return self class Num(Constant): @@ -911,6 +916,18 @@ def s(self): return self.value +def check_literal(node: VyperNode) -> bool: + """ + Check if the given node is a literal value. + """ + if isinstance(node, Constant): + return True + elif isinstance(node, (Tuple, List)): + return all(check_literal(item) for item in node.elements) + + return False + + class List(ExprNode): __slots__ = ("elements",) _is_prefoldable = True @@ -920,15 +937,44 @@ def fold(self) -> Optional[ExprNode]: elements = [e.get_folded_value_throwing() for e in self.elements] return type(self).from_node(self, elements=elements) + def get_folded_value_throwing(self) -> "VyperNode": + if check_literal(self): + return self + + return super().get_folded_value_throwing() + + def get_folded_value_maybe(self) -> Optional["VyperNode"]: + if check_literal(self): + return self + + return super().get_folded_value_maybe() + class Tuple(ExprNode): __slots__ = ("elements",) + _is_prefoldable = True _translated_fields = {"elts": "elements"} def validate(self): if not self.elements: raise InvalidLiteral("Cannot have an empty tuple", self) + def fold(self) -> Optional[ExprNode]: + elements = [e.get_folded_value_throwing() for e in self.elements] + return type(self).from_node(self, elements=elements) + + def get_folded_value_throwing(self) -> "VyperNode": + if check_literal(self): + return self + + return super().get_folded_value_throwing() + + def get_folded_value_maybe(self) -> Optional["VyperNode"]: + if check_literal(self): + return self + + return super().get_folded_value_maybe() + class NameConstant(Constant): __slots__ = () diff --git a/vyper/builtins/_signatures.py b/vyper/builtins/_signatures.py index 7c8d396f3e..f955296ee0 100644 --- a/vyper/builtins/_signatures.py +++ b/vyper/builtins/_signatures.py @@ -34,7 +34,7 @@ def process_arg(arg, expected_arg_type, context): def process_kwarg(kwarg_node, kwarg_settings, expected_kwarg_type, context): if kwarg_settings.require_literal: - return kwarg_node.value + return kwarg_node.get_folded_value_throwing().value return process_arg(kwarg_node, expected_kwarg_type, context) diff --git a/vyper/codegen/expr.py b/vyper/codegen/expr.py index 1b5934d41f..27266577a0 100644 --- a/vyper/codegen/expr.py +++ b/vyper/codegen/expr.py @@ -70,6 +70,9 @@ class Expr: # TODO: Once other refactors are made reevaluate all inline imports def __init__(self, node, context): + if isinstance(node, vy_ast.VyperNode): + node = node._metadata.get("folded_value", node) + self.expr = node self.context = context @@ -185,6 +188,14 @@ def parse_Name(self): # TODO: use self.expr._expr_info elif self.expr.id in self.context.globals: varinfo = self.context.globals[self.expr.id] + if varinfo.modifiability == Modifiability.ALWAYS_CONSTANT: + # non-struct constants should have been dispatched via the `Expr` ctor + # using the folded value metadata + assert isinstance(varinfo.typ, StructT) + value_node = varinfo.decl_node.value + value_node = value_node._metadata.get("folded_value", value_node) + return Expr.parse_value_expr(value_node, self.context) + assert varinfo.modifiability == Modifiability.IMMUTABLE, "not an immutable!" ofst = varinfo.position.offset diff --git a/vyper/codegen/function_definitions/external_function.py b/vyper/codegen/function_definitions/external_function.py index f46aa6b2e5..65276469e7 100644 --- a/vyper/codegen/function_definitions/external_function.py +++ b/vyper/codegen/function_definitions/external_function.py @@ -1,4 +1,3 @@ -from vyper import ast as vy_ast from vyper.codegen.abi_encoder import abi_encoding_matches_vyper from vyper.codegen.context import Context, VariableRecord from vyper.codegen.core import get_element_ptr, getpos, make_setter, needs_clamp @@ -51,7 +50,7 @@ def _register_function_args(func_t: ContractFunctionT, context: Context) -> list def _generate_kwarg_handlers( - func_t: ContractFunctionT, context: Context, code: vy_ast.FunctionDef + func_t: ContractFunctionT, context: Context ) -> dict[str, tuple[int, IRnode]]: # generate kwarg handlers. # since they might come in thru calldata or be default, @@ -63,7 +62,7 @@ def _generate_kwarg_handlers( # write default args to memory # goto external_function_common_ir - def handler_for(calldata_kwargs, original_default_kwargs, folded_default_kwargs): + def handler_for(calldata_kwargs, default_kwargs): calldata_args = func_t.positional_args + calldata_kwargs # create a fake type so that get_element_ptr works calldata_args_t = TupleT(list(arg.typ for arg in calldata_args)) @@ -82,7 +81,7 @@ def handler_for(calldata_kwargs, original_default_kwargs, folded_default_kwargs) calldata_min_size = args_abi_t.min_size() + 4 # TODO optimize make_setter by using - # TupleT(list(arg.typ for arg in calldata_kwargs + folded_default_kwargs)) + # TupleT(list(arg.typ for arg in calldata_kwargs + default_kwargs)) # (must ensure memory area is contiguous) for i, arg_meta in enumerate(calldata_kwargs): @@ -98,15 +97,15 @@ def handler_for(calldata_kwargs, original_default_kwargs, folded_default_kwargs) copy_arg.source_pos = getpos(arg_meta.ast_source) ret.append(copy_arg) - for x, y in zip(original_default_kwargs, folded_default_kwargs): + for x in default_kwargs: dst = context.lookup_var(x.name).pos lhs = IRnode(dst, location=MEMORY, typ=x.typ) - lhs.source_pos = getpos(y) - kw_ast_val = y + lhs.source_pos = getpos(x.ast_source) + kw_ast_val = func_t.default_values[x.name] # e.g. `3` in x: int = 3 rhs = Expr(kw_ast_val, context).ir_node copy_arg = make_setter(lhs, rhs) - copy_arg.source_pos = getpos(y) + copy_arg.source_pos = getpos(x.ast_source) ret.append(copy_arg) ret.append(["goto", func_t._ir_info.external_function_base_entry_label]) @@ -117,7 +116,6 @@ def handler_for(calldata_kwargs, original_default_kwargs, folded_default_kwargs) ret = {} keyword_args = func_t.keyword_args - folded_keyword_args = code.args.defaults # allocate variable slots in memory for arg in keyword_args: @@ -125,17 +123,12 @@ def handler_for(calldata_kwargs, original_default_kwargs, folded_default_kwargs) for i, _ in enumerate(keyword_args): calldata_kwargs = keyword_args[:i] - # folded ast - original_default_kwargs = keyword_args[i:] - # unfolded ast - folded_default_kwargs = folded_keyword_args[i:] + default_kwargs = keyword_args[i:] - sig, calldata_min_size, ir_node = handler_for( - calldata_kwargs, original_default_kwargs, folded_default_kwargs - ) + sig, calldata_min_size, ir_node = handler_for(calldata_kwargs, default_kwargs) ret[sig] = calldata_min_size, ir_node - sig, calldata_min_size, ir_node = handler_for(keyword_args, [], []) + sig, calldata_min_size, ir_node = handler_for(keyword_args, []) ret[sig] = calldata_min_size, ir_node @@ -160,7 +153,7 @@ def generate_ir_for_external_function(code, func_t, context): handle_base_args = _register_function_args(func_t, context) # generate handlers for kwargs and register the variable records - kwarg_handlers = _generate_kwarg_handlers(func_t, context, code) + kwarg_handlers = _generate_kwarg_handlers(func_t, context) body = ["seq"] # once optional args have been handled, diff --git a/vyper/compiler/phases.py b/vyper/compiler/phases.py index 4982e84b68..7407c4f281 100644 --- a/vyper/compiler/phases.py +++ b/vyper/compiler/phases.py @@ -280,7 +280,7 @@ def generate_folded_ast( symbol_tables = set_data_positions(vyper_module, storage_layout_overrides) vyper_module_folded = copy.deepcopy(vyper_module) - vy_ast.folding.fold(vyper_module_folded) + # vy_ast.folding.fold(vyper_module_folded) return vyper_module_folded, symbol_tables diff --git a/vyper/exceptions.py b/vyper/exceptions.py index f216069eab..f625a7d3fb 100644 --- a/vyper/exceptions.py +++ b/vyper/exceptions.py @@ -108,7 +108,7 @@ def __str__(self): if isinstance(node, vy_ast.VyperNode): module_node = node.get_ancestor(vy_ast.Module) - if module_node.get("path") not in (None, ""): + if module_node and module_node.get("path") not in (None, ""): node_msg = f'{node_msg}contract "{module_node.path}:{node.lineno}", ' fn_node = node.get_ancestor(vy_ast.FunctionDef) diff --git a/vyper/semantics/analysis/local.py b/vyper/semantics/analysis/local.py index 0163547d55..417e9e7018 100644 --- a/vyper/semantics/analysis/local.py +++ b/vyper/semantics/analysis/local.py @@ -545,13 +545,15 @@ def visit(self, node, typ): # can happen. super().visit(node, typ) - folded_value = node.get_folded_value_maybe() - if isinstance(folded_value, vy_ast.Constant): - validate_expected_type(folded_value, typ) - # annotate node._metadata["type"] = typ + # validate and annotate folded value + folded_value = node._metadata.get("folded_value") + if folded_value: + validate_expected_type(folded_value, typ) + folded_value._metadata["type"] = typ + def visit_Attribute(self, node: vy_ast.Attribute, typ: VyperType) -> None: _validate_msg_data_attribute(node) diff --git a/vyper/semantics/analysis/module.py b/vyper/semantics/analysis/module.py index ffbd2265db..9d828eaa2d 100644 --- a/vyper/semantics/analysis/module.py +++ b/vyper/semantics/analysis/module.py @@ -495,7 +495,7 @@ def _parse_and_fold_ast(file: FileInput) -> vy_ast.VyperNode: resolved_path=str(file.resolved_path), ) vy_ast.validation.validate_literal_nodes(ret) - vy_ast.folding.fold(ret) + # vy_ast.folding.fold(ret) return ret diff --git a/vyper/semantics/analysis/pre_typecheck.py b/vyper/semantics/analysis/pre_typecheck.py index 223407d839..b89c1c6759 100644 --- a/vyper/semantics/analysis/pre_typecheck.py +++ b/vyper/semantics/analysis/pre_typecheck.py @@ -1,5 +1,5 @@ from vyper import ast as vy_ast -from vyper.exceptions import UnfoldableNode, VyperException +from vyper.exceptions import UnfoldableNode def get_constants(node: vy_ast.Module) -> dict: @@ -66,7 +66,7 @@ def prefold(node: vy_ast.VyperNode, constants: dict[str, vy_ast.VyperNode]): try: node._metadata["folded_value"] = call_type.fold(node) return - except (UnfoldableNode, VyperException): + except UnfoldableNode: pass if getattr(node, "_is_prefoldable", None): diff --git a/vyper/semantics/analysis/utils.py b/vyper/semantics/analysis/utils.py index 3f91d5f258..dcf81b4d6e 100644 --- a/vyper/semantics/analysis/utils.py +++ b/vyper/semantics/analysis/utils.py @@ -624,25 +624,11 @@ def validate_unique_method_ids(functions: List) -> None: seen.add(method_id) -def _check_literal(node: vy_ast.VyperNode) -> bool: - """ - Check if the given node is a literal value. - """ - if isinstance(node, vy_ast.Constant): - return True - elif isinstance(node, (vy_ast.Tuple, vy_ast.List)): - return all(_check_literal(item) for item in node.elements) - - if node.get_folded_value_maybe(): - return True - return False - - def check_modifiability(node: vy_ast.VyperNode, modifiability: Modifiability) -> bool: """ Check if the given node is not more modifiable than the given modifiability. """ - if _check_literal(node): + if node.get_folded_value_maybe(): return True if isinstance(node, (vy_ast.BinOp, vy_ast.Compare)):