Skip to content

Commit

Permalink
more fixes wip
Browse files Browse the repository at this point in the history
  • Loading branch information
tserg committed Nov 1, 2023
1 parent 55359e7 commit a5d19d8
Show file tree
Hide file tree
Showing 8 changed files with 51 additions and 48 deletions.
3 changes: 2 additions & 1 deletion vyper/builtins/_signatures.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from vyper.semantics.analysis.utils import get_exact_type_from_node, validate_expected_type
from vyper.semantics.types import TYPE_T, KwargSettings, VyperType
from vyper.semantics.types.utils import type_from_annotation
from vyper.semantics.utils import get_folded_value


def process_arg(arg, expected_arg_type, context):
Expand Down Expand Up @@ -103,7 +104,7 @@ def _validate_arg_types(self, node):

for kwarg in node.keywords:
kwarg_settings = self._kwargs[kwarg.arg]
is_literal_value = kwarg.value._metadata.get("folded_value") is not None
is_literal_value = get_folded_value(kwarg.value) is not None

if kwarg_settings.require_literal and not is_literal_value:
raise TypeMismatch("Value for kwarg must be a literal", kwarg.value)
Expand Down
50 changes: 23 additions & 27 deletions vyper/builtins/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
ZeroDivisionException,
)
from vyper.semantics.analysis.base import VarInfo
from vyper.semantics.analysis.pre_typecheck import get_folded_value
from vyper.semantics.analysis.utils import (
get_common_types,
get_exact_type_from_node,
Expand Down Expand Up @@ -141,7 +142,7 @@ class Floor(BuiltinFunction):

def evaluate(self, node, skip_typecheck=False):
validate_call_args(node, 1)
arg = node.args[0]._metadata.get("folded_value")
arg = get_folded_value(node.args[0])
if not isinstance(arg, vy_ast.Decimal):
raise UnfoldableNode

Expand Down Expand Up @@ -172,7 +173,7 @@ class Ceil(BuiltinFunction):

def evaluate(self, node, skip_typecheck=False):
validate_call_args(node, 1)
arg = node.args[0]._metadata.get("folded_value")
arg = get_folded_value(node.args[0])
if not isinstance(arg, vy_ast.Decimal):
raise UnfoldableNode

Expand Down Expand Up @@ -467,9 +468,9 @@ class Len(BuiltinFunction):

def evaluate(self, node, skip_typecheck=False):
validate_call_args(node, 1)
arg = node.args[0]._metadata.get("folded_value")
arg = get_folded_value(node.args[0])
if isinstance(arg, (vy_ast.Str, vy_ast.Bytes)):
length = len(arg)
length = len(arg.value)
elif isinstance(arg, vy_ast.Hex):
# 2 characters represent 1 byte and we subtract 1 to ignore the leading `0x`
length = len(arg.value) // 2 - 1
Expand Down Expand Up @@ -604,7 +605,7 @@ class Keccak256(BuiltinFunction):

def evaluate(self, node, skip_typecheck=False):
validate_call_args(node, 1)
arg = node.args[0]._metadata.get("folded_value")
arg = get_folded_value(node.args[0])
if isinstance(arg, vy_ast.Bytes):
value = arg.value
elif isinstance(arg, vy_ast.Str):
Expand Down Expand Up @@ -652,7 +653,7 @@ class Sha256(BuiltinFunction):

def evaluate(self, node, skip_typecheck=False):
validate_call_args(node, 1)
arg = node.args[0]._metadata.get("folded_value")
arg = get_folded_value(node.args[0])
if isinstance(arg, vy_ast.Bytes):
value = arg.value
elif isinstance(arg, vy_ast.Str):
Expand Down Expand Up @@ -980,7 +981,7 @@ class AsWeiValue(BuiltinFunction):
}

def get_denomination(self, node):
arg = node.args[1]._metadata.get("folded_value")
arg = get_folded_value(node.args[1])
if not isinstance(arg, vy_ast.Str):
raise ArgumentException(
"Wei denomination must be given as a literal string", node.args[1]
Expand All @@ -996,7 +997,7 @@ def evaluate(self, node, skip_typecheck=False):
validate_call_args(node, 2)
denom = self.get_denomination(node)

arg = node.args[0]._metadata.get("folded_value")
arg = get_folded_value(node.args[0])
if not isinstance(arg, (vy_ast.Decimal, vy_ast.Int)):
raise UnfoldableNode

Expand Down Expand Up @@ -1082,10 +1083,10 @@ def fetch_call_return(self, node):

outsize = kwargz.get("max_outsize")
if outsize is not None:
outsize = outsize._metadata.get("folded_value")
outsize = get_folded_value(outsize)
revert_on_failure = kwargz.get("revert_on_failure")
if revert_on_failure is not None:
revert_on_failure = revert_on_failure._metadata.get("folded_value")
revert_on_failure = get_folded_value(revert_on_failure)

revert_on_failure = revert_on_failure if revert_on_failure is not None else True

Expand Down Expand Up @@ -1358,7 +1359,7 @@ def evaluate(self, node, skip_typecheck=False):
self.__class__._warned = True

validate_call_args(node, 2)
args = [i._metadata.get("folded_value") for i in node.args]
args = [get_folded_value(i) for i in node.args]
for v, arg in zip(args, node.args):
if not isinstance(v, vy_ast.Int):
raise UnfoldableNode
Expand All @@ -1385,7 +1386,7 @@ def evaluate(self, node, skip_typecheck=False):
self.__class__._warned = True

validate_call_args(node, 2)
args = [i._metadata.get("folded_value") for i in node.args]
args = [get_folded_value(i) for i in node.args]
for v, arg in zip(args, node.args):
if not isinstance(arg, vy_ast.Int):
raise UnfoldableNode
Expand All @@ -1412,7 +1413,7 @@ def evaluate(self, node, skip_typecheck=False):
self.__class__._warned = True

validate_call_args(node, 2)
args = [i._metadata.get("folded_value") for i in node.args]
args = [get_folded_value(i) for i in node.args]
for v, arg in zip(args, node.args):
if not isinstance(arg, vy_ast.Int):
raise UnfoldableNode
Expand All @@ -1439,7 +1440,7 @@ def evaluate(self, node, skip_typecheck=False):
self.__class__._warned = True

validate_call_args(node, 1)
arg = node.args[0]._metadata.get("folded_value")
arg = get_folded_value(node.args[0])
if not isinstance(arg, vy_ast.Int):
raise UnfoldableNode

Expand All @@ -1466,7 +1467,7 @@ def evaluate(self, node, skip_typecheck=False):
self.__class__._warned = True

validate_call_args(node, 2)
value, shift = [i._metadata.get("folded_value") for i in node.args]
value, shift = [get_folded_value(i) for i in node.args]
if any(not isinstance(i, int) for i in [value, shift]):
raise UnfoldableNode
if value < 0 or value >= 2**256:
Expand Down Expand Up @@ -1514,11 +1515,11 @@ class _AddMulMod(BuiltinFunction):

def evaluate(self, node, skip_typecheck=False):
validate_call_args(node, 3)
args = [i._metadata.get("folded_value") for i in node.args]
args = [get_folded_value(i) for i in node.args]
if isinstance(args[2], vy_ast.Int) and args[2] == 0:
raise ZeroDivisionException("Modulo by 0", node.args[2])
for v, arg in zip(args, node.args):
if not isinstance(v, int):
if not isinstance(v, vy_ast.Int):
raise UnfoldableNode
if v.value < 0 or v.value >= 2**256:
raise InvalidLiteral("Value out of range for uint256", arg)
Expand Down Expand Up @@ -1557,7 +1558,7 @@ class PowMod256(BuiltinFunction):

def evaluate(self, node, skip_typecheck=False):
validate_call_args(node, 2)
args = [i._metadata.get("folded_value") for i in node.args]
args = [get_folded_value(i) for i in node.args]
if any(not isinstance(i, vy_ast.Int) for i in args):
raise UnfoldableNode

Expand All @@ -1581,7 +1582,7 @@ class Abs(BuiltinFunction):

def evaluate(self, node, skip_typecheck=False):
validate_call_args(node, 1)
arg = node.args[0]._metadata.get("folded_value")
arg = get_folded_value(node.args[0])
if not isinstance(arg, vy_ast.Int):
raise UnfoldableNode

Expand Down Expand Up @@ -2025,7 +2026,7 @@ class _MinMax(BuiltinFunction):

def evaluate(self, node, skip_typecheck=False):
validate_call_args(node, 2)
args = [i._metadata.get("folded_value") for i in node.args]
args = [get_folded_value(i) for i in node.args]
if not isinstance(args[0], type(args[1])):
raise UnfoldableNode
if not isinstance(args[0], (vy_ast.Decimal, vy_ast.Int)):
Expand All @@ -2047,12 +2048,7 @@ def evaluate(self, node, skip_typecheck=False):
raise TypeMismatch("Cannot perform action between dislike numeric types", node)

value = self._eval_fn(left.value, right.value)

if isinstance(left, Decimal):
node = vy_ast.Decimal.from_node(node, value=value)
elif isinstance(left, int):
node = vy_ast.Int.from_node(node, value=value)
return node
return type(left).from_node(node, value=value)

def fetch_call_return(self, node):
self._validate_arg_types(node)
Expand Down Expand Up @@ -2119,7 +2115,7 @@ def fetch_call_return(self, node):

def evaluate(self, node, skip_typecheck=False):
validate_call_args(node, 1)
arg = node.args[0]._metadata.get("folded_value")
arg = get_folded_value(node.args[0])
if not isinstance(arg, vy_ast.Int):
raise UnfoldableNode

Expand Down
13 changes: 7 additions & 6 deletions vyper/semantics/analysis/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
)
from vyper.semantics.types.function import ContractFunctionT, MemberFunctionT, StateMutability
from vyper.semantics.types.utils import type_from_annotation
from vyper.semantics.utils import get_folded_value


def validate_functions(vy_module: vy_ast.Module) -> None:
Expand Down Expand Up @@ -353,7 +354,7 @@ def visit_For(self, node):
if len(args) == 1:
# range(CONSTANT)
n = args[0]
folded_n = n._metadata.get("folded_value")
folded_n = get_folded_value(n)

bound = kwargs.pop("bound", None)
validate_expected_type(n, IntegerT.any())
Expand All @@ -366,7 +367,7 @@ def visit_For(self, node):
type_list = get_possible_types_from_node(n)

else:
folded_bound = bound._metadata.get("folded_value")
folded_bound = get_folded_value(bound)
if folded_bound is None:
raise StateAccessViolation("bound must be a literal", bound)
if folded_bound.value <= 0:
Expand All @@ -383,7 +384,7 @@ def visit_For(self, node):

validate_expected_type(args[0], IntegerT.any())
type_list = get_common_types(*args)
folded_arg0 = args[0]._metadata.get("folded_value")
folded_arg0 = get_folded_value(args[0])
if not isinstance(folded_arg0, vy_ast.Constant):
# range(x, x + CONSTANT)
if not isinstance(args[1], vy_ast.BinOp) or not isinstance(
Expand All @@ -397,7 +398,7 @@ def visit_For(self, node):
"First and second variable must be the same", args[1].left
)

folded_right = args[1].right._metadata.get("folded_value")
folded_right = get_folded_value(args[1].right)
if not isinstance(folded_right, vy_ast.Int):
raise InvalidLiteral("Literal must be an integer", args[1].right)
if folded_right.value < 1:
Expand All @@ -408,7 +409,7 @@ def visit_For(self, node):
)
else:
# range(CONSTANT, CONSTANT)
folded_arg1 = args[1]._metadata.get("folded_value")
folded_arg1 = get_folded_value(args[1])
if not isinstance(folded_arg1, vy_ast.Int):
raise InvalidType("Value must be a literal integer", args[1])
validate_expected_type(folded_arg1, IntegerT.any())
Expand All @@ -420,7 +421,7 @@ def visit_For(self, node):

else:
# iteration over a variable or literal list
folded_iter = node.iter._metadata.get("folded_value")
folded_iter = get_folded_value(node.iter)
if isinstance(folded_iter, vy_ast.List) and len(folded_iter.elements) == 0:
raise StructureException("For loop must have at least 1 iteration", node.iter)

Expand Down
12 changes: 2 additions & 10 deletions vyper/semantics/analysis/pre_typecheck.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
from typing import Optional

from vyper import ast as vy_ast
from vyper.exceptions import UnfoldableNode, VyperException
from vyper.semantics.analysis.common import VyperNodeVisitorBase
from vyper.semantics.utils import get_folded_value


def pre_typecheck(node: vy_ast.VyperNode) -> None:
Expand Down Expand Up @@ -231,7 +230,7 @@ def visit_Compare(self, node):
node._metadata["folded_value"] = vy_ast.NameConstant.from_node(value=value)

def visit_Constant(self, node):
node._metadata["folded_value"] = node
pass

def visit_Dict(self, node):
for v in node.values:
Expand Down Expand Up @@ -288,10 +287,3 @@ def visit_IfExp(self, node):
self.visit(node.test)
self.visit(node.body)
self.visit(node.orelse)


def get_folded_value(node: vy_ast.VyperNode) -> Optional[vy_ast.VyperNode]:
if isinstance(node, vy_ast.Constant):
return node

return node._metadata.get("folded_value")
3 changes: 2 additions & 1 deletion vyper/semantics/analysis/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from vyper.semantics.types.bytestrings import BytesT, StringT
from vyper.semantics.types.primitives import AddressT, BoolT, BytesM_T, IntegerT
from vyper.semantics.types.subscriptable import DArrayT, SArrayT, TupleT
from vyper.semantics.utils import get_folded_value
from vyper.utils import checksum_encode, int_to_fourbytes


Expand Down Expand Up @@ -643,7 +644,7 @@ def check_constant(node: vy_ast.VyperNode) -> bool:
"""
Check if the given node is a literal or constant value.
"""
if node._metadata.get("folded_value") is not None:
if get_folded_value(node) is not None:
return True
if isinstance(node, vy_ast.Call):
call_type = get_exact_type_from_node(node.func)
Expand Down
5 changes: 3 additions & 2 deletions vyper/semantics/types/subscriptable.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from vyper.semantics.types.primitives import IntegerT
from vyper.semantics.types.shortcuts import UINT256_T
from vyper.semantics.types.utils import get_index_value, type_from_annotation
from vyper.semantics.utils import get_folded_value


class _SubscriptableT(VyperType):
Expand Down Expand Up @@ -128,7 +129,7 @@ def validate_index_type(self, node):
# TODO break this cycle
from vyper.semantics.analysis.utils import validate_expected_type

index = node._metadata.get("folded_value")
index = get_folded_value(node)
if isinstance(index, vy_ast.Int):
value = index.value
if value < 0:
Expand Down Expand Up @@ -287,7 +288,7 @@ def from_annotation(cls, node: vy_ast.Subscript) -> "DArrayT":
node,
)

folded_max_length = node.slice.value.elements[1]._metadata.get("folded_value")
folded_max_length = get_folded_value(node.slice.value.elements[1])
if not isinstance(folded_max_length, vy_ast.Int):
raise StructureException(
"DynArray must have a max length of integer type, e.g. DynArray[bool, 5]", node
Expand Down
3 changes: 2 additions & 1 deletion vyper/semantics/types/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from vyper.semantics.data_locations import DataLocation
from vyper.semantics.namespace import get_namespace
from vyper.semantics.types.base import VyperType
from vyper.semantics.utils import get_folded_value

# TODO maybe this should be merged with .types/base.py

Expand Down Expand Up @@ -139,7 +140,7 @@ def get_index_value(node: vy_ast.Index) -> int:
int
Literal integer value.
"""
folded_node = node.value._metadata.get("folded_value")
folded_node = get_folded_value(node.value)

if not isinstance(folded_node, vy_ast.Int):
raise InvalidType("Subscript must be a literal integer", node)
Expand Down
10 changes: 10 additions & 0 deletions vyper/semantics/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from typing import Optional

from vyper import ast as vy_ast


def get_folded_value(node: vy_ast.VyperNode) -> Optional[vy_ast.VyperNode]:
if isinstance(node, vy_ast.Constant):
return node

return node._metadata.get("folded_value")

0 comments on commit a5d19d8

Please sign in to comment.