From 23ef9d1682d4bc581de30ba1688be851907d063e Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Fri, 8 Dec 2023 10:12:21 -0500 Subject: [PATCH] fix some signatures --- vyper/builtins/_signatures.py | 12 ++++++++---- vyper/semantics/types/base.py | 9 ++++++--- vyper/semantics/types/function.py | 8 ++++++-- vyper/semantics/types/user.py | 6 ++++-- 4 files changed, 24 insertions(+), 11 deletions(-) diff --git a/vyper/builtins/_signatures.py b/vyper/builtins/_signatures.py index 2fad51268e..01fd4dafc4 100644 --- a/vyper/builtins/_signatures.py +++ b/vyper/builtins/_signatures.py @@ -1,5 +1,5 @@ import functools -from typing import Any, Optional +from typing import Any from vyper.ast import nodes as vy_ast from vyper.ast.validation import validate_call_args @@ -79,7 +79,7 @@ class BuiltinFunctionT(VyperType): _has_varargs = False _inputs: list[tuple[str, Any]] = [] _kwargs: dict[str, KwargSettings] = {} - _return_type: Optional[VyperType] = None + _return_type: VyperType | None = None # helper function to deal with TYPE_DEFINITIONs def _validate_single(self, arg: vy_ast.VyperNode, expected_type: VyperType) -> None: @@ -121,7 +121,9 @@ def _validate_arg_types(self, node: vy_ast.Call) -> None: # ensures the type can be inferred exactly. get_exact_type_from_node(arg) - def get_return_type(self, node: vy_ast.Call, expected_type: Optional[VyperType] = None) -> Optional[VyperType]: + def get_return_type( + self, node: vy_ast.Call, expected_type: VyperType | None = None + ) -> VyperType | None: self._validate_arg_types(node) ret = self._return_type @@ -131,7 +133,9 @@ def get_return_type(self, node: vy_ast.Call, expected_type: Optional[VyperType] return ret - def infer_arg_types(self, node: vy_ast.Call, return_type: Optional[VyperType] = None) -> list[VyperType]: + def infer_arg_types( + self, node: vy_ast.Call, return_type: VyperType | None = None + ) -> list[VyperType]: # validate arg types and sanity check the return type self.get_return_type(node, expected_type=return_type) diff --git a/vyper/semantics/types/base.py b/vyper/semantics/types/base.py index f483bf362c..5be6735663 100644 --- a/vyper/semantics/types/base.py +++ b/vyper/semantics/types/base.py @@ -244,7 +244,10 @@ def compare_type(self, other: "VyperType") -> bool: """ return isinstance(other, type(self)) - def get_return_type(self, node: vy_ast.Call) -> Optional["VyperType"]: + def get_return_type( + self, node: vy_ast.Call, expected_type: "VyperType" | None = None + ) -> "VyperType" | None: + # TODO will be cleaner to separate into validate_call and get_return_type """ Validate a call to this type and return the result. @@ -325,7 +328,7 @@ def __repr__(self): return f"type({self.typedef})" # dispatch into ctor if it's called - def get_return_type(self, node, expected_type: Optional[VyperType] = None): + def get_return_type(self, node, expected_type=None): if expected_type is not None: if not self.typedef.compare_type(expected_type): raise CompilerPanic("bad type passed to {self.typedef} ctor", node) @@ -333,7 +336,7 @@ def get_return_type(self, node, expected_type: Optional[VyperType] = None): return self.typedef._ctor_call_return(node) raise StructureException("Value is not callable", node) - def infer_arg_types(self, node, return_type: Optional[VyperType] = None): + def infer_arg_types(self, node, return_type=None): if hasattr(self.typedef, "_ctor_arg_types"): return self.typedef._ctor_arg_types(node) raise StructureException("Value is not callable", node) diff --git a/vyper/semantics/types/function.py b/vyper/semantics/types/function.py index 1733cf65c6..32dcfdfc80 100644 --- a/vyper/semantics/types/function.py +++ b/vyper/semantics/types/function.py @@ -491,7 +491,9 @@ def method_ids(self) -> Dict[str, int]: method_ids.update(_generate_method_id(self.name, arg_types[:i])) return method_ids - def get_return_type(self, node: vy_ast.Call) -> Optional[VyperType]: + def get_return_type( + self, node: vy_ast.Call, expected_type: VyperType | None = None + ) -> VyperType | None: if node.get("func.value.id") == "self" and self.visibility == FunctionVisibility.EXTERNAL: raise CallViolation("Cannot call external functions via 'self'", node) @@ -627,7 +629,9 @@ def __init__( def __repr__(self): return f"{self.underlying_type._id} member function '{self.name}'" - def get_return_type(self, node: vy_ast.Call) -> Optional[VyperType]: + def get_return_type( + self, node: vy_ast.Call, expected_type: VyperType | None = None + ) -> VyperType | None: validate_call_args(node, len(self.arg_types)) assert len(node.args) == len(self.arg_types) # validate_call_args postcondition diff --git a/vyper/semantics/types/user.py b/vyper/semantics/types/user.py index 47a2fca167..037f63f4d3 100644 --- a/vyper/semantics/types/user.py +++ b/vyper/semantics/types/user.py @@ -1,5 +1,5 @@ from functools import cached_property -from typing import Dict, List, Optional, Tuple, Union +from typing import Dict, List, Tuple, Union from vyper import ast as vy_ast from vyper.abi_types import ABI_Address, ABI_GIntM, ABI_Tuple, ABIType @@ -131,7 +131,9 @@ def from_EnumDef(cls, base_node: vy_ast.EnumDef) -> "EnumT": return cls(base_node.name, members) - def get_return_type(self, node: vy_ast.Call, expected_type: Optional[VyperType] = None) -> Optional[VyperType]: + def get_return_type( + self, node: vy_ast.Call, expected_type: VyperType | None = None + ) -> VyperType | None: # TODO return None