Skip to content

Commit

Permalink
fix some signatures
Browse files Browse the repository at this point in the history
  • Loading branch information
charles-cooper committed Dec 8, 2023
1 parent 832f17b commit 23ef9d1
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 11 deletions.
12 changes: 8 additions & 4 deletions vyper/builtins/_signatures.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import functools
from typing import Any, Optional
from typing import Any

from vyper.ast import nodes as vy_ast
from vyper.ast.validation import validate_call_args
Expand Down Expand Up @@ -79,7 +79,7 @@ class BuiltinFunctionT(VyperType):
_has_varargs = False
_inputs: list[tuple[str, Any]] = []
_kwargs: dict[str, KwargSettings] = {}
_return_type: Optional[VyperType] = None
_return_type: VyperType | None = None

# helper function to deal with TYPE_DEFINITIONs
def _validate_single(self, arg: vy_ast.VyperNode, expected_type: VyperType) -> None:
Expand Down Expand Up @@ -121,7 +121,9 @@ def _validate_arg_types(self, node: vy_ast.Call) -> None:
# ensures the type can be inferred exactly.
get_exact_type_from_node(arg)

def get_return_type(self, node: vy_ast.Call, expected_type: Optional[VyperType] = None) -> Optional[VyperType]:
def get_return_type(
self, node: vy_ast.Call, expected_type: VyperType | None = None
) -> VyperType | None:
self._validate_arg_types(node)

ret = self._return_type
Expand All @@ -131,7 +133,9 @@ def get_return_type(self, node: vy_ast.Call, expected_type: Optional[VyperType]

return ret

def infer_arg_types(self, node: vy_ast.Call, return_type: Optional[VyperType] = None) -> list[VyperType]:
def infer_arg_types(
self, node: vy_ast.Call, return_type: VyperType | None = None
) -> list[VyperType]:
# validate arg types and sanity check the return type
self.get_return_type(node, expected_type=return_type)

Expand Down
9 changes: 6 additions & 3 deletions vyper/semantics/types/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,10 @@ def compare_type(self, other: "VyperType") -> bool:
"""
return isinstance(other, type(self))

def get_return_type(self, node: vy_ast.Call) -> Optional["VyperType"]:
def get_return_type(
self, node: vy_ast.Call, expected_type: "VyperType" | None = None
) -> "VyperType" | None:
# TODO will be cleaner to separate into validate_call and get_return_type
"""
Validate a call to this type and return the result.
Expand Down Expand Up @@ -325,15 +328,15 @@ def __repr__(self):
return f"type({self.typedef})"

# dispatch into ctor if it's called
def get_return_type(self, node, expected_type: Optional[VyperType] = None):
def get_return_type(self, node, expected_type=None):
if expected_type is not None:
if not self.typedef.compare_type(expected_type):
raise CompilerPanic("bad type passed to {self.typedef} ctor", node)
if hasattr(self.typedef, "_ctor_call_return"):
return self.typedef._ctor_call_return(node)
raise StructureException("Value is not callable", node)

def infer_arg_types(self, node, return_type: Optional[VyperType] = None):
def infer_arg_types(self, node, return_type=None):
if hasattr(self.typedef, "_ctor_arg_types"):
return self.typedef._ctor_arg_types(node)
raise StructureException("Value is not callable", node)
Expand Down
8 changes: 6 additions & 2 deletions vyper/semantics/types/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,7 +491,9 @@ def method_ids(self) -> Dict[str, int]:
method_ids.update(_generate_method_id(self.name, arg_types[:i]))
return method_ids

def get_return_type(self, node: vy_ast.Call) -> Optional[VyperType]:
def get_return_type(
self, node: vy_ast.Call, expected_type: VyperType | None = None
) -> VyperType | None:
if node.get("func.value.id") == "self" and self.visibility == FunctionVisibility.EXTERNAL:
raise CallViolation("Cannot call external functions via 'self'", node)

Expand Down Expand Up @@ -627,7 +629,9 @@ def __init__(
def __repr__(self):
return f"{self.underlying_type._id} member function '{self.name}'"

def get_return_type(self, node: vy_ast.Call) -> Optional[VyperType]:
def get_return_type(
self, node: vy_ast.Call, expected_type: VyperType | None = None
) -> VyperType | None:
validate_call_args(node, len(self.arg_types))

assert len(node.args) == len(self.arg_types) # validate_call_args postcondition
Expand Down
6 changes: 4 additions & 2 deletions vyper/semantics/types/user.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from functools import cached_property
from typing import Dict, List, Optional, Tuple, Union
from typing import Dict, List, Tuple, Union

from vyper import ast as vy_ast
from vyper.abi_types import ABI_Address, ABI_GIntM, ABI_Tuple, ABIType
Expand Down Expand Up @@ -131,7 +131,9 @@ def from_EnumDef(cls, base_node: vy_ast.EnumDef) -> "EnumT":

return cls(base_node.name, members)

def get_return_type(self, node: vy_ast.Call, expected_type: Optional[VyperType] = None) -> Optional[VyperType]:
def get_return_type(
self, node: vy_ast.Call, expected_type: VyperType | None = None
) -> VyperType | None:
# TODO
return None

Expand Down

0 comments on commit 23ef9d1

Please sign in to comment.