Skip to content

Commit

Permalink
Add typing to two objects in connection_utils (#1198)
Browse files Browse the repository at this point in the history
  • Loading branch information
DanielNoord authored Oct 29, 2024
1 parent bae282e commit a273e0e
Showing 1 changed file with 30 additions and 5 deletions.
35 changes: 30 additions & 5 deletions asyncpg/connect_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@
# This module is part of asyncpg and is released under
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0

from __future__ import annotations

import asyncio
import collections
from collections.abc import Callable
import enum
import functools
import getpass
Expand Down Expand Up @@ -764,14 +766,21 @@ def _parse_connect_arguments(*, dsn, host, port, user, password, passfile,


class TLSUpgradeProto(asyncio.Protocol):
def __init__(self, loop, host, port, ssl_context, ssl_is_advisory):
def __init__(
self,
loop: asyncio.AbstractEventLoop,
host: str,
port: int,
ssl_context: ssl_module.SSLContext,
ssl_is_advisory: bool,
) -> None:
self.on_data = _create_future(loop)
self.host = host
self.port = port
self.ssl_context = ssl_context
self.ssl_is_advisory = ssl_is_advisory

def data_received(self, data):
def data_received(self, data: bytes) -> None:
if data == b'S':
self.on_data.set_result(True)
elif (self.ssl_is_advisory and
Expand All @@ -789,15 +798,30 @@ def data_received(self, data):
'rejected SSL upgrade'.format(
host=self.host, port=self.port)))

def connection_lost(self, exc):
def connection_lost(self, exc: typing.Optional[Exception]) -> None:
if not self.on_data.done():
if exc is None:
exc = ConnectionError('unexpected connection_lost() call')
self.on_data.set_exception(exc)


async def _create_ssl_connection(protocol_factory, host, port, *,
loop, ssl_context, ssl_is_advisory=False):
_ProctolFactoryR = typing.TypeVar(
"_ProctolFactoryR", bound=asyncio.protocols.Protocol
)


async def _create_ssl_connection(
# TODO: The return type is a specific combination of subclasses of
# asyncio.protocols.Protocol that we can't express. For now, having the
# return type be dependent on signature of the factory is an improvement
protocol_factory: Callable[[], _ProctolFactoryR],
host: str,
port: int,
*,
loop: asyncio.AbstractEventLoop,
ssl_context: ssl_module.SSLContext,
ssl_is_advisory: bool = False,
) -> typing.Tuple[asyncio.Transport, _ProctolFactoryR]:

tr, pr = await loop.create_connection(
lambda: TLSUpgradeProto(loop, host, port,
Expand All @@ -817,6 +841,7 @@ async def _create_ssl_connection(protocol_factory, host, port, *,
try:
new_tr = await loop.start_tls(
tr, pr, ssl_context, server_hostname=host)
assert new_tr is not None
except (Exception, asyncio.CancelledError):
tr.close()
raise
Expand Down

0 comments on commit a273e0e

Please sign in to comment.