Skip to content

Commit

Permalink
add immutable enum to modifiability
Browse files Browse the repository at this point in the history
  • Loading branch information
tserg committed Dec 22, 2023
1 parent ce3f61f commit 83a8b66
Show file tree
Hide file tree
Showing 7 changed files with 27 additions and 25 deletions.
4 changes: 2 additions & 2 deletions vyper/builtins/_signatures.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from vyper.codegen.expr import Expr
from vyper.codegen.ir_node import IRnode
from vyper.exceptions import CompilerPanic, TypeMismatch, UnfoldableNode, VyperException
from vyper.semantics.analysis.base import VariableConstancy
from vyper.semantics.analysis.base import Modifiability
from vyper.semantics.analysis.utils import (
check_variable_constancy,
get_exact_type_from_node,
Expand Down Expand Up @@ -112,7 +112,7 @@ def _validate_arg_types(self, node: vy_ast.Call) -> None:
for kwarg in node.keywords:
kwarg_settings = self._kwargs[kwarg.arg]
if kwarg_settings.require_literal and not check_variable_constancy(
kwarg.value, VariableConstancy.RUNTIME_CONSTANT
kwarg.value, Modifiability.IMMUTABLE
):
raise TypeMismatch("Value must be literal or environment variable", kwarg.value)
self._validate_single(kwarg.value, kwarg_settings.typ)
Expand Down
16 changes: 9 additions & 7 deletions vyper/semantics/analysis/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,10 +97,12 @@ def from_abi(cls, abi_dict: Dict) -> "StateMutability":
# specifying a state mutability modifier at all. Do the same here.


class VariableConstancy(enum.IntEnum):
MUTABLE = enum.auto()
RUNTIME_CONSTANT = enum.auto()
COMPILE_TIME_CONSTANT = enum.auto()
class Modifiability(enum.IntEnum):
MODIFIABLE = enum.auto()
IMMUTABLE = enum.auto()
NOT_MODIFIABLE = enum.auto()
CONSTANT_IN_CURRENT_TX = enum.auto()
ALWAYS_CONSTANT = enum.auto()


class DataPosition:
Expand Down Expand Up @@ -198,7 +200,7 @@ class VarInfo:

typ: VyperType
location: DataLocation = DataLocation.UNSET
constancy: VariableConstancy = VariableConstancy.MUTABLE
constancy: Modifiability = Modifiability.MODIFIABLE
is_public: bool = False
is_immutable: bool = False
is_transient: bool = False
Expand Down Expand Up @@ -231,7 +233,7 @@ class ExprInfo:
typ: VyperType
var_info: Optional[VarInfo] = None
location: DataLocation = DataLocation.UNSET
constancy: VariableConstancy = VariableConstancy.MUTABLE
constancy: Modifiability = Modifiability.MODIFIABLE
is_immutable: bool = False

def __post_init__(self):
Expand Down Expand Up @@ -283,7 +285,7 @@ def validate_modification(self, node: vy_ast.VyperNode, mutability: StateMutabil

if self.location == DataLocation.CALLDATA:
raise ImmutableViolation("Cannot write to calldata", node)
if self.constancy == VariableConstancy.COMPILE_TIME_CONSTANT:
if self.constancy == Modifiability.ALWAYS_CONSTANT:
raise ImmutableViolation("Constant value cannot be written to", node)
if self.is_immutable:
if node.get_ancestor(vy_ast.FunctionDef).get("name") != "__init__":
Expand Down
8 changes: 4 additions & 4 deletions vyper/semantics/analysis/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
VariableDeclarationException,
VyperException,
)
from vyper.semantics.analysis.base import VariableConstancy, VarInfo
from vyper.semantics.analysis.base import Modifiability, VarInfo
from vyper.semantics.analysis.common import VyperNodeVisitorBase
from vyper.semantics.analysis.utils import (
get_common_types,
Expand Down Expand Up @@ -192,9 +192,9 @@ def __init__(
def analyze(self):
# allow internal function params to be mutable
location, is_immutable, constancy = (
(DataLocation.MEMORY, False, VariableConstancy.MUTABLE)
(DataLocation.MEMORY, False, Modifiability.MODIFIABLE)
if self.func.is_internal
else (DataLocation.CALLDATA, True, VariableConstancy.RUNTIME_CONSTANT)
else (DataLocation.CALLDATA, True, Modifiability.NOT_MODIFIABLE)
)
for arg in self.func.arguments:
self.namespace[arg.name] = VarInfo(
Expand Down Expand Up @@ -494,7 +494,7 @@ def visit_For(self, node):

with self.namespace.enter_scope():
self.namespace[iter_name] = VarInfo(
possible_target_type, constancy=VariableConstancy.COMPILE_TIME_CONSTANT
possible_target_type, constancy=Modifiability.ALWAYS_CONSTANT
)

try:
Expand Down
10 changes: 5 additions & 5 deletions vyper/semantics/analysis/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
VariableDeclarationException,
VyperException,
)
from vyper.semantics.analysis.base import ImportInfo, ModuleInfo, VariableConstancy, VarInfo
from vyper.semantics.analysis.base import ImportInfo, Modifiability, ModuleInfo, VarInfo
from vyper.semantics.analysis.common import VyperNodeVisitorBase
from vyper.semantics.analysis.import_graph import ImportGraph
from vyper.semantics.analysis.local import ExprVisitor, validate_functions
Expand Down Expand Up @@ -263,11 +263,11 @@ def visit_VariableDecl(self, node):
)

constancy = (
VariableConstancy.RUNTIME_CONSTANT
Modifiability.IMMUTABLE
if node.is_immutable
else VariableConstancy.COMPILE_TIME_CONSTANT
else Modifiability.ALWAYS_CONSTANT
if node.is_constant
else VariableConstancy.MUTABLE
else Modifiability.MODIFIABLE
)

type_ = type_from_annotation(node.annotation, data_loc)
Expand Down Expand Up @@ -317,7 +317,7 @@ def _validate_self_namespace():

ExprVisitor().visit(node.value, type_)

if not check_variable_constancy(node.value, VariableConstancy.COMPILE_TIME_CONSTANT):
if not check_variable_constancy(node.value, Modifiability.ALWAYS_CONSTANT):
raise StateAccessViolation("Value must be a literal", node.value)

validate_expected_type(node.value, type_)
Expand Down
6 changes: 3 additions & 3 deletions vyper/semantics/analysis/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
ZeroDivisionException,
)
from vyper.semantics import types
from vyper.semantics.analysis.base import ExprInfo, ModuleInfo, VariableConstancy, VarInfo
from vyper.semantics.analysis.base import ExprInfo, Modifiability, ModuleInfo, VarInfo
from vyper.semantics.analysis.levenshtein_utils import get_levenshtein_error_suggestions
from vyper.semantics.namespace import get_namespace
from vyper.semantics.types.base import TYPE_T, VyperType
Expand Down Expand Up @@ -201,7 +201,7 @@ def _raise_invalid_reference(name, node):
if isinstance(s, (VyperType, TYPE_T)):
# ex. foo.bar(). bar() is a ContractFunctionT
return [s]
if is_self_reference and s.constancy >= VariableConstancy.RUNTIME_CONSTANT:
if is_self_reference and s.constancy >= Modifiability.IMMUTABLE:
_raise_invalid_reference(name, node)
# general case. s is a VarInfo, e.g. self.foo
return [s.typ]
Expand Down Expand Up @@ -639,7 +639,7 @@ def _check_literal(node: vy_ast.VyperNode) -> bool:
return False


def check_variable_constancy(node: vy_ast.VyperNode, constancy: VariableConstancy) -> bool:
def check_variable_constancy(node: vy_ast.VyperNode, constancy: Modifiability) -> bool:
"""
Check if the given node is a literal or constant value.
"""
Expand Down
4 changes: 2 additions & 2 deletions vyper/semantics/environment.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Dict

from vyper.semantics.analysis.base import VariableConstancy, VarInfo
from vyper.semantics.analysis.base import Modifiability, VarInfo
from vyper.semantics.types import AddressT, BytesT, VyperType
from vyper.semantics.types.shortcuts import BYTES32_T, UINT256_T

Expand Down Expand Up @@ -52,7 +52,7 @@ def get_constant_vars() -> Dict:
"""
result = {}
for k, v in CONSTANT_ENVIRONMENT_VARS.items():
result[k] = VarInfo(v, constancy=VariableConstancy.RUNTIME_CONSTANT)
result[k] = VarInfo(v, constancy=Modifiability.CONSTANT_IN_CURRENT_TX)

return result

Expand Down
4 changes: 2 additions & 2 deletions vyper/semantics/types/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@
)
from vyper.semantics.analysis.base import (
FunctionVisibility,
Modifiability,
StateMutability,
StorageSlot,
VariableConstancy,
)
from vyper.semantics.analysis.utils import (
check_variable_constancy,
Expand Down Expand Up @@ -702,7 +702,7 @@ def _parse_args(
positional_args.append(PositionalArg(argname, type_, ast_source=arg))
else:
value = funcdef.args.defaults[i - n_positional_args]
if not check_variable_constancy(value, VariableConstancy.RUNTIME_CONSTANT):
if not check_variable_constancy(value, Modifiability.IMMUTABLE):
raise StateAccessViolation("Value must be literal or environment variable", value)
validate_expected_type(value, type_)
keyword_args.append(KeywordArg(argname, type_, value, ast_source=arg))
Expand Down

0 comments on commit 83a8b66

Please sign in to comment.