Skip to content

Commit

Permalink
move all check_modifiability checks to be after a validate_expected_t…
Browse files Browse the repository at this point in the history
…ype (which calls validate_call_args)
  • Loading branch information
charles-cooper committed Jan 15, 2024
1 parent 4924a85 commit bf7b346
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 23 deletions.
19 changes: 12 additions & 7 deletions vyper/builtins/_signatures.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,9 @@ class BuiltinFunctionT(VyperType):
_is_terminus = False

# helper function to deal with TYPE_DEFINITIONs
def _validate_single(self, arg: vy_ast.VyperNode, expected_type: VyperType) -> None:
def _validate_single(
self, arg: vy_ast.VyperNode, expected_type: VyperType, modifiability: Modifiability
) -> None:
# TODO using "TYPE_DEFINITION" is a kludge in derived classes,
# refactor me.
if expected_type == "TYPE_DEFINITION":
Expand All @@ -97,6 +99,9 @@ def _validate_single(self, arg: vy_ast.VyperNode, expected_type: VyperType) -> N
type_from_annotation(arg)
else:
validate_expected_type(arg, expected_type)
if not check_modifiability(arg, modifiability):
# CMC 2024-01-15 TODO: change to StateAccessViolation
raise TypeMismatch("Value must be literal", arg)

def _validate_arg_types(self, node: vy_ast.Call) -> None:
num_args = len(self._inputs) # the number of args the signature indicates
Expand All @@ -109,15 +114,15 @@ def _validate_arg_types(self, node: vy_ast.Call) -> None:
validate_call_args(node, expect_num_args, list(self._kwargs.keys()))

for arg, (_, expected) in zip(node.args, self._inputs):
self._validate_single(arg, expected)
self._validate_single(arg, expected, Modifiability.MODIFIABLE)

for kwarg in node.keywords:
kwarg_settings = self._kwargs[kwarg.arg]
if kwarg_settings.require_literal and not check_modifiability(
kwarg.value, Modifiability.CONSTANT
):
raise TypeMismatch("Value must be literal", kwarg.value)
self._validate_single(kwarg.value, kwarg_settings.typ)

modifiability = Modifiability.MODIFIABLE
if kwarg_settings.require_literal:
modifiability = Modifiability.CONSTANT
self._validate_single(kwarg.value, kwarg_settings.typ, modifiability)

# typecheck varargs. we don't have type info from the signature,
# so ensure that the types of the args can be inferred exactly.
Expand Down
7 changes: 6 additions & 1 deletion vyper/semantics/analysis/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from vyper.semantics.analysis.base import Modifiability, VarInfo
from vyper.semantics.analysis.common import VyperNodeVisitorBase
from vyper.semantics.analysis.utils import (
check_modifiability,
get_common_types,
get_exact_type_from_node,
get_expr_info,
Expand Down Expand Up @@ -214,7 +215,11 @@ def analyze(self):
# visit default args
assert self.func.n_keyword_args == len(self.fn_node.args.defaults)
for kwarg in self.func.keyword_args:
self.expr_visitor.visit(kwarg.default_value, kwarg.typ)
value = kwarg.default_value
self.expr_visitor.visit(value, kwarg.typ)
# CMC 2024-01-15 move these check_modifiability checks into expr visitor
if not check_modifiability(value, Modifiability.RUNTIME_CONSTANT):
raise StateAccessViolation("Value must be literal or environment variable", value)

def visit(self, node):
super().visit(node)
Expand Down
17 changes: 2 additions & 15 deletions vyper/semantics/types/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,20 +13,10 @@
CompilerPanic,
FunctionDeclarationException,
InvalidType,
StateAccessViolation,
StructureException,
)
from vyper.semantics.analysis.base import (
FunctionVisibility,
Modifiability,
StateMutability,
StorageSlot,
)
from vyper.semantics.analysis.utils import (
check_modifiability,
get_exact_type_from_node,
validate_expected_type,
)
from vyper.semantics.analysis.base import FunctionVisibility, StateMutability, StorageSlot
from vyper.semantics.analysis.utils import get_exact_type_from_node, validate_expected_type
from vyper.semantics.data_locations import DataLocation
from vyper.semantics.types.base import KwargSettings, VyperType
from vyper.semantics.types.primitives import BoolT
Expand Down Expand Up @@ -701,9 +691,6 @@ def _parse_args(
positional_args.append(PositionalArg(argname, type_, ast_source=arg))
else:
value = funcdef.args.defaults[i - n_positional_args]
if not check_modifiability(value, Modifiability.RUNTIME_CONSTANT):
raise StateAccessViolation("Value must be literal or environment variable", value)
validate_expected_type(value, type_)
keyword_args.append(KeywordArg(argname, type_, value, ast_source=arg))

argnames.add(argname)
Expand Down

0 comments on commit bf7b346

Please sign in to comment.