Skip to content

Commit

Permalink
combine check_kwargable and check_constant
Browse files Browse the repository at this point in the history
  • Loading branch information
tserg committed Nov 16, 2023
1 parent 8785e02 commit 926fef6
Show file tree
Hide file tree
Showing 5 changed files with 16 additions and 44 deletions.
7 changes: 4 additions & 3 deletions vyper/builtins/_signatures.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@
from vyper.codegen.expr import Expr
from vyper.codegen.ir_node import IRnode
from vyper.exceptions import CompilerPanic, TypeMismatch, UnfoldableNode, VyperException
from vyper.semantics.analysis.base import VariableConstancy
from vyper.semantics.analysis.utils import (
check_kwargable,
check_variable_constancy,
get_exact_type_from_node,
validate_expected_type,
)
Expand Down Expand Up @@ -107,8 +108,8 @@ def _validate_arg_types(self, node):

for kwarg in node.keywords:
kwarg_settings = self._kwargs[kwarg.arg]
if kwarg_settings.require_literal and not check_kwargable(kwarg.value):
raise TypeMismatch("Value for kwarg must be a literal", kwarg.value)
if kwarg_settings.require_literal and not check_variable_constancy(kwarg.value, VariableConstancy.RUNTIME_CONSTANT):
raise TypeMismatch("Value must be literal or environment variable", kwarg.value)
self._validate_single(kwarg.value, kwarg_settings.typ)

# typecheck varargs. we don't have type info from the signature,
Expand Down
2 changes: 1 addition & 1 deletion vyper/builtins/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ class FoldedFunction(BuiltinFunction):
# Base class for nodes which should always be folded

# Since foldable builtin functions are not folded before semantics validation,
# this flag is used for `check_kwargable` in semantics validation.
# this flag is used for `check_variable_constancy` in semantics validation.
_kwargable = True
# Skip annotation of builtins if it will be folded before codegen
_always_folded_before_codegen = True
Expand Down
4 changes: 2 additions & 2 deletions vyper/semantics/analysis/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from vyper.semantics.analysis.base import VariableConstancy, VarInfo
from vyper.semantics.analysis.common import VyperNodeVisitorBase
from vyper.semantics.analysis.local import ExprVisitor
from vyper.semantics.analysis.utils import check_constant, validate_expected_type
from vyper.semantics.analysis.utils import check_variable_constancy, validate_expected_type
from vyper.semantics.data_locations import DataLocation
from vyper.semantics.namespace import Namespace, get_namespace
from vyper.semantics.types import EnumT, EventT, InterfaceT, StructT
Expand Down Expand Up @@ -242,7 +242,7 @@ def _validate_self_namespace():
if node.is_constant:
if not node.value:
raise VariableDeclarationException("Constant must be declared with a value", node)
if not check_constant(node.value):
if not check_variable_constancy(node.value, VariableConstancy.COMPILE_TIME_CONSTANT):
raise StateAccessViolation("Value must be a literal", node.value)

validate_expected_type(node.value, type_)
Expand Down
41 changes: 6 additions & 35 deletions vyper/semantics/analysis/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -616,35 +616,6 @@ def validate_unique_method_ids(functions: List) -> None:
seen.add(method_id)


def check_kwargable(node: vy_ast.VyperNode) -> bool:
"""
Check if the given node can be used as a default arg
"""
if _check_literal(node):
return True

if isinstance(node, vy_ast.BinOp):
return all(check_kwargable(i) for i in (node.left, node.right))

if isinstance(node, vy_ast.BoolOp):
return all(check_kwargable(i) for i in node.values)

if isinstance(node, (vy_ast.Tuple, vy_ast.List)):
return all(check_kwargable(item) for item in node.elements)

if isinstance(node, vy_ast.Call):
args = node.args
if len(args) == 1 and isinstance(args[0], vy_ast.Dict):
return all(check_kwargable(v) for v in args[0].values)

call_type = get_exact_type_from_node(node.func)
if getattr(call_type, "_kwargable", False):
return True

value_type = get_expr_info(node)
return value_type.constancy >= VariableConstancy.RUNTIME_CONSTANT


def _check_literal(node: vy_ast.VyperNode) -> bool:
"""
Check if the given node is a literal value.
Expand All @@ -659,30 +630,30 @@ def _check_literal(node: vy_ast.VyperNode) -> bool:
return False


def check_constant(node: vy_ast.VyperNode) -> bool:
def check_variable_constancy(node: vy_ast.VyperNode, constancy: VariableConstancy) -> bool:
"""
Check if the given node is a literal or constant value.
"""
if _check_literal(node):
return True

if isinstance(node, vy_ast.BinOp):
return all(check_constant(i) for i in (node.left, node.right))
return all(check_variable_constancy(i, constancy) for i in (node.left, node.right))

if isinstance(node, vy_ast.BoolOp):
return all(check_constant(i) for i in node.values)
return all(check_variable_constancy(i, constancy) for i in node.values)

if isinstance(node, (vy_ast.Tuple, vy_ast.List)):
return all(check_constant(item) for item in node.elements)
return all(check_variable_constancy(item, constancy) for item in node.elements)

if isinstance(node, vy_ast.Call):
args = node.args
if len(args) == 1 and isinstance(args[0], vy_ast.Dict):
return all(check_constant(v) for v in args[0].values)
return all(check_variable_constancy(v, constancy) for v in args[0].values)

call_type = get_exact_type_from_node(node.func)
if getattr(call_type, "_kwargable", False):
return True

value_type = get_expr_info(node)
return value_type.constancy == VariableConstancy.COMPILE_TIME_CONSTANT
return value_type.constancy >= constancy
6 changes: 3 additions & 3 deletions vyper/semantics/types/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
StateAccessViolation,
StructureException,
)
from vyper.semantics.analysis.base import FunctionVisibility, StateMutability, StorageSlot
from vyper.semantics.analysis.utils import check_kwargable, validate_expected_type
from vyper.semantics.analysis.base import FunctionVisibility, StateMutability, StorageSlot, VariableConstancy
from vyper.semantics.analysis.utils import check_variable_constancy, validate_expected_type
from vyper.semantics.data_locations import DataLocation
from vyper.semantics.types.base import KwargSettings, VyperType
from vyper.semantics.types.primitives import BoolT
Expand Down Expand Up @@ -320,7 +320,7 @@ def from_FunctionDef(
positional_args.append(PositionalArg(argname, type_, ast_source=arg))
else:
value = node.args.defaults[i - n_positional_args]
if not check_kwargable(value):
if not check_variable_constancy(value, VariableConstancy.RUNTIME_CONSTANT):
raise StateAccessViolation(
"Value must be literal or environment variable", value
)
Expand Down

0 comments on commit 926fef6

Please sign in to comment.