Skip to content

Commit

Permalink
Don't throw away type arguments when matching protocols
Browse files Browse the repository at this point in the history
Summary:
I've been trying to do a typeshed update and ran into a problem with pyre no longer recognizing `TextIO` as matching the `SupportsWrite[str]` protocol.

Typeshed recently removed a redundant `def write(self: IO[str], s: str, /) -> int` overload for `typing.IO`, so pyre now needs to recognize that the `def write(self, s: AnyStr, /) -> int` overload allows `TextIO` to match `SupportsWrite[str]` because `TextIO` inherits from `IO[str]`.

Where things were going wrong is that pyre was first trying to match `IO` against `SupportsWrite` (throwing away the `str` type argument), picking an overload with `self: IO[bytes]` (because multiple overloads match when considering the bare classes), and concluding that the match failed because `IO[str]` and `IO[bytes]` are contradictory.

This change adds a `protocol_arguments` parameter to the `instantiate_protocol_parameters` function in `constraintsSet` so that type arguments are taken into account from the beginning.

Reviewed By: stroxler

Differential Revision: D64015614

fbshipit-source-id: f91b8ad13cbecc20fd6b1e63e884adb8559774af
  • Loading branch information
rchen152 authored and facebook-github-bot committed Oct 8, 2024
1 parent 5f2061e commit ad3e647
Show file tree
Hide file tree
Showing 4 changed files with 106 additions and 19 deletions.
68 changes: 53 additions & 15 deletions source/analysis/constraintsSet.ml
Original file line number Diff line number Diff line change
Expand Up @@ -101,9 +101,10 @@ module type OrderedConstraintsSetType = sig
Type.Callable.parameters list

val instantiate_protocol_parameters
: order ->
candidate:Type.t ->
: candidate:Type.t ->
protocol:Ast.Identifier.t ->
?protocol_arguments:Type.Argument.t list ->
order ->
Type.Argument.t list option
end

Expand Down Expand Up @@ -857,7 +858,11 @@ module Make (OrderedConstraints : OrderedConstraintsType) = struct
let left_arguments = instantiate_successors_parameters ~source:left ~target:right_name in
match left_arguments with
| None when is_protocol right ->
instantiate_protocol_parameters order ~protocol:right_name ~candidate:left
instantiate_protocol_parameters
order
~protocol:right_name
~protocol_arguments:right_arguments
~candidate:left
| _ -> left_arguments
in
left_arguments
Expand Down Expand Up @@ -1144,14 +1149,15 @@ module Make (OrderedConstraints : OrderedConstraintsType) = struct
Note that classes that refer to themselves don't suffer from this since subtyping for two
classes just follows from the class hierarchy. *)
and instantiate_protocol_parameters_with_solve
~solve_candidate_less_or_equal_protocol
~candidate
~protocol
?protocol_arguments
({
class_hierarchy = { generic_parameters; _ };
cycle_detections = { assumed_protocol_instantiations; _ } as cycle_detections;
_;
} as order)
~solve_candidate_less_or_equal_protocol
~candidate
~protocol
: Type.Argument.t list option
=
match candidate with
Expand Down Expand Up @@ -1270,24 +1276,52 @@ module Make (OrderedConstraints : OrderedConstraintsType) = struct
>>| desanitize
|> Option.value ~default:[]
in
let protocol_annotation =
let generic_protocol_annotation =
protocol_generic_parameters
>>| Type.parametric protocol
|> Option.value ~default:(Type.Primitive protocol)
in
solve_candidate_less_or_equal_protocol
order_with_new_assumption
~candidate
~protocol_annotation
>>| instantiate_protocol_generics)
let protocol_annotations =
(* When protocol arguments are provided by the caller, we first try solving for them
before falling back to a protocol annotation with generic parameters. We keep only
the non-variable arguments because using the variable names from the protocol
definition produces better error messages. Falling back to the generic annotation
handles the case of `candidate` being an empty container. *)
match protocol_arguments, protocol_generic_parameters with
| Some arguments, Some generic_arguments
when Int.equal (List.length arguments) (List.length generic_arguments) ->
let map argument generic_argument =
match argument with
| Type.Argument.Single (Type.Variable _) -> generic_argument
| _ -> argument
in
let protocol_annotation =
Type.Parametric
{
name = protocol;
arguments = List.map2_exn ~f:map arguments generic_arguments;
}
in
[protocol_annotation; generic_protocol_annotation]
| _ -> [generic_protocol_annotation]
in
let find_first_solution sofar protocol_annotation =
match sofar with
| Some _ -> sofar
| None ->
solve_candidate_less_or_equal_protocol
order_with_new_assumption
~candidate
~protocol_annotation
>>| instantiate_protocol_generics
in
List.fold ~init:None ~f:find_first_solution protocol_annotations)


(** As with `instantiate_protocol_parameters_with_solve`, here `None` means a failure to match
`candidate` type with the protocol, whereas `Some []` means no generic constraints were
induced. *)
and instantiate_protocol_parameters
: order -> candidate:Type.t -> protocol:Ast.Identifier.t -> Type.Argument.t list option
=
and instantiate_protocol_parameters ~candidate ~protocol ?protocol_arguments order =
(* A candidate is less-or-equal to a protocol if candidate.x <: protocol.x for each attribute
`x` in the protocol. *)
let solve_all_protocol_attributes_less_or_equal
Expand Down Expand Up @@ -1346,7 +1380,11 @@ module Make (OrderedConstraints : OrderedConstraintsType) = struct
>>= List.hd
in
instantiate_protocol_parameters_with_solve
order
~solve_candidate_less_or_equal_protocol:solve_all_protocol_attributes_less_or_equal
~candidate
~protocol
~protocol_arguments:(Option.value ~default:[] protocol_arguments)


and instantiate_recursive_type_parameters order ~candidate ~recursive_type
Expand Down
5 changes: 3 additions & 2 deletions source/analysis/constraintsSet.mli
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,10 @@ module type OrderedConstraintsSetType = sig

(* Only exposed for testing *)
val instantiate_protocol_parameters
: order ->
candidate:Type.t ->
: candidate:Type.t ->
protocol:Ast.Identifier.t ->
?protocol_arguments:Type.Argument.t list ->
order ->
Type.Argument.t list option
end

Expand Down
47 changes: 47 additions & 0 deletions source/analysis/test/integration/protocolTest.ml
Original file line number Diff line number Diff line change
Expand Up @@ -627,6 +627,53 @@ let test_check_generic_protocols =
"Revealed type [-1]: Revealed type for `a` is `PXY[int, str]`.";
"Revealed type [-1]: Revealed type for `b` is `PYX[str, int]`.";
];
labeled_test_case __FUNCTION__ __LINE__
@@ assert_type_errors
{|
from typing import Any, Generic, overload, Protocol, TypeVar

_T_contra = TypeVar("_T_contra", contravariant=True)
AnyStr = TypeVar("AnyStr", str, bytes)

class SupportsWrite(Protocol[_T_contra]):
def write(self, s: _T_contra, /) -> object:
pass

class IO(Generic[AnyStr]):
@overload
def write(self: "IO[bytes]", s: bytes, /) -> int: ...

@overload
def write(self, s: AnyStr, /) -> int: ...

# pyre-ignore[2]: Explicit Any.
def write(self: Any, s: Any, /) -> int:
return 0

def f(x: SupportsWrite[str]) -> None:
pass

def g(x: IO[str]) -> None:
f(x)
|}
[];
labeled_test_case __FUNCTION__ __LINE__
@@ assert_type_errors
{|
from typing import Iterable, Protocol, TypeVar

_T = TypeVar('_T')

class HasKeys(Protocol[_T]):
def keys(self) -> Iterable[_T]: ...

def f(x: HasKeys[str]) -> None:
pass

def g() -> None:
f({})
|}
[];
]


Expand Down
5 changes: 3 additions & 2 deletions source/analysis/typeOrder.mli
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ module OrderedConstraints : TypeConstraints.OrderedConstraintsType with type ord
module OrderedConstraintsSet : ConstraintsSet.OrderedConstraintsSetType

val instantiate_protocol_parameters
: order ->
candidate:Type.t ->
: candidate:Type.t ->
protocol:Ast.Identifier.t ->
?protocol_arguments:Type.Argument.t list ->
order ->
Type.Argument.t list option

0 comments on commit ad3e647

Please sign in to comment.