Skip to content

Commit

Permalink
refactor; consolidate derive to types utils
Browse files Browse the repository at this point in the history
  • Loading branch information
tserg committed Sep 25, 2023
1 parent 9102536 commit a5de4bd
Show file tree
Hide file tree
Showing 8 changed files with 65 additions and 96 deletions.
18 changes: 14 additions & 4 deletions vyper/ast/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -904,6 +904,12 @@ class Tuple(ExprNode):
__slots__ = ("elements",)
_translated_fields = {"elts": "elements"}

def derive(self, constants: dict):
val = [e.derive(constants) for e in self.elements]
if None in val:
return None
return val

def validate(self):
if not self.elements:
raise InvalidLiteral("Cannot have an empty tuple", self)
Expand All @@ -912,6 +918,12 @@ def validate(self):
class Dict(ExprNode):
__slots__ = ("keys", "values")

def derive(self, constants: dict):
values = [v.derive(constants) for v in self.args[0].values]
if any(v is None for v in values):
return None
return {k: v for (k, v) in zip(self.args[0].keys, values)}


class NameConstant(Constant):
__slots__ = ("value",)
Expand Down Expand Up @@ -1305,11 +1317,9 @@ class Call(ExprNode):
__slots__ = ("func", "args", "keywords", "keyword")

def derive(self, constants: dict):
# only return constant struct values
if len(self.args) == 1 and isinstance(self.args[0], Dict):
values = [v.derive(constants) for v in self.args[0].values]
if any(v is None for v in values):
return None
return {k: v for (k, v) in zip(self.args[0].keys, values)}
return self.args[0].derive(constants)
return None


Expand Down
17 changes: 3 additions & 14 deletions vyper/builtins/_signatures.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,13 @@
import functools
from typing import Dict

from vyper.ast import nodes as vy_ast
from vyper.ast.validation import validate_call_args
from vyper.codegen.expr import Expr
from vyper.codegen.ir_node import IRnode
from vyper.exceptions import CompilerPanic, TypeMismatch, UnfoldableNode, VyperException
from vyper.exceptions import CompilerPanic, TypeMismatch
from vyper.semantics.analysis.utils import get_exact_type_from_node, validate_expected_type
from vyper.semantics.namespace import get_namespace
from vyper.semantics.types import TYPE_T, KwargSettings, VyperType
from vyper.semantics.types.utils import type_from_annotation
from vyper.semantics.types.utils import derive_folded_value, type_from_annotation


def process_arg(arg, expected_arg_type, context):
Expand Down Expand Up @@ -103,18 +101,9 @@ def _validate_arg_types(self, node):
for arg, (_, expected) in zip(node.args, self._inputs):
self._validate_single(arg, expected)

ns = get_namespace()
for kwarg in node.keywords:
kwarg_settings = self._kwargs[kwarg.arg]
is_literal_value = kwarg.value.derive(ns._constants) is not None
if isinstance(kwarg.value, vy_ast.Call):
call_type = get_exact_type_from_node(kwarg.value.func)
if hasattr(call_type, "evaluate"):
try:
call_type.evaluate(kwarg.value)
is_literal_value = True
except (UnfoldableNode, VyperException):
pass
is_literal_value = derive_folded_value(kwarg.value) is not None

if kwarg_settings.require_literal and not is_literal_value:
raise TypeMismatch("Value for kwarg must be a literal", kwarg.value)
Expand Down
27 changes: 3 additions & 24 deletions vyper/builtins/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@
StructureException,
TypeMismatch,
UnfoldableNode,
VyperException,
ZeroDivisionException,
)
from vyper.semantics.analysis.base import VarInfo
Expand All @@ -60,7 +59,6 @@
get_possible_types_from_node,
validate_expected_type,
)
from vyper.semantics.namespace import get_namespace
from vyper.semantics.types import (
TYPE_T,
AddressT,
Expand All @@ -85,7 +83,7 @@
UINT8_T,
UINT256_T,
)
from vyper.semantics.types.utils import type_from_annotation
from vyper.semantics.types.utils import derive_folded_value, type_from_annotation
from vyper.utils import (
DECIMAL_DIVISOR,
EIP_170_LIMIT,
Expand Down Expand Up @@ -1062,25 +1060,6 @@ def build_IR(self, expr, args, kwargs, context):
empty_value = IRnode.from_list(0, typ=BYTES32_T)


def derive_kwarg_value(kwarg, call_type):
if kwarg is None:
return None

ns = get_namespace()
kwarg_val = kwarg.derive(ns._constants)
if kwarg_val is not None:
return kwarg_val

if isinstance(kwarg, vy_ast.Call):
try:
evaluated = call_type.evaluate(kwarg)
return evaluated.value
except (UnfoldableNode, VyperException):
pass

return None


class RawCall(BuiltinFunction):
_id = "raw_call"
_inputs = [("to", AddressT()), ("data", BytesT.any())]
Expand All @@ -1099,8 +1078,8 @@ def fetch_call_return(self, node):

kwargz = {i.arg: i.value for i in node.keywords}

outsize = derive_kwarg_value(kwargz.get("max_outsize"), self)
revert_on_failure = derive_kwarg_value(kwargz.get("revert_on_failure"), self)
outsize = derive_folded_value(kwargz.get("max_outsize"))
revert_on_failure = derive_folded_value(kwargz.get("revert_on_failure"))
revert_on_failure = revert_on_failure if revert_on_failure is not None else True

if outsize is None or outsize == 0:
Expand Down
14 changes: 7 additions & 7 deletions vyper/semantics/analysis/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
is_type_t,
)
from vyper.semantics.types.function import ContractFunctionT, MemberFunctionT, StateMutability
from vyper.semantics.types.utils import type_from_annotation
from vyper.semantics.types.utils import derive_folded_value, type_from_annotation


def validate_functions(vy_module: vy_ast.Module) -> None:
Expand Down Expand Up @@ -358,15 +358,15 @@ def visit_For(self, node):
validate_expected_type(n, IntegerT.any())

if bound is None:
n_val = n.derive(self.namespace._constants)
n_val = derive_folded_value(n)
if n_val is None:
raise StateAccessViolation("Value must be a literal", n)
if n_val <= 0:
raise StructureException("For loop must have at least 1 iteration", args[0])
type_list = get_possible_types_from_node(n)

else:
bound_val = bound.derive(self.namespace._constants)
bound_val = derive_folded_value(bound)
if bound_val is None:
raise StateAccessViolation("bound must be a literal", bound)
if bound_val <= 0:
Expand All @@ -383,7 +383,7 @@ def visit_For(self, node):

validate_expected_type(args[0], IntegerT.any())
type_list = get_common_types(*args)
arg0_val = args[0].derive(self.namespace._constants)
arg0_val = derive_folded_value(args[0])
if arg0_val is None:
# range(x, x + CONSTANT)
if not isinstance(args[1], vy_ast.BinOp) or not isinstance(
Expand All @@ -397,7 +397,7 @@ def visit_For(self, node):
"First and second variable must be the same", args[1].left
)

right_val = args[1].right.derive(self.namespace._constants)
right_val = derive_folded_value(args[1].right)
if not isinstance(args[1].right, vy_ast.Int) and not (
isinstance(args[1].right, vy_ast.Name) and right_val
):
Expand All @@ -410,7 +410,7 @@ def visit_For(self, node):
)
else:
# range(CONSTANT, CONSTANT)
arg1_val = args[1].derive(self.namespace._constants)
arg1_val = derive_folded_value(args[1])
if not arg1_val:
raise InvalidType("Value must be a literal integer", args[1])
validate_expected_type(args[1], IntegerT.any())
Expand All @@ -422,7 +422,7 @@ def visit_For(self, node):

else:
# iteration over a variable or literal list
iter_ = node.iter.derive(self.namespace._constants)
iter_ = derive_folded_value(node.iter)
if isinstance(iter_, list) and len(iter_) == 0:
raise StructureException("For loop must have at least 1 iteration", node.iter)

Expand Down
6 changes: 3 additions & 3 deletions vyper/semantics/analysis/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from vyper.semantics.namespace import Namespace, get_namespace
from vyper.semantics.types import EnumT, EventT, InterfaceT, StructT
from vyper.semantics.types.function import ContractFunctionT
from vyper.semantics.types.utils import type_from_annotation
from vyper.semantics.types.utils import derive_folded_value, type_from_annotation
from vyper.typing import InterfaceDict


Expand Down Expand Up @@ -80,7 +80,7 @@ def __init__(
if c.value is None:
continue

val = c.value.derive(self.namespace._constants)
val = derive_folded_value(c.value)
self.namespace.add_constant(name, val)

if val is not None:
Expand Down Expand Up @@ -269,7 +269,7 @@ def _validate_self_namespace():
if not node.value:
raise VariableDeclarationException("Constant must be declared with a value", node)
# TODO: move to check_constant
if not node.value.derive(self.namespace._constants) and not check_constant(node.value):
if not check_constant(node.value):
raise StateAccessViolation("Value must be a literal", node.value)

validate_expected_type(node.value, type_)
Expand Down
44 changes: 3 additions & 41 deletions vyper/semantics/analysis/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
StructureException,
TypeMismatch,
UndeclaredDefinition,
UnfoldableNode,
UnknownAttribute,
VyperException,
ZeroDivisionException,
Expand All @@ -25,6 +24,7 @@
from vyper.semantics.types.bytestrings import BytesT, StringT
from vyper.semantics.types.primitives import AddressT, BoolT, BytesM_T, IntegerT
from vyper.semantics.types.subscriptable import DArrayT, SArrayT, TupleT
from vyper.semantics.types.utils import derive_folded_value
from vyper.utils import checksum_encode, int_to_fourbytes


Expand Down Expand Up @@ -624,65 +624,27 @@ def check_kwargable(node: vy_ast.VyperNode) -> bool:
"""
Check if the given node can be used as a default arg
"""
if _check_literal(node):
if derive_folded_value(node) is not None:
return True
if isinstance(node, (vy_ast.Tuple, vy_ast.List)):
return all(check_kwargable(item) for item in node.elements)
if isinstance(node, vy_ast.Call):
args = node.args
if len(args) == 1 and isinstance(args[0], vy_ast.Dict):
return all(check_kwargable(v) for v in args[0].values)

call_type = get_exact_type_from_node(node.func)
if getattr(call_type, "_kwargable", False):
return True

if getattr(call_type, "evaluate", False):
try:
call_type.evaluate(node)
return True
except (UnfoldableNode, VyperException):
return False

value_type = get_expr_info(node)
# is_constant here actually means not_assignable, and is to be renamed
return value_type.is_constant


def _check_literal(node: vy_ast.VyperNode) -> bool:
"""
Check if the given node is a literal value.
"""
ns = get_namespace()
val = node.derive(ns._constants)
if val is not None:
return True

return False


def check_constant(node: vy_ast.VyperNode) -> bool:
"""
Check if the given node is a literal or constant value.
"""
if _check_literal(node):
if derive_folded_value(node) is not None:
return True
if isinstance(node, (vy_ast.Tuple, vy_ast.List)):
return all(check_constant(item) for item in node.elements)
if isinstance(node, vy_ast.Call):
args = node.args
if len(args) == 1 and isinstance(args[0], vy_ast.Dict):
return all(check_constant(v) for v in args[0].values)

call_type = get_exact_type_from_node(node.func)
if getattr(call_type, "_kwargable", False):
return True

if getattr(call_type, "evaluate", False):
try:
call_type.evaluate(node)
return True
except (UnfoldableNode, VyperException):
return False

return False
4 changes: 2 additions & 2 deletions vyper/semantics/types/subscriptable.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from vyper.semantics.types.base import VyperType
from vyper.semantics.types.primitives import IntegerT
from vyper.semantics.types.shortcuts import UINT256_T
from vyper.semantics.types.utils import get_index_value, type_from_annotation
from vyper.semantics.types.utils import derive_folded_value, get_index_value, type_from_annotation


class _SubscriptableT(VyperType):
Expand Down Expand Up @@ -287,7 +287,7 @@ def from_annotation(cls, node: vy_ast.Subscript, constants: dict) -> "DArrayT":
node,
)

max_length = node.slice.value.elements[1].derive(constants)
max_length = derive_folded_value(node.slice.value.elements[1])
if not max_length or not isinstance(max_length, int):
raise StructureException(
"DynArray must have a max length of integer type, e.g. DynArray[bool, 5]", node
Expand Down
31 changes: 30 additions & 1 deletion vyper/semantics/types/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
InstantiationException,
InvalidType,
StructureException,
UnfoldableNode,
UnknownType,
VyperException,
)
from vyper.semantics.analysis.levenshtein_utils import get_levenshtein_error_suggestions
from vyper.semantics.data_locations import DataLocation
Expand Down Expand Up @@ -132,6 +134,33 @@ def _failwith(type_name):
return typ_


def derive_literal_value(node: vy_ast.VyperNode):
ns = get_namespace()
val = node.derive(ns._constants)
return val


def derive_folded_value(node: vy_ast.VyperNode):
if node is None:
return None

val = derive_literal_value(node)
if val is not None:
return val

if isinstance(node, vy_ast.Call):
from vyper.semantics.analysis.utils import get_exact_type_from_node

call_type = get_exact_type_from_node(node.func)
try:
evaluated = call_type.evaluate(node)
return evaluated.value
except (UnfoldableNode, VyperException):
pass

return None


def get_index_value(node: vy_ast.Index, constants: dict) -> int:
"""
Return the literal value for a `Subscript` index.
Expand All @@ -151,7 +180,7 @@ def get_index_value(node: vy_ast.Index, constants: dict) -> int:
# TODO: revisit this!
from vyper.semantics.analysis.utils import get_possible_types_from_node

val = node.value.derive(constants)
val = derive_folded_value(node.value)

if not isinstance(val, int):
if hasattr(node, "value"):
Expand Down

0 comments on commit a5de4bd

Please sign in to comment.