From bbf05102721d46d7b1aa41438846e3b187a10416 Mon Sep 17 00:00:00 2001 From: tserg <8017125+tserg@users.noreply.github.com> Date: Tue, 14 Nov 2023 11:24:14 +0800 Subject: [PATCH] fix minmax --- vyper/builtins/functions.py | 187 +++++++++++++++++----- vyper/semantics/analysis/local.py | 4 +- vyper/semantics/analysis/pre_typecheck.py | 2 +- vyper/semantics/analysis/utils.py | 2 + 4 files changed, 152 insertions(+), 43 deletions(-) diff --git a/vyper/builtins/functions.py b/vyper/builtins/functions.py index 6eda338d73..2a7d997d59 100644 --- a/vyper/builtins/functions.py +++ b/vyper/builtins/functions.py @@ -128,7 +128,7 @@ def fetch_call_return(self, node): type_ = self.infer_arg_types(node)[0].typedef return type_ - def infer_arg_types(self, node): + def infer_arg_types(self, node, expected_return_typ=None): validate_call_args(node, 1) input_typedef = TYPE_T(type_from_annotation(node.args[0])) return [input_typedef] @@ -140,6 +140,9 @@ class Floor(BuiltinFunction): # TODO: maybe use int136? _return_type = INT256_T + def prefold(self, node): + return self.evaluate(node) + def evaluate(self, node): validate_call_args(node, 1) value = node.args[0]._metadata.get("folded_value") @@ -171,6 +174,9 @@ class Ceil(BuiltinFunction): # TODO: maybe use int136? _return_type = INT256_T + def prefold(self, node): + return self.evaluate(node) + def evaluate(self, node): validate_call_args(node, 1) value = node.args[0]._metadata.get("folded_value") @@ -206,7 +212,7 @@ def fetch_call_return(self, node): return target_typedef.typedef # TODO: push this down into convert.py for more consistency - def infer_arg_types(self, node): + def infer_arg_types(self, node, expected_return_typ=None): validate_call_args(node, 2) target_type = type_from_annotation(node.args[1]) @@ -342,7 +348,7 @@ def fetch_call_return(self, node): return return_type - def infer_arg_types(self, node): + def infer_arg_types(self, node, expected_return_typ=None): self._validate_arg_types(node) # return a concrete type for `b` b_type = get_possible_types_from_node(node.args[0]).pop() @@ -466,6 +472,9 @@ class Len(BuiltinFunction): _inputs = [("b", (StringT.any(), BytesT.any(), DArrayT.any()))] _return_type = UINT256_T + def prefold(self, node): + return self.evaluate(node) + def evaluate(self, node): validate_call_args(node, 1) arg = node.args[0] @@ -479,7 +488,7 @@ def evaluate(self, node): return vy_ast.Int.from_node(node, value=length) - def infer_arg_types(self, node): + def infer_arg_types(self, node, expected_return_typ=None): self._validate_arg_types(node) # return a concrete type typ = get_possible_types_from_node(node.args[0]).pop() @@ -509,7 +518,7 @@ def fetch_call_return(self, node): return_type.set_length(length) return return_type - def infer_arg_types(self, node): + def infer_arg_types(self, node, expected_return_typ=None): if len(node.args) < 2: raise ArgumentException("Invalid argument count: expected at least 2", node) @@ -603,6 +612,9 @@ class Keccak256(BuiltinFunction): _inputs = [("value", (BytesT.any(), BYTES32_T, StringT.any()))] _return_type = BYTES32_T + def prefold(self, node): + return self.evaluate(node) + def evaluate(self, node): validate_call_args(node, 1) if isinstance(node.args[0], vy_ast.Bytes): @@ -618,7 +630,7 @@ def evaluate(self, node): hash_ = f"0x{keccak256(value).hex()}" return vy_ast.Hex.from_node(node, value=hash_) - def infer_arg_types(self, node): + def infer_arg_types(self, node, expected_return_typ=None): self._validate_arg_types(node) # return a concrete type for `value` value_type = get_possible_types_from_node(node.args[0]).pop() @@ -650,6 +662,9 @@ class Sha256(BuiltinFunction): _inputs = [("value", (BYTES32_T, BytesT.any(), StringT.any()))] _return_type = BYTES32_T + def prefold(self, node): + return self.evaluate(node) + def evaluate(self, node): validate_call_args(node, 1) if isinstance(node.args[0], vy_ast.Bytes): @@ -665,7 +680,7 @@ def evaluate(self, node): hash_ = f"0x{hashlib.sha256(value).hexdigest()}" return vy_ast.Hex.from_node(node, value=hash_) - def infer_arg_types(self, node): + def infer_arg_types(self, node, expected_return_typ=None): self._validate_arg_types(node) # return a concrete type for `value` value_type = get_possible_types_from_node(node.args[0]).pop() @@ -720,6 +735,12 @@ def build_IR(self, expr, args, kwargs, context): class MethodID(FoldedFunction): _id = "method_id" + def prefold(self, node): + try: + return self.evaluate(node) + except (InvalidType, InvalidLiteral): + return + def evaluate(self, node): validate_call_args(node, 1, ["output_type"]) @@ -767,7 +788,7 @@ class ECRecover(BuiltinFunction): ] _return_type = AddressT() - def infer_arg_types(self, node): + def infer_arg_types(self, node, expected_return_typ=None): self._validate_arg_types(node) v_t, r_t, s_t = [get_possible_types_from_node(arg).pop() for arg in node.args[1:]] return [BYTES32_T, v_t, r_t, s_t] @@ -865,7 +886,7 @@ def fetch_call_return(self, node): return_type = self.infer_kwarg_types(node)["output_type"].typedef return return_type - def infer_arg_types(self, node): + def infer_arg_types(self, node, expected_return_typ=None): self._validate_arg_types(node) input_type = get_possible_types_from_node(node.args[0]).pop() return [input_type, UINT256_T] @@ -993,6 +1014,12 @@ def get_denomination(self, node): return denom + def prefold(self, node): + try: + return self.evaluate(node) + except InvalidLiteral: + return + def evaluate(self, node): validate_call_args(node, 2) denom = self.get_denomination(node) @@ -1015,7 +1042,7 @@ def fetch_call_return(self, node): self.infer_arg_types(node) return self._return_type - def infer_arg_types(self, node): + def infer_arg_types(self, node, expected_return_typ=None): self._validate_arg_types(node) # return a concrete type instead of abstract type value_type = get_possible_types_from_node(node.args[0]).pop() @@ -1104,7 +1131,7 @@ def fetch_call_return(self, node): return return_type return TupleT([BoolT(), return_type]) - def infer_arg_types(self, node): + def infer_arg_types(self, node, expected_return_typ=None): self._validate_arg_types(node) # return a concrete type for `data` data_type = get_possible_types_from_node(node.args[1]).pop() @@ -1281,7 +1308,7 @@ class RawRevert(BuiltinFunction): def fetch_call_return(self, node): return None - def infer_arg_types(self, node): + def infer_arg_types(self, node, expected_return_typ=None): self._validate_arg_types(node) data_type = get_possible_types_from_node(node.args[0]).pop() return [data_type] @@ -1301,7 +1328,7 @@ class RawLog(BuiltinFunction): def fetch_call_return(self, node): self.infer_arg_types(node) - def infer_arg_types(self, node): + def infer_arg_types(self, node, expected_return_typ=None): self._validate_arg_types(node) if not isinstance(node.args[0], vy_ast.List) or len(node.args[0].elements) > 4: @@ -1351,6 +1378,12 @@ class BitwiseAnd(BuiltinFunction): _return_type = UINT256_T _warned = False + def prefold(self, node): + try: + return self.evaluate(node) + except (InvalidLiteral, UnfoldableNode): + return + def evaluate(self, node): if not self.__class__._warned: vyper_warn("`bitwise_and()` is deprecated! Please use the & operator instead.") @@ -1377,6 +1410,12 @@ class BitwiseOr(BuiltinFunction): _return_type = UINT256_T _warned = False + def prefold(self, node): + try: + return self.evaluate(node) + except (UnfoldableNode, InvalidLiteral): + return + def evaluate(self, node): if not self.__class__._warned: vyper_warn("`bitwise_or()` is deprecated! Please use the | operator instead.") @@ -1403,6 +1442,12 @@ class BitwiseXor(BuiltinFunction): _return_type = UINT256_T _warned = False + def prefold(self, node): + try: + return self.evaluate(node) + except InvalidLiteral: + return + def evaluate(self, node): if not self.__class__._warned: vyper_warn("`bitwise_xor()` is deprecated! Please use the ^ operator instead.") @@ -1429,6 +1474,12 @@ class BitwiseNot(BuiltinFunction): _return_type = UINT256_T _warned = False + def prefold(self, node): + try: + return self.evaluate(node) + except InvalidLiteral: + return + def evaluate(self, node): if not self.__class__._warned: vyper_warn("`bitwise_not()` is deprecated! Please use the ~ operator instead.") @@ -1456,17 +1507,16 @@ class Shift(BuiltinFunction): _return_type = UINT256_T _warned = False - def evaluate(self, node): + def prefold(self, node): if not self.__class__._warned: vyper_warn("`shift()` is deprecated! Please use the << or >> operator instead.") self.__class__._warned = True validate_call_args(node, 2) + args = [i._metadata.get("folded_value") for i in node.args] if [i for i in node.args if not isinstance(i, vy_ast.Int)]: raise UnfoldableNode value, shift = [i.value for i in node.args] - if value < 0 or value >= 2**256: - raise InvalidLiteral("Value out of range for uint256", node.args[0]) if shift < -256 or shift > 256: # this validation is performed to prevent the compiler from hanging # rather than for correctness because the post-folded constant would @@ -1479,11 +1529,21 @@ def evaluate(self, node): value = (value << shift) % (2**256) return vy_ast.Int.from_node(node, value=value) + def evaluate(self, node): + value = args[0]._metadata.get("folded_value") + if not isinstance(value, vy_ast.Int): + raise UnfoldableNode + + if value < 0 or value >= 2**256: + raise InvalidLiteral("Value out of range for uint256", node.args[0]) + + return self.prefold(node) + def fetch_call_return(self, node): # return type is the type of the first argument return self.infer_arg_types(node)[0] - def infer_arg_types(self, node): + def infer_arg_types(self, node, expected_return_typ=None): self._validate_arg_types(node) # return a concrete type instead of SignedIntegerAbstractType arg_ty = get_possible_types_from_node(node.args[0])[0] @@ -1508,17 +1568,30 @@ class _AddMulMod(BuiltinFunction): _inputs = [("a", UINT256_T), ("b", UINT256_T), ("c", UINT256_T)] _return_type = UINT256_T - def evaluate(self, node): + def prefold(self, node): validate_call_args(node, 3) - if isinstance(node.args[2], vy_ast.Int) and node.args[2].value == 0: - raise ZeroDivisionException("Modulo by 0", node.args[2]) + args = [i._metadata.get("folded_value") for i in node.args] + if not all(isinstance(i, vy_ast.Int) for i in args): + raise UnfoldableNode + if isinstance(args[2], vy_ast.Int) and args[2].value == 0: + raise UnfoldableNode("Modulo by 0", node.args[2]) for arg in node.args: if not isinstance(arg, vy_ast.Int): raise UnfoldableNode + + value = self._eval_fn(node.args[0].value, node.args[1].value) % node.args[2].value + return vy_ast.Int.from_node(node, value=value) + + def evaluate(self, node): + validate_call_args(node, 3) + args = [i._metadata.get("folded_value") for i in node.args] + if isinstance(args[2], vy_ast.Int) and args[2].value == 0: + raise ZeroDivisionException("Modulo by 0", node.args[2]) + for arg in node.args: if arg.value < 0 or arg.value >= 2**256: raise InvalidLiteral("Value out of range for uint256", arg) - value = self._eval_fn(node.args[0].value, node.args[1].value) % node.args[2].value + value = self._eval_fn(args[0].value, args[1].value) % args[2].value return vy_ast.Int.from_node(node, value=value) @process_inputs @@ -1959,7 +2032,7 @@ def fetch_call_return(self, node): return_type = self.infer_arg_types(node).pop() return return_type - def infer_arg_types(self, node): + def infer_arg_types(self, node, expected_return_typ=None): self._validate_arg_types(node) types_list = get_common_types(*node.args, filter_fn=lambda x: isinstance(x, IntegerT)) @@ -2017,34 +2090,41 @@ class UnsafeDiv(_UnsafeMath): class _MinMax(BuiltinFunction): _inputs = [("a", (DecimalT(), IntegerT.any())), ("b", (DecimalT(), IntegerT.any()))] - def evaluate(self, node): + def prefold(self, node): validate_call_args(node, 2) - if not isinstance(node.args[0], type(node.args[1])): + arg0 = node.args[0]._metadata.get("folded_value") + arg1 = node.args[1]._metadata.get("folded_value") + if not isinstance(arg0, (vy_ast.Decimal, vy_ast.Int)): raise UnfoldableNode - if not isinstance(node.args[0], (vy_ast.Decimal, vy_ast.Int)): + if not isinstance(arg0, type(arg1)): raise UnfoldableNode - left, right = (i.value for i in node.args) - if isinstance(left, Decimal) and ( - min(left, right) < SizeLimits.MIN_AST_DECIMAL - or max(left, right) > SizeLimits.MAX_AST_DECIMAL + left = arg0.value + right = arg1.value + + value = self._eval_fn(left, right) + return type(node.args[0]).from_node(node, value=value) + + def evaluate(self, node): + new_node = self.prefold(node) + + left = node.args[0]._metadata.get("folded_value") + right = node.args[1]._metadata.get("folded_value") + if isinstance(left.value, Decimal) and ( + min(left.value, right.value) < SizeLimits.MIN_AST_DECIMAL + or max(left.value, right.value) > SizeLimits.MAX_AST_DECIMAL ): raise InvalidType("Decimal value is outside of allowable range", node) types_list = get_common_types( - *node.args, filter_fn=lambda x: isinstance(x, (IntegerT, DecimalT)) + *(left, right), filter_fn=lambda x: isinstance(x, (IntegerT, DecimalT)) ) if not types_list: raise TypeMismatch("Cannot perform action between dislike numeric types", node) - value = self._eval_fn(left, right) - return type(node.args[0]).from_node(node, value=value) + return new_node def fetch_call_return(self, node): - return_type = self.infer_arg_types(node).pop() - return return_type - - def infer_arg_types(self, node): self._validate_arg_types(node) types_list = get_common_types( @@ -2053,8 +2133,20 @@ def infer_arg_types(self, node): if not types_list: raise TypeMismatch("Cannot perform action between dislike numeric types", node) - type_ = types_list.pop() - return [type_, type_] + return types_list + + def infer_arg_types(self, node, expected_return_typ=None): + types_list = self.fetch_call_return(node) + + if expected_return_typ is not None: + if expected_return_typ not in types_list: + raise TypeMismatch("Cannot perform action between dislike numeric types", node) + + arg_typ = expected_return_typ + else: + arg_typ = types_list.pop() + + return [arg_typ, arg_typ] @process_inputs def build_IR(self, expr, args, kwargs, context): @@ -2098,6 +2190,9 @@ def fetch_call_return(self, node): len_needed = math.ceil(bits * math.log(2) / math.log(10)) return StringT(len_needed) + def prefold(self, node): + return self.evaluate(node) + def evaluate(self, node): validate_call_args(node, 1) if not isinstance(node.args[0], vy_ast.Int): @@ -2106,7 +2201,7 @@ def evaluate(self, node): value = str(node.args[0].value) return vy_ast.Str.from_node(node, value=value) - def infer_arg_types(self, node): + def infer_arg_types(self, node, expected_return_typ=None): self._validate_arg_types(node) input_type = get_possible_types_from_node(node.args[0]).pop() return [input_type] @@ -2506,7 +2601,7 @@ def fetch_call_return(self, node): _, output_type = self.infer_arg_types(node) return output_type.typedef - def infer_arg_types(self, node): + def infer_arg_types(self, node, expected_return_typ=None): self._validate_arg_types(node) validate_call_args(node, 2, ["unwrap_tuple"]) @@ -2585,6 +2680,12 @@ def build_IR(self, expr, args, kwargs, context): class _MinMaxValue(TypenameFoldedFunction): + def prefold(self, node): + try: + return self.evaluate(node) + except InvalidType: + return + def evaluate(self, node): self._validate_arg_types(node) input_type = type_from_annotation(node.args[0]) @@ -2621,6 +2722,12 @@ def _eval(self, type_): class Epsilon(TypenameFoldedFunction): _id = "epsilon" + def prefold(self, node): + try: + return self.evaluate(node) + except InvalidType: + return + def evaluate(self, node): self._validate_arg_types(node) input_type = type_from_annotation(node.args[0]) diff --git a/vyper/semantics/analysis/local.py b/vyper/semantics/analysis/local.py index ce1ab983cf..e3c6c3dc28 100644 --- a/vyper/semantics/analysis/local.py +++ b/vyper/semantics/analysis/local.py @@ -607,7 +607,7 @@ def visit(self, node, typ): super().visit(node, typ) folded_value = node._metadata.get("folded_value") - if folded_value: + if isinstance(folded_value, vy_ast.Constant): # print("folded value: ", folded_value) validate_expected_type(folded_value, typ) @@ -689,7 +689,7 @@ def visit_Call(self, node: vy_ast.Call, typ: VyperType) -> None: return # builtin functions - arg_types = call_type.infer_arg_types(node) + arg_types = call_type.infer_arg_types(node, typ) # `infer_arg_types` already calls `validate_expected_type` for arg, arg_type in zip(node.args, arg_types): self.visit(arg, arg_type) diff --git a/vyper/semantics/analysis/pre_typecheck.py b/vyper/semantics/analysis/pre_typecheck.py index 425bf9d57e..3ef89b227e 100644 --- a/vyper/semantics/analysis/pre_typecheck.py +++ b/vyper/semantics/analysis/pre_typecheck.py @@ -79,6 +79,6 @@ def prefold(node: vy_ast.VyperNode, constants: dict) -> None: call_type = DISPATCH_TABLE.get(func_name) if call_type and hasattr(call_type, "evaluate"): try: - node._metadata["folded_value"] = call_type.evaluate(node) # type: ignore + node._metadata["folded_value"] = call_type.prefold(node) # type: ignore except UnfoldableNode: pass diff --git a/vyper/semantics/analysis/utils.py b/vyper/semantics/analysis/utils.py index baa3aa67c2..4663676779 100644 --- a/vyper/semantics/analysis/utils.py +++ b/vyper/semantics/analysis/utils.py @@ -281,6 +281,8 @@ def types_from_Call(self, node): var = self.get_exact_type_from_node(node.func, include_type_exprs=True) return_value = var.fetch_call_return(node) if return_value: + if isinstance(return_value, list): + return return_value return [return_value] raise InvalidType(f"{var} did not return a value", node)