Skip to content

Commit

Permalink
add compile time and runtime constants attributes
Browse files Browse the repository at this point in the history
  • Loading branch information
tserg committed Nov 12, 2023
1 parent 7f75749 commit 415e123
Show file tree
Hide file tree
Showing 8 changed files with 34 additions and 24 deletions.
1 change: 1 addition & 0 deletions vyper/ast/folding.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ def replace_builtin_functions(vyper_module: vy_ast.Module) -> int:
continue
try:
new_node = func.evaluate(node) # type: ignore
new_node._metadata["type"] = node._metadata["type"]
except UnfoldableNode:
continue

Expand Down
5 changes: 3 additions & 2 deletions vyper/builtins/_signatures.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import functools
from typing import Dict


from vyper.ast.validation import validate_call_args
from vyper.codegen.expr import Expr
from vyper.codegen.ir_node import IRnode
from vyper.exceptions import CompilerPanic, TypeMismatch
from vyper.semantics.analysis.utils import (
check_constant,
check_kwargable,
get_exact_type_from_node,
validate_expected_type,
)
Expand Down Expand Up @@ -107,7 +108,7 @@ def _validate_arg_types(self, node):

for kwarg in node.keywords:
kwarg_settings = self._kwargs[kwarg.arg]
if kwarg_settings.require_literal and not check_constant(kwarg.value):
if kwarg_settings.require_literal and not check_kwargable(kwarg.value):
raise TypeMismatch("Value for kwarg must be a literal", kwarg.value)
self._validate_single(kwarg.value, kwarg_settings.typ)

Expand Down
17 changes: 10 additions & 7 deletions vyper/semantics/analysis/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,8 @@ class VarInfo:

typ: VyperType
location: DataLocation = DataLocation.UNSET
is_constant: bool = False
is_compile_time_constant: bool = False
is_runtime_constant: bool = False
is_public: bool = False
is_immutable: bool = False
is_transient: bool = False
Expand Down Expand Up @@ -192,11 +193,12 @@ class ExprInfo:
typ: VyperType
var_info: Optional[VarInfo] = None
location: DataLocation = DataLocation.UNSET
is_constant: bool = False
is_compile_time_constant: bool = False
is_runtime_constant: bool = False
is_immutable: bool = False

def __post_init__(self):
should_match = ("typ", "location", "is_constant", "is_immutable")
should_match = ("typ", "location", "is_compile_time_constant", "is_runtime_constant", "is_immutable")
if self.var_info is not None:
for attr in should_match:
if getattr(self.var_info, attr) != getattr(self, attr):
Expand All @@ -208,15 +210,16 @@ def from_varinfo(cls, var_info: VarInfo) -> "ExprInfo":
var_info.typ,
var_info=var_info,
location=var_info.location,
is_constant=var_info.is_constant,
is_immutable=var_info.is_immutable,
is_compile_time_constant=var_info.is_compile_time_constant,
is_runtime_constant=var_info.is_runtime_constant,
is_immutable=var_info.is_immutable
)

def copy_with_type(self, typ: VyperType) -> "ExprInfo":
"""
Return a copy of the ExprInfo but with the type set to something else
"""
to_copy = ("location", "is_constant", "is_immutable")
to_copy = ("location", "is_compile_time_constant", "is_runtime_constant", "is_immutable")
fields = {k: getattr(self, k) for k in to_copy}
return self.__class__(typ=typ, **fields)

Expand All @@ -240,7 +243,7 @@ def validate_modification(self, node: vy_ast.VyperNode, mutability: StateMutabil

if self.location == DataLocation.CALLDATA:
raise ImmutableViolation("Cannot write to calldata", node)
if self.is_constant:
if self.is_compile_time_constant:
raise ImmutableViolation("Constant value cannot be written to", node)
if self.is_immutable:
if node.get_ancestor(vy_ast.FunctionDef).get("name") != "__init__":
Expand Down
4 changes: 2 additions & 2 deletions vyper/semantics/analysis/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ def __init__(
(DataLocation.MEMORY, False) if self.func.is_internal else (DataLocation.CALLDATA, True)
)
for arg in self.func.arguments:
namespace[arg.name] = VarInfo(arg.typ, location=location, is_immutable=is_immutable)
namespace[arg.name] = VarInfo(arg.typ, location=location, is_immutable=is_immutable, is_runtime_constant=is_immutable)

for node in fn_node.body:
self.visit(node)
Expand Down Expand Up @@ -473,7 +473,7 @@ def visit_For(self, node):

with self.namespace.enter_scope():
try:
self.namespace[iter_name] = VarInfo(possible_target_type, is_constant=True)
self.namespace[iter_name] = VarInfo(possible_target_type, is_compile_time_constant=True)
except VyperException as exc:
raise exc.with_annotation(node) from None

Expand Down
3 changes: 2 additions & 1 deletion vyper/semantics/analysis/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,9 +195,10 @@ def visit_VariableDecl(self, node):
type_,
decl_node=node,
location=data_loc,
is_constant=node.is_constant,
is_compile_time_constant=node.is_constant,
is_public=node.is_public,
is_immutable=node.is_immutable,
is_runtime_constant=node.is_immutable,
is_transient=node.is_transient,
)
node.target._metadata["varinfo"] = var_info # TODO maybe put this in the global namespace
Expand Down
22 changes: 13 additions & 9 deletions vyper/semantics/analysis/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ def _raise_invalid_reference(name, node):
if isinstance(s, VyperType):
# ex. foo.bar(). bar() is a ContractFunctionT
return [s]
if is_self_reference and (s.is_constant or s.is_immutable):
if is_self_reference and (s.is_compile_time_constant or s.is_immutable):
_raise_invalid_reference(name, node)
# general case. s is a VarInfo, e.g. self.foo
return [s.typ]
Expand Down Expand Up @@ -622,10 +622,16 @@ def check_kwargable(node: vy_ast.VyperNode) -> bool:
"""
if _check_literal(node):
return True
if isinstance(node, vy_ast.Attribute):
return check_kwargable(node.value)

if isinstance(node, vy_ast.BinOp):
return all(check_kwargable(i) for i in (node.left, node.right))

if isinstance(node, vy_ast.BoolOp):
return all(check_kwargable(i) for i in node.values)

if isinstance(node, (vy_ast.Tuple, vy_ast.List)):
return all(check_kwargable(item) for item in node.elements)

if isinstance(node, vy_ast.Call):
args = node.args
if len(args) == 1 and isinstance(args[0], vy_ast.Dict):
Expand All @@ -636,8 +642,7 @@ def check_kwargable(node: vy_ast.VyperNode) -> bool:
return True

value_type = get_expr_info(node)
# is_constant here actually means not_assignable, and is to be renamed
return value_type.is_constant
return value_type.is_runtime_constant


def _check_literal(node: vy_ast.VyperNode) -> bool:
Expand All @@ -662,10 +667,10 @@ def check_constant(node: vy_ast.VyperNode) -> bool:
return True

if isinstance(node, vy_ast.BinOp):
return all(check_kwargable(i) for i in (node.left, node.right))
return all(check_constant(i) for i in (node.left, node.right))

if isinstance(node, vy_ast.BoolOp):
return all(check_kwargable(i) for i in node.values)
return all(check_constant(i) for i in node.values)

if isinstance(node, (vy_ast.Tuple, vy_ast.List)):
return all(check_constant(item) for item in node.elements)
Expand All @@ -680,5 +685,4 @@ def check_constant(node: vy_ast.VyperNode) -> bool:
return True

value_type = get_expr_info(node)
# is_constant here actually means not_assignable, and is to be renamed
return value_type.is_constant
return value_type.is_compile_time_constant
2 changes: 1 addition & 1 deletion vyper/semantics/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def get_constant_vars() -> Dict:
"""
result = {}
for k, v in CONSTANT_ENVIRONMENT_VARS.items():
result[k] = VarInfo(v, is_constant=True)
result[k] = VarInfo(v, is_runtime_constant=True)

return result

Expand Down
4 changes: 2 additions & 2 deletions vyper/semantics/types/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
StructureException,
)
from vyper.semantics.analysis.base import FunctionVisibility, StateMutability, StorageSlot
from vyper.semantics.analysis.utils import check_constant, validate_expected_type
from vyper.semantics.analysis.utils import check_kwargable, validate_expected_type
from vyper.semantics.data_locations import DataLocation
from vyper.semantics.types.base import KwargSettings, VyperType
from vyper.semantics.types.primitives import BoolT
Expand Down Expand Up @@ -320,7 +320,7 @@ def from_FunctionDef(
positional_args.append(PositionalArg(argname, type_, ast_source=arg))
else:
value = node.args.defaults[i - n_positional_args]
if not check_constant(value):
if not check_kwargable(value):
raise StateAccessViolation(
"Value must be literal or environment variable", value
)
Expand Down

0 comments on commit 415e123

Please sign in to comment.