Skip to content

Commit

Permalink
remove is_immutable; rename constancy to modifiability
Browse files Browse the repository at this point in the history
  • Loading branch information
tserg committed Dec 25, 2023
1 parent 4f4576c commit 6daf69f
Show file tree
Hide file tree
Showing 10 changed files with 38 additions and 42 deletions.
4 changes: 2 additions & 2 deletions vyper/builtins/_signatures.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from vyper.exceptions import CompilerPanic, TypeMismatch, UnfoldableNode
from vyper.semantics.analysis.base import Modifiability
from vyper.semantics.analysis.utils import (
check_variable_constancy,
check_modifiability,
get_exact_type_from_node,
validate_expected_type,
)
Expand Down Expand Up @@ -111,7 +111,7 @@ def _validate_arg_types(self, node: vy_ast.Call) -> None:

for kwarg in node.keywords:
kwarg_settings = self._kwargs[kwarg.arg]
if kwarg_settings.require_literal and not check_variable_constancy(
if kwarg_settings.require_literal and not check_modifiability(
kwarg.value, Modifiability.IMMUTABLE
):
raise TypeMismatch("Value must be literal or environment variable", kwarg.value)
Expand Down
2 changes: 1 addition & 1 deletion vyper/builtins/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ class FoldedFunctionT(BuiltinFunctionT):
# Base class for nodes which should always be folded

# Since foldable builtin functions are not folded before semantics validation,
# this flag is used for `check_variable_constancy` in semantics validation.
# this flag is used for `check_modifiability` in semantics validation.
_kwargable = True


Expand Down
3 changes: 2 additions & 1 deletion vyper/codegen/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
VyperException,
tag_exceptions,
)
from vyper.semantics.analysis.base import Modifiability
from vyper.semantics.types import (
AddressT,
BoolT,
Expand Down Expand Up @@ -186,7 +187,7 @@ def parse_Name(self):
# TODO: use self.expr._expr_info
elif self.expr.id in self.context.globals:
varinfo = self.context.globals[self.expr.id]
assert varinfo.is_immutable, "not an immutable!"
assert varinfo.modifiability == Modifiability.IMMUTABLE, "not an immutable!"

ofst = varinfo.position.offset

Expand Down
19 changes: 8 additions & 11 deletions vyper/semantics/analysis/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ class ImportInfo(AnalysisResult):
class VarInfo:
"""
VarInfo are objects that represent the type of a variable,
plus associated metadata like location and constancy attributes
plus associated metadata like location and modifiability attributes
Object Attributes
-----------------
Expand All @@ -200,9 +200,8 @@ class VarInfo:

typ: VyperType
location: DataLocation = DataLocation.UNSET
constancy: Modifiability = Modifiability.MODIFIABLE
modifiability: Modifiability = Modifiability.MODIFIABLE
is_public: bool = False
is_immutable: bool = False
is_transient: bool = False
is_local_var: bool = False
decl_node: Optional[vy_ast.VyperNode] = None
Expand Down Expand Up @@ -233,11 +232,10 @@ class ExprInfo:
typ: VyperType
var_info: Optional[VarInfo] = None
location: DataLocation = DataLocation.UNSET
constancy: Modifiability = Modifiability.MODIFIABLE
is_immutable: bool = False
modifiability: Modifiability = Modifiability.MODIFIABLE

def __post_init__(self):
should_match = ("typ", "location", "constancy", "is_immutable")
should_match = ("typ", "location", "modifiability")
if self.var_info is not None:
for attr in should_match:
if getattr(self.var_info, attr) != getattr(self, attr):
Expand All @@ -249,8 +247,7 @@ def from_varinfo(cls, var_info: VarInfo) -> "ExprInfo":
var_info.typ,
var_info=var_info,
location=var_info.location,
constancy=var_info.constancy,
is_immutable=var_info.is_immutable,
modifiability=var_info.modifiability,
)

@classmethod
Expand All @@ -261,7 +258,7 @@ def copy_with_type(self, typ: VyperType) -> "ExprInfo":
"""
Return a copy of the ExprInfo but with the type set to something else
"""
to_copy = ("location", "constancy", "is_immutable")
to_copy = ("location", "modifiability")
fields = {k: getattr(self, k) for k in to_copy}
return self.__class__(typ=typ, **fields)

Expand All @@ -285,9 +282,9 @@ def validate_modification(self, node: vy_ast.VyperNode, mutability: StateMutabil

if self.location == DataLocation.CALLDATA:
raise ImmutableViolation("Cannot write to calldata", node)
if self.constancy == Modifiability.ALWAYS_CONSTANT:
if self.modifiability == Modifiability.ALWAYS_CONSTANT:
raise ImmutableViolation("Constant value cannot be written to", node)
if self.is_immutable:
if self.modifiability == Modifiability.IMMUTABLE:
if node.get_ancestor(vy_ast.FunctionDef).get("name") != "__init__":
raise ImmutableViolation("Immutable value cannot be written to", node)
# TODO: we probably want to remove this restriction.
Expand Down
10 changes: 5 additions & 5 deletions vyper/semantics/analysis/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,14 +190,14 @@ def __init__(

def analyze(self):
# allow internal function params to be mutable
location, is_immutable, constancy = (
(DataLocation.MEMORY, False, Modifiability.MODIFIABLE)
location, modifiability = (
(DataLocation.MEMORY, Modifiability.MODIFIABLE)
if self.func.is_internal
else (DataLocation.CALLDATA, True, Modifiability.NOT_MODIFIABLE)
else (DataLocation.CALLDATA, Modifiability.NOT_MODIFIABLE)
)
for arg in self.func.arguments:
self.namespace[arg.name] = VarInfo(
arg.typ, location=location, is_immutable=is_immutable, constancy=constancy
arg.typ, location=location, modifiability=modifiability
)

for node in self.fn_node.body:
Expand Down Expand Up @@ -425,7 +425,7 @@ def visit_For(self, node):

with self.namespace.enter_scope():
self.namespace[iter_name] = VarInfo(
possible_target_type, constancy=Modifiability.ALWAYS_CONSTANT
possible_target_type, modifiability=Modifiability.ALWAYS_CONSTANT
)

try:
Expand Down
9 changes: 4 additions & 5 deletions vyper/semantics/analysis/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from vyper.semantics.analysis.local import ExprVisitor, validate_functions
from vyper.semantics.analysis.pre_typecheck import pre_typecheck
from vyper.semantics.analysis.utils import (
check_variable_constancy,
check_modifiability,
get_exact_type_from_node,
validate_expected_type,
)
Expand Down Expand Up @@ -262,7 +262,7 @@ def visit_VariableDecl(self, node):
else DataLocation.STORAGE
)

constancy = (
modifiability = (
Modifiability.IMMUTABLE
if node.is_immutable
else Modifiability.ALWAYS_CONSTANT
Expand All @@ -279,9 +279,8 @@ def visit_VariableDecl(self, node):
type_,
decl_node=node,
location=data_loc,
constancy=constancy,
modifiability=modifiability,
is_public=node.is_public,
is_immutable=node.is_immutable,
is_transient=node.is_transient,
)
node.target._metadata["varinfo"] = var_info # TODO maybe put this in the global namespace
Expand Down Expand Up @@ -317,7 +316,7 @@ def _validate_self_namespace():

ExprVisitor().visit(node.value, type_)

if not check_variable_constancy(node.value, Modifiability.ALWAYS_CONSTANT):
if not check_modifiability(node.value, Modifiability.ALWAYS_CONSTANT):
raise StateAccessViolation("Value must be a literal", node.value)

validate_expected_type(node.value, type_)
Expand Down
23 changes: 11 additions & 12 deletions vyper/semantics/analysis/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,10 +98,9 @@ def get_expr_info(self, node: vy_ast.VyperNode) -> ExprInfo:
# kludge! for validate_modification in local analysis of Assign
types = [self.get_expr_info(n) for n in node.elements]
location = sorted((i.location for i in types), key=lambda k: k.value)[-1]
constancy = sorted((i.constancy for i in types), key=lambda k: k.value)[-1]
is_immutable = any((getattr(i, "is_immutable", False) for i in types))
modifiability = sorted((i.modifiability for i in types), key=lambda k: k.value)[-1]

return ExprInfo(t, location=location, constancy=constancy, is_immutable=is_immutable)
return ExprInfo(t, location=location, modifiability=modifiability)

# If it's a Subscript, propagate the subscriptable varinfo
if isinstance(node, vy_ast.Subscript):
Expand Down Expand Up @@ -201,7 +200,7 @@ def _raise_invalid_reference(name, node):
if isinstance(s, (VyperType, TYPE_T)):
# ex. foo.bar(). bar() is a ContractFunctionT
return [s]
if is_self_reference and s.constancy >= Modifiability.IMMUTABLE:
if is_self_reference and s.modifiability >= Modifiability.IMMUTABLE:
_raise_invalid_reference(name, node)
# general case. s is a VarInfo, e.g. self.foo
return [s.typ]
Expand Down Expand Up @@ -639,33 +638,33 @@ def _check_literal(node: vy_ast.VyperNode) -> bool:
return False


def check_variable_constancy(node: vy_ast.VyperNode, constancy: Modifiability) -> bool:
def check_modifiability(node: vy_ast.VyperNode, modifiability: Modifiability) -> bool:
"""
Check if the given node is a literal or constant value.
Check if the given node is not more modifiable than the given modifiability.
"""
if _check_literal(node):
return True

if isinstance(node, (vy_ast.BinOp, vy_ast.Compare)):
return all(check_variable_constancy(i, constancy) for i in (node.left, node.right))
return all(check_modifiability(i, modifiability) for i in (node.left, node.right))

if isinstance(node, vy_ast.BoolOp):
return all(check_variable_constancy(i, constancy) for i in node.values)
return all(check_modifiability(i, modifiability) for i in node.values)

if isinstance(node, vy_ast.UnaryOp):
return check_variable_constancy(node.operand, constancy)
return check_modifiability(node.operand, modifiability)

if isinstance(node, (vy_ast.Tuple, vy_ast.List)):
return all(check_variable_constancy(item, constancy) for item in node.elements)
return all(check_modifiability(item, modifiability) for item in node.elements)

if isinstance(node, vy_ast.Call):
args = node.args
if len(args) == 1 and isinstance(args[0], vy_ast.Dict):
return all(check_variable_constancy(v, constancy) for v in args[0].values)
return all(check_modifiability(v, modifiability) for v in args[0].values)

call_type = get_exact_type_from_node(node.func)
if getattr(call_type, "_kwargable", False):
return True

value_type = get_expr_info(node)
return value_type.constancy >= constancy
return value_type.modifiability >= modifiability
2 changes: 1 addition & 1 deletion vyper/semantics/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def get_constant_vars() -> Dict:
"""
result = {}
for k, v in CONSTANT_ENVIRONMENT_VARS.items():
result[k] = VarInfo(v, constancy=Modifiability.CONSTANT_IN_CURRENT_TX)
result[k] = VarInfo(v, modifiability=Modifiability.CONSTANT_IN_CURRENT_TX)

return result

Expand Down
4 changes: 2 additions & 2 deletions vyper/semantics/types/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
StorageSlot,
)
from vyper.semantics.analysis.utils import (
check_variable_constancy,
check_modifiability,
get_exact_type_from_node,
validate_expected_type,
)
Expand Down Expand Up @@ -703,7 +703,7 @@ def _parse_args(
positional_args.append(PositionalArg(argname, type_, ast_source=arg))
else:
value = funcdef.args.defaults[i - n_positional_args]
if not check_variable_constancy(value, Modifiability.IMMUTABLE):
if not check_modifiability(value, Modifiability.IMMUTABLE):
raise StateAccessViolation("Value must be literal or environment variable", value)
validate_expected_type(value, type_)
keyword_args.append(KeywordArg(argname, type_, value, ast_source=arg))
Expand Down
4 changes: 2 additions & 2 deletions vyper/semantics/types/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from vyper.abi_types import ABI_Address, ABIType
from vyper.ast.validation import validate_call_args
from vyper.exceptions import InterfaceViolation, NamespaceCollision, StructureException
from vyper.semantics.analysis.base import VarInfo
from vyper.semantics.analysis.base import Modifiability, VarInfo
from vyper.semantics.analysis.utils import validate_expected_type, validate_unique_method_ids
from vyper.semantics.namespace import get_namespace
from vyper.semantics.types.base import TYPE_T, VyperType
Expand Down Expand Up @@ -324,7 +324,7 @@ def variables(self):

@cached_property
def immutables(self):
return [t for t in self.variables.values() if t.is_immutable]
return [t for t in self.variables.values() if t.modifiability == Modifiability.IMMUTABLE]

@cached_property
def immutable_section_bytes(self):
Expand Down

0 comments on commit 6daf69f

Please sign in to comment.