diff --git a/asyncpg/connect_utils.py b/asyncpg/connect_utils.py index 4890d007..c65f68a6 100644 --- a/asyncpg/connect_utils.py +++ b/asyncpg/connect_utils.py @@ -4,9 +4,11 @@ # This module is part of asyncpg and is released under # the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 +from __future__ import annotations import asyncio import collections +from collections.abc import Callable import enum import functools import getpass @@ -764,14 +766,21 @@ def _parse_connect_arguments(*, dsn, host, port, user, password, passfile, class TLSUpgradeProto(asyncio.Protocol): - def __init__(self, loop, host, port, ssl_context, ssl_is_advisory): + def __init__( + self, + loop: asyncio.AbstractEventLoop, + host: str, + port: int, + ssl_context: ssl_module.SSLContext, + ssl_is_advisory: bool, + ) -> None: self.on_data = _create_future(loop) self.host = host self.port = port self.ssl_context = ssl_context self.ssl_is_advisory = ssl_is_advisory - def data_received(self, data): + def data_received(self, data: bytes) -> None: if data == b'S': self.on_data.set_result(True) elif (self.ssl_is_advisory and @@ -789,15 +798,30 @@ def data_received(self, data): 'rejected SSL upgrade'.format( host=self.host, port=self.port))) - def connection_lost(self, exc): + def connection_lost(self, exc: typing.Optional[Exception]) -> None: if not self.on_data.done(): if exc is None: exc = ConnectionError('unexpected connection_lost() call') self.on_data.set_exception(exc) -async def _create_ssl_connection(protocol_factory, host, port, *, - loop, ssl_context, ssl_is_advisory=False): +_ProctolFactoryR = typing.TypeVar( + "_ProctolFactoryR", bound=asyncio.protocols.Protocol +) + + +async def _create_ssl_connection( + # TODO: The return type is a specific combination of subclasses of + # asyncio.protocols.Protocol that we can't express. For now, having the + # return type be dependent on signature of the factory is an improvement + protocol_factory: Callable[[], _ProctolFactoryR], + host: str, + port: int, + *, + loop: asyncio.AbstractEventLoop, + ssl_context: ssl_module.SSLContext, + ssl_is_advisory: bool = False, +) -> typing.Tuple[asyncio.Transport, _ProctolFactoryR]: tr, pr = await loop.create_connection( lambda: TLSUpgradeProto(loop, host, port, @@ -817,6 +841,7 @@ async def _create_ssl_connection(protocol_factory, host, port, *, try: new_tr = await loop.start_tls( tr, pr, ssl_context, server_hostname=host) + assert new_tr is not None except (Exception, asyncio.CancelledError): tr.close() raise