Skip to content

Commit

Permalink
Add type annotations to _create_ssl_connection
Browse files Browse the repository at this point in the history
  • Loading branch information
DanielNoord committed Oct 23, 2024
1 parent a31d182 commit 76105cc
Showing 1 changed file with 19 additions and 2 deletions.
21 changes: 19 additions & 2 deletions asyncpg/connect_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import asyncio
import collections
from collections.abc import Callable
import enum
import functools
import getpass
Expand Down Expand Up @@ -803,8 +804,23 @@ def connection_lost(self, exc: typing.Optional[Exception]) -> None:
self.on_data.set_exception(exc)


async def _create_ssl_connection(protocol_factory, host, port, *,
loop, ssl_context, ssl_is_advisory=False):
_ProctolFactoryR = typing.TypeVar(
"_ProctolFactoryR", bound=asyncio.protocols.Protocol
)


async def _create_ssl_connection(
# TODO: The return type is a specific combination of subclasses of
# asyncio.protocols.Protocol that we can't express. For now, having the
# return type be dependent on signature of the factory is an improvement
protocol_factory: "Callable[[], _ProctolFactoryR]",
host: str,
port: int,
*,
loop: asyncio.AbstractEventLoop,
ssl_context: ssl_module.SSLContext,
ssl_is_advisory: bool = False,
) -> typing.Tuple[asyncio.Transport, _ProctolFactoryR]:

tr, pr = await loop.create_connection(
lambda: TLSUpgradeProto(loop, host, port,
Expand All @@ -824,6 +840,7 @@ async def _create_ssl_connection(protocol_factory, host, port, *,
try:
new_tr = await loop.start_tls(
tr, pr, ssl_context, server_hostname=host)
assert new_tr is not None
except (Exception, asyncio.CancelledError):
tr.close()
raise
Expand Down

0 comments on commit 76105cc

Please sign in to comment.