diff --git a/vyper/builtins/_signatures.py b/vyper/builtins/_signatures.py index d39a4a085f..1b37f4c556 100644 --- a/vyper/builtins/_signatures.py +++ b/vyper/builtins/_signatures.py @@ -1,14 +1,9 @@ import functools -from typing import Dict -from vyper.ast import nodes as vy_ast -from vyper.ast.validation import validate_call_args from vyper.codegen.expr import Expr from vyper.codegen.ir_node import IRnode -from vyper.exceptions import CompilerPanic, TypeMismatch -from vyper.semantics.analysis.utils import get_exact_type_from_node, validate_expected_type -from vyper.semantics.types import TYPE_T, KwargSettings, VyperType -from vyper.semantics.types.utils import type_from_annotation +from vyper.exceptions import CompilerPanic +from vyper.semantics.types import TYPE_T, VyperType def process_arg(arg, expected_arg_type, context): @@ -72,72 +67,3 @@ def decorator_fn(self, node, context): return wrapped_fn(self, node, subs, kwsubs, context) return decorator_fn - - -class BuiltinFunction: - _has_varargs = False - _kwargs: Dict[str, KwargSettings] = {} - - # helper function to deal with TYPE_DEFINITIONs - def _validate_single(self, arg, expected_type): - # TODO using "TYPE_DEFINITION" is a kludge in derived classes, - # refactor me. - if expected_type == "TYPE_DEFINITION": - # try to parse the type - call type_from_annotation - # for its side effects (will throw if is not a type) - type_from_annotation(arg) - else: - validate_expected_type(arg, expected_type) - - def _validate_arg_types(self, node): - num_args = len(self._inputs) # the number of args the signature indicates - - expect_num_args = num_args - if self._has_varargs: - # note special meaning for -1 in validate_call_args API - expect_num_args = (num_args, -1) - - validate_call_args(node, expect_num_args, self._kwargs) - - for arg, (_, expected) in zip(node.args, self._inputs): - self._validate_single(arg, expected) - - for kwarg in node.keywords: - kwarg_settings = self._kwargs[kwarg.arg] - if kwarg_settings.require_literal and not isinstance(kwarg.value, vy_ast.Constant): - raise TypeMismatch("Value for kwarg must be a literal", kwarg.value) - self._validate_single(kwarg.value, kwarg_settings.typ) - - # typecheck varargs. we don't have type info from the signature, - # so ensure that the types of the args can be inferred exactly. - varargs = node.args[num_args:] - if len(varargs) > 0: - assert self._has_varargs # double check validate_call_args - for arg in varargs: - # call get_exact_type_from_node for its side effects - - # ensures the type can be inferred exactly. - get_exact_type_from_node(arg) - - def fetch_call_return(self, node): - self._validate_arg_types(node) - - if self._return_type: - return self._return_type - - def infer_arg_types(self, node): - self._validate_arg_types(node) - ret = [expected for (_, expected) in self._inputs] - - # handle varargs. - n_known_args = len(self._inputs) - varargs = node.args[n_known_args:] - if len(varargs) > 0: - assert self._has_varargs - ret.extend(get_exact_type_from_node(arg) for arg in varargs) - return ret - - def infer_kwarg_types(self, node): - return {i.arg: self._kwargs[i.arg].typ for i in node.keywords} - - def __repr__(self): - return f"(builtin) {self._id}" diff --git a/vyper/builtins/functions.py b/vyper/builtins/functions.py index 685d832c01..964db04baa 100644 --- a/vyper/builtins/functions.py +++ b/vyper/builtins/functions.py @@ -74,6 +74,7 @@ TupleT, ) from vyper.semantics.types.bytestrings import _BytestringT +from vyper.semantics.types.function import BuiltinFunctionT from vyper.semantics.types.shortcuts import ( BYTES4_T, BYTES32_T, @@ -98,14 +99,14 @@ ) from ._convert import convert -from ._signatures import BuiltinFunction, process_inputs +from ._signatures import process_inputs SHA256_ADDRESS = 2 SHA256_BASE_GAS = 60 SHA256_PER_WORD_GAS = 12 -class FoldedFunction(BuiltinFunction): +class FoldedFunctionT(BuiltinFunctionT): # Base class for nodes which should always be folded # Since foldable builtin functions are not folded before semantics validation, @@ -113,7 +114,7 @@ class FoldedFunction(BuiltinFunction): _kwargable = True -class TypenameFoldedFunction(FoldedFunction): +class TypenameFoldedFunctionT(FoldedFunctionT): # Base class for builtin functions that: # (1) take a typename as the only argument; and # (2) should always be folded. @@ -132,7 +133,7 @@ def infer_arg_types(self, node): return [input_typedef] -class Floor(BuiltinFunction): +class Floor(BuiltinFunctionT): _id = "floor" _inputs = [("value", DecimalT())] # TODO: maybe use int136? @@ -162,7 +163,7 @@ def build_IR(self, expr, args, kwargs, context): return b1.resolve(ret) -class Ceil(BuiltinFunction): +class Ceil(BuiltinFunctionT): _id = "ceil" _inputs = [("value", DecimalT())] # TODO: maybe use int136? @@ -192,7 +193,7 @@ def build_IR(self, expr, args, kwargs, context): return b1.resolve(ret) -class Convert(BuiltinFunction): +class Convert(BuiltinFunctionT): _id = "convert" def fetch_call_return(self, node): @@ -285,7 +286,7 @@ def _build_adhoc_slice_node(sub: IRnode, start: IRnode, length: IRnode, context: # note: this and a lot of other builtins could be refactored to accept any uint type -class Slice(BuiltinFunction): +class Slice(BuiltinFunctionT): _id = "slice" _inputs = [ ("b", (BYTES32_T, BytesT.any(), StringT.any())), @@ -457,7 +458,7 @@ def build_IR(self, expr, args, kwargs, context): return b1.resolve(b2.resolve(b3.resolve(ret))) -class Len(BuiltinFunction): +class Len(BuiltinFunctionT): _id = "len" _inputs = [("b", (StringT.any(), BytesT.any(), DArrayT.any()))] _return_type = UINT256_T @@ -482,7 +483,7 @@ def build_IR(self, node, context): return get_bytearray_length(arg) -class Concat(BuiltinFunction): +class Concat(BuiltinFunctionT): _id = "concat" def fetch_call_return(self, node): @@ -587,7 +588,7 @@ def build_IR(self, expr, context): ) -class Keccak256(BuiltinFunction): +class Keccak256(BuiltinFunctionT): _id = "keccak256" # TODO allow any BytesM_T _inputs = [("value", (BytesT.any(), BYTES32_T, StringT.any()))] @@ -635,7 +636,7 @@ def _make_sha256_call(inp_start, inp_len, out_start, out_len): ] -class Sha256(BuiltinFunction): +class Sha256(BuiltinFunctionT): _id = "sha256" _inputs = [("value", (BYTES32_T, BytesT.any(), StringT.any()))] _return_type = BYTES32_T @@ -707,7 +708,7 @@ def build_IR(self, expr, args, kwargs, context): ) -class MethodID(FoldedFunction): +class MethodID(FoldedFunctionT): _id = "method_id" def evaluate(self, node): @@ -747,7 +748,7 @@ def infer_kwarg_types(self, node): return BytesT(4) -class ECRecover(BuiltinFunction): +class ECRecover(BuiltinFunctionT): _id = "ecrecover" _inputs = [ ("hash", BYTES32_T), @@ -786,7 +787,7 @@ def _getelem(arg, ind): return unwrap_location(get_element_ptr(arg, IRnode.from_list(ind, typ=INT128_T))) -class ECAdd(BuiltinFunction): +class ECAdd(BuiltinFunctionT): _id = "ecadd" _inputs = [("a", SArrayT(UINT256_T, 2)), ("b", SArrayT(UINT256_T, 2))] _return_type = SArrayT(UINT256_T, 2) @@ -817,7 +818,7 @@ def build_IR(self, expr, args, kwargs, context): return b2.resolve(b1.resolve(o)) -class ECMul(BuiltinFunction): +class ECMul(BuiltinFunctionT): _id = "ecmul" _inputs = [("point", SArrayT(UINT256_T, 2)), ("scalar", UINT256_T)] _return_type = SArrayT(UINT256_T, 2) @@ -860,7 +861,7 @@ def _storage_element_getter(index): return IRnode.from_list(["sload", ["add", "_sub", ["add", 1, index]]], typ=INT128_T) -class Extract32(BuiltinFunction): +class Extract32(BuiltinFunctionT): _id = "extract32" _inputs = [("b", BytesT.any()), ("start", IntegerT.unsigneds())] # "TYPE_DEFINITION" is a placeholder value for a type definition string, and @@ -972,7 +973,7 @@ def build_IR(self, expr, args, kwargs, context): return IRnode.from_list(clamp_basetype(o), typ=ret_type) -class AsWeiValue(BuiltinFunction): +class AsWeiValue(BuiltinFunctionT): _id = "as_wei_value" _inputs = [("value", (IntegerT.any(), DecimalT())), ("unit", StringT.any())] _return_type = UINT256_T @@ -1071,7 +1072,7 @@ def build_IR(self, expr, args, kwargs, context): empty_value = IRnode.from_list(0, typ=BYTES32_T) -class RawCall(BuiltinFunction): +class RawCall(BuiltinFunctionT): _id = "raw_call" _inputs = [("to", AddressT()), ("data", BytesT.any())] _kwargs = { @@ -1225,7 +1226,7 @@ def build_IR(self, expr, args, kwargs, context): raise CompilerPanic("unreachable!") -class Send(BuiltinFunction): +class Send(BuiltinFunctionT): _id = "send" _inputs = [("to", AddressT()), ("value", UINT256_T)] # default gas stipend is 0 @@ -1242,7 +1243,7 @@ def build_IR(self, expr, args, kwargs, context): ) -class SelfDestruct(BuiltinFunction): +class SelfDestruct(BuiltinFunctionT): _id = "selfdestruct" _inputs = [("to", AddressT())] _return_type = None @@ -1261,7 +1262,7 @@ def build_IR(self, expr, args, kwargs, context): ) -class BlockHash(BuiltinFunction): +class BlockHash(BuiltinFunctionT): _id = "blockhash" _inputs = [("block_num", UINT256_T)] _return_type = BYTES32_T @@ -1274,7 +1275,7 @@ def build_IR(self, expr, args, kwargs, contact): ) -class RawRevert(BuiltinFunction): +class RawRevert(BuiltinFunctionT): _id = "raw_revert" _inputs = [("data", BytesT.any())] _return_type = None @@ -1296,7 +1297,7 @@ def build_IR(self, expr, args, kwargs, context): return b.resolve(IRnode.from_list(["revert", data, len_])) -class RawLog(BuiltinFunction): +class RawLog(BuiltinFunctionT): _id = "raw_log" _inputs = [("topics", DArrayT(BYTES32_T, 4)), ("data", (BYTES32_T, BytesT.any()))] @@ -1347,7 +1348,7 @@ def build_IR(self, expr, args, kwargs, context): ) -class BitwiseAnd(BuiltinFunction): +class BitwiseAnd(BuiltinFunctionT): _id = "bitwise_and" _inputs = [("x", UINT256_T), ("y", UINT256_T)] _return_type = UINT256_T @@ -1373,7 +1374,7 @@ def build_IR(self, expr, args, kwargs, context): return IRnode.from_list(["and", args[0], args[1]], typ=UINT256_T) -class BitwiseOr(BuiltinFunction): +class BitwiseOr(BuiltinFunctionT): _id = "bitwise_or" _inputs = [("x", UINT256_T), ("y", UINT256_T)] _return_type = UINT256_T @@ -1399,7 +1400,7 @@ def build_IR(self, expr, args, kwargs, context): return IRnode.from_list(["or", args[0], args[1]], typ=UINT256_T) -class BitwiseXor(BuiltinFunction): +class BitwiseXor(BuiltinFunctionT): _id = "bitwise_xor" _inputs = [("x", UINT256_T), ("y", UINT256_T)] _return_type = UINT256_T @@ -1425,7 +1426,7 @@ def build_IR(self, expr, args, kwargs, context): return IRnode.from_list(["xor", args[0], args[1]], typ=UINT256_T) -class BitwiseNot(BuiltinFunction): +class BitwiseNot(BuiltinFunctionT): _id = "bitwise_not" _inputs = [("x", UINT256_T)] _return_type = UINT256_T @@ -1452,7 +1453,7 @@ def build_IR(self, expr, args, kwargs, context): return IRnode.from_list(["not", args[0]], typ=UINT256_T) -class Shift(BuiltinFunction): +class Shift(BuiltinFunctionT): _id = "shift" _inputs = [("x", (UINT256_T, INT256_T)), ("_shift_bits", IntegerT.any())] _return_type = UINT256_T @@ -1506,7 +1507,7 @@ def build_IR(self, expr, args, kwargs, context): return b1.resolve(b2.resolve(IRnode.from_list(ret, typ=argty))) -class _AddMulMod(BuiltinFunction): +class _AddMulMod(BuiltinFunctionT): _inputs = [("a", UINT256_T), ("b", UINT256_T), ("c", UINT256_T)] _return_type = UINT256_T @@ -1546,7 +1547,7 @@ class MulMod(_AddMulMod): _opcode = "mulmod" -class PowMod256(BuiltinFunction): +class PowMod256(BuiltinFunctionT): _id = "pow_mod256" _inputs = [("a", UINT256_T), ("b", UINT256_T)] _return_type = UINT256_T @@ -1569,7 +1570,7 @@ def build_IR(self, expr, context): return IRnode.from_list(["exp", left, right], typ=left.typ) -class Abs(BuiltinFunction): +class Abs(BuiltinFunctionT): _id = "abs" _inputs = [("value", INT256_T)] _return_type = INT256_T @@ -1718,7 +1719,7 @@ def _create_preamble(codesize): return ["or", bytes_to_int(evm), shl(shl_bits, codesize)], evm_len -class _CreateBase(BuiltinFunction): +class _CreateBase(BuiltinFunctionT): _kwargs = { "value": KwargSettings(UINT256_T, zero_value), "salt": KwargSettings(BYTES32_T, empty_value), @@ -1943,7 +1944,7 @@ def _build_create_IR(self, expr, args, context, value, salt, code_offset, raw_ar return b1.resolve(b2.resolve(b3.resolve(b4.resolve(b5.resolve(ir))))) -class _UnsafeMath(BuiltinFunction): +class _UnsafeMath(BuiltinFunctionT): # TODO add unsafe math for `decimal`s _inputs = [("a", IntegerT.any()), ("b", IntegerT.any())] @@ -2009,7 +2010,7 @@ class UnsafeDiv(_UnsafeMath): op = "div" -class _MinMax(BuiltinFunction): +class _MinMax(BuiltinFunctionT): _inputs = [("a", (DecimalT(), IntegerT.any())), ("b", (DecimalT(), IntegerT.any()))] def evaluate(self, node): @@ -2083,7 +2084,7 @@ class Max(_MinMax): _opcode = "gt" -class Uint2Str(BuiltinFunction): +class Uint2Str(BuiltinFunctionT): _id = "uint2str" _inputs = [("x", IntegerT.unsigneds())] @@ -2155,7 +2156,7 @@ def build_IR(self, expr, args, kwargs, context): return b1.resolve(IRnode.from_list(ret, location=MEMORY, typ=return_t)) -class Sqrt(BuiltinFunction): +class Sqrt(BuiltinFunctionT): _id = "sqrt" _inputs = [("d", DecimalT())] _return_type = DecimalT() @@ -2211,7 +2212,7 @@ def build_IR(self, expr, args, kwargs, context): ) -class ISqrt(BuiltinFunction): +class ISqrt(BuiltinFunctionT): _id = "isqrt" _inputs = [("d", UINT256_T)] _return_type = UINT256_T @@ -2261,7 +2262,7 @@ def build_IR(self, expr, args, kwargs, context): return b1.resolve(IRnode.from_list(ret, typ=UINT256_T)) -class Empty(TypenameFoldedFunction): +class Empty(TypenameFoldedFunctionT): _id = "empty" def fetch_call_return(self, node): @@ -2276,7 +2277,7 @@ def build_IR(self, expr, args, kwargs, context): return IRnode("~empty", typ=output_type) -class Breakpoint(BuiltinFunction): +class Breakpoint(BuiltinFunctionT): _id = "breakpoint" _inputs: list = [] @@ -2294,7 +2295,7 @@ def build_IR(self, expr, args, kwargs, context): return IRnode.from_list("breakpoint", annotation="breakpoint()") -class Print(BuiltinFunction): +class Print(BuiltinFunctionT): _id = "print" _inputs: list = [] _has_varargs = True @@ -2372,7 +2373,7 @@ def build_IR(self, expr, args, kwargs, context): return IRnode.from_list(ret, annotation="print:" + sig) -class ABIEncode(BuiltinFunction): +class ABIEncode(BuiltinFunctionT): _id = "_abi_encode" # TODO prettier to rename this to abi.encode # signature: *, ensure_tuple= -> Bytes[] # (check the signature manually since we have no utility methods @@ -2491,7 +2492,7 @@ def build_IR(self, expr, args, kwargs, context): return IRnode.from_list(ret, location=MEMORY, typ=buf_t) -class ABIDecode(BuiltinFunction): +class ABIDecode(BuiltinFunctionT): _id = "_abi_decode" _inputs = [("data", BytesT.any()), ("output_type", "TYPE_DEFINITION")] _kwargs = {"unwrap_tuple": KwargSettings(BoolT(), True, require_literal=True)} @@ -2577,7 +2578,7 @@ def build_IR(self, expr, args, kwargs, context): ) -class _MinMaxValue(TypenameFoldedFunction): +class _MinMaxValue(TypenameFoldedFunctionT): def evaluate(self, node): self._validate_arg_types(node) input_type = type_from_annotation(node.args[0]) @@ -2612,7 +2613,7 @@ def _eval(self, type_): return type_.ast_bounds[1] -class Epsilon(TypenameFoldedFunction): +class Epsilon(TypenameFoldedFunctionT): _id = "epsilon" def evaluate(self, node): diff --git a/vyper/semantics/types/function.py b/vyper/semantics/types/function.py index 10711edc8e..d0868dab83 100644 --- a/vyper/semantics/types/function.py +++ b/vyper/semantics/types/function.py @@ -14,9 +14,14 @@ InvalidType, StateAccessViolation, StructureException, + TypeMismatch, ) from vyper.semantics.analysis.base import FunctionVisibility, StateMutability, StorageSlot -from vyper.semantics.analysis.utils import check_kwargable, validate_expected_type +from vyper.semantics.analysis.utils import ( + check_kwargable, + get_exact_type_from_node, + validate_expected_type, +) from vyper.semantics.data_locations import DataLocation from vyper.semantics.types.base import KwargSettings, VyperType from vyper.semantics.types.primitives import BoolT @@ -631,6 +636,75 @@ def fetch_call_return(self, node: vy_ast.Call) -> Optional[VyperType]: return self.return_type +class BuiltinFunctionT(VyperType): + _has_varargs = False + _kwargs: Dict[str, KwargSettings] = {} + + # helper function to deal with TYPE_DEFINITIONs + def _validate_single(self, arg, expected_type): + # TODO using "TYPE_DEFINITION" is a kludge in derived classes, + # refactor me. + if expected_type == "TYPE_DEFINITION": + # try to parse the type - call type_from_annotation + # for its side effects (will throw if is not a type) + type_from_annotation(arg) + else: + validate_expected_type(arg, expected_type) + + def _validate_arg_types(self, node): + num_args = len(self._inputs) # the number of args the signature indicates + + expect_num_args = num_args + if self._has_varargs: + # note special meaning for -1 in validate_call_args API + expect_num_args = (num_args, -1) + + validate_call_args(node, expect_num_args, self._kwargs) + + for arg, (_, expected) in zip(node.args, self._inputs): + self._validate_single(arg, expected) + + for kwarg in node.keywords: + kwarg_settings = self._kwargs[kwarg.arg] + if kwarg_settings.require_literal and not isinstance(kwarg.value, vy_ast.Constant): + raise TypeMismatch("Value for kwarg must be a literal", kwarg.value) + self._validate_single(kwarg.value, kwarg_settings.typ) + + # typecheck varargs. we don't have type info from the signature, + # so ensure that the types of the args can be inferred exactly. + varargs = node.args[num_args:] + if len(varargs) > 0: + assert self._has_varargs # double check validate_call_args + for arg in varargs: + # call get_exact_type_from_node for its side effects - + # ensures the type can be inferred exactly. + get_exact_type_from_node(arg) + + def fetch_call_return(self, node): + self._validate_arg_types(node) + + if self._return_type: + return self._return_type + + def infer_arg_types(self, node): + self._validate_arg_types(node) + ret = [expected for (_, expected) in self._inputs] + + # handle varargs. + n_known_args = len(self._inputs) + varargs = node.args[n_known_args:] + if len(varargs) > 0: + assert self._has_varargs + ret.extend(get_exact_type_from_node(arg) for arg in varargs) + return ret + + def infer_kwarg_types(self, node): + return {i.arg: self._kwargs[i.arg].typ for i in node.keywords} + + def __repr__(self): + return f"(builtin) {self._id}" + + def _generate_method_id(name: str, canonical_abi_types: List[str]) -> Dict[str, int]: function_sig = f"{name}({','.join(canonical_abi_types)})" selector = keccak256(function_sig.encode())[:4].hex()