diff --git a/vyper/ast/folding.py b/vyper/ast/folding.py index 1121903109..17ae0a9da7 100644 --- a/vyper/ast/folding.py +++ b/vyper/ast/folding.py @@ -116,6 +116,7 @@ def replace_builtin_functions(vyper_module: vy_ast.Module) -> int: continue try: new_node = func.evaluate(node) # type: ignore + new_node._metadata["type"] = node._metadata["type"] except UnfoldableNode: continue diff --git a/vyper/builtins/_signatures.py b/vyper/builtins/_signatures.py index 332094cc0f..82c43e705b 100644 --- a/vyper/builtins/_signatures.py +++ b/vyper/builtins/_signatures.py @@ -1,12 +1,13 @@ import functools from typing import Dict + from vyper.ast.validation import validate_call_args from vyper.codegen.expr import Expr from vyper.codegen.ir_node import IRnode from vyper.exceptions import CompilerPanic, TypeMismatch from vyper.semantics.analysis.utils import ( - check_constant, + check_kwargable, get_exact_type_from_node, validate_expected_type, ) @@ -107,7 +108,7 @@ def _validate_arg_types(self, node): for kwarg in node.keywords: kwarg_settings = self._kwargs[kwarg.arg] - if kwarg_settings.require_literal and not check_constant(kwarg.value): + if kwarg_settings.require_literal and not check_kwargable(kwarg.value): raise TypeMismatch("Value for kwarg must be a literal", kwarg.value) self._validate_single(kwarg.value, kwarg_settings.typ) diff --git a/vyper/semantics/analysis/base.py b/vyper/semantics/analysis/base.py index 449e6ca338..041575ee63 100644 --- a/vyper/semantics/analysis/base.py +++ b/vyper/semantics/analysis/base.py @@ -159,7 +159,8 @@ class VarInfo: typ: VyperType location: DataLocation = DataLocation.UNSET - is_constant: bool = False + is_compile_time_constant: bool = False + is_runtime_constant: bool = False is_public: bool = False is_immutable: bool = False is_transient: bool = False @@ -192,11 +193,12 @@ class ExprInfo: typ: VyperType var_info: Optional[VarInfo] = None location: DataLocation = DataLocation.UNSET - is_constant: bool = False + is_compile_time_constant: bool = False + is_runtime_constant: bool = False is_immutable: bool = False def __post_init__(self): - should_match = ("typ", "location", "is_constant", "is_immutable") + should_match = ("typ", "location", "is_compile_time_constant", "is_runtime_constant", "is_immutable") if self.var_info is not None: for attr in should_match: if getattr(self.var_info, attr) != getattr(self, attr): @@ -208,15 +210,16 @@ def from_varinfo(cls, var_info: VarInfo) -> "ExprInfo": var_info.typ, var_info=var_info, location=var_info.location, - is_constant=var_info.is_constant, - is_immutable=var_info.is_immutable, + is_compile_time_constant=var_info.is_compile_time_constant, + is_runtime_constant=var_info.is_runtime_constant, + is_immutable=var_info.is_immutable ) def copy_with_type(self, typ: VyperType) -> "ExprInfo": """ Return a copy of the ExprInfo but with the type set to something else """ - to_copy = ("location", "is_constant", "is_immutable") + to_copy = ("location", "is_compile_time_constant", "is_runtime_constant", "is_immutable") fields = {k: getattr(self, k) for k in to_copy} return self.__class__(typ=typ, **fields) @@ -240,7 +243,7 @@ def validate_modification(self, node: vy_ast.VyperNode, mutability: StateMutabil if self.location == DataLocation.CALLDATA: raise ImmutableViolation("Cannot write to calldata", node) - if self.is_constant: + if self.is_compile_time_constant: raise ImmutableViolation("Constant value cannot be written to", node) if self.is_immutable: if node.get_ancestor(vy_ast.FunctionDef).get("name") != "__init__": diff --git a/vyper/semantics/analysis/local.py b/vyper/semantics/analysis/local.py index 57de3cbe1a..53dbb40178 100644 --- a/vyper/semantics/analysis/local.py +++ b/vyper/semantics/analysis/local.py @@ -193,7 +193,7 @@ def __init__( (DataLocation.MEMORY, False) if self.func.is_internal else (DataLocation.CALLDATA, True) ) for arg in self.func.arguments: - namespace[arg.name] = VarInfo(arg.typ, location=location, is_immutable=is_immutable) + namespace[arg.name] = VarInfo(arg.typ, location=location, is_immutable=is_immutable, is_runtime_constant=is_immutable) for node in fn_node.body: self.visit(node) @@ -473,7 +473,7 @@ def visit_For(self, node): with self.namespace.enter_scope(): try: - self.namespace[iter_name] = VarInfo(possible_target_type, is_constant=True) + self.namespace[iter_name] = VarInfo(possible_target_type, is_compile_time_constant=True) except VyperException as exc: raise exc.with_annotation(node) from None diff --git a/vyper/semantics/analysis/module.py b/vyper/semantics/analysis/module.py index 2cd9841e16..ecfff33a46 100644 --- a/vyper/semantics/analysis/module.py +++ b/vyper/semantics/analysis/module.py @@ -195,9 +195,10 @@ def visit_VariableDecl(self, node): type_, decl_node=node, location=data_loc, - is_constant=node.is_constant, + is_compile_time_constant=node.is_constant, is_public=node.is_public, is_immutable=node.is_immutable, + is_runtime_constant=node.is_immutable, is_transient=node.is_transient, ) node.target._metadata["varinfo"] = var_info # TODO maybe put this in the global namespace diff --git a/vyper/semantics/analysis/utils.py b/vyper/semantics/analysis/utils.py index 7d71c6c948..237792ef03 100644 --- a/vyper/semantics/analysis/utils.py +++ b/vyper/semantics/analysis/utils.py @@ -195,7 +195,7 @@ def _raise_invalid_reference(name, node): if isinstance(s, VyperType): # ex. foo.bar(). bar() is a ContractFunctionT return [s] - if is_self_reference and (s.is_constant or s.is_immutable): + if is_self_reference and (s.is_compile_time_constant or s.is_immutable): _raise_invalid_reference(name, node) # general case. s is a VarInfo, e.g. self.foo return [s.typ] @@ -622,10 +622,16 @@ def check_kwargable(node: vy_ast.VyperNode) -> bool: """ if _check_literal(node): return True - if isinstance(node, vy_ast.Attribute): - return check_kwargable(node.value) + + if isinstance(node, vy_ast.BinOp): + return all(check_kwargable(i) for i in (node.left, node.right)) + + if isinstance(node, vy_ast.BoolOp): + return all(check_kwargable(i) for i in node.values) + if isinstance(node, (vy_ast.Tuple, vy_ast.List)): return all(check_kwargable(item) for item in node.elements) + if isinstance(node, vy_ast.Call): args = node.args if len(args) == 1 and isinstance(args[0], vy_ast.Dict): @@ -636,8 +642,7 @@ def check_kwargable(node: vy_ast.VyperNode) -> bool: return True value_type = get_expr_info(node) - # is_constant here actually means not_assignable, and is to be renamed - return value_type.is_constant + return value_type.is_runtime_constant def _check_literal(node: vy_ast.VyperNode) -> bool: @@ -662,10 +667,10 @@ def check_constant(node: vy_ast.VyperNode) -> bool: return True if isinstance(node, vy_ast.BinOp): - return all(check_kwargable(i) for i in (node.left, node.right)) + return all(check_constant(i) for i in (node.left, node.right)) if isinstance(node, vy_ast.BoolOp): - return all(check_kwargable(i) for i in node.values) + return all(check_constant(i) for i in node.values) if isinstance(node, (vy_ast.Tuple, vy_ast.List)): return all(check_constant(item) for item in node.elements) @@ -680,5 +685,4 @@ def check_constant(node: vy_ast.VyperNode) -> bool: return True value_type = get_expr_info(node) - # is_constant here actually means not_assignable, and is to be renamed - return value_type.is_constant + return value_type.is_compile_time_constant diff --git a/vyper/semantics/environment.py b/vyper/semantics/environment.py index ad68f1103e..0eeb803eb7 100644 --- a/vyper/semantics/environment.py +++ b/vyper/semantics/environment.py @@ -52,7 +52,7 @@ def get_constant_vars() -> Dict: """ result = {} for k, v in CONSTANT_ENVIRONMENT_VARS.items(): - result[k] = VarInfo(v, is_constant=True) + result[k] = VarInfo(v, is_runtime_constant=True) return result diff --git a/vyper/semantics/types/function.py b/vyper/semantics/types/function.py index afbba95658..ee43939320 100644 --- a/vyper/semantics/types/function.py +++ b/vyper/semantics/types/function.py @@ -17,7 +17,7 @@ StructureException, ) from vyper.semantics.analysis.base import FunctionVisibility, StateMutability, StorageSlot -from vyper.semantics.analysis.utils import check_constant, validate_expected_type +from vyper.semantics.analysis.utils import check_kwargable, validate_expected_type from vyper.semantics.data_locations import DataLocation from vyper.semantics.types.base import KwargSettings, VyperType from vyper.semantics.types.primitives import BoolT @@ -320,7 +320,7 @@ def from_FunctionDef( positional_args.append(PositionalArg(argname, type_, ast_source=arg)) else: value = node.args.defaults[i - n_positional_args] - if not check_constant(value): + if not check_kwargable(value): raise StateAccessViolation( "Value must be literal or environment variable", value )