From a31d1824d03df92444bb94ba60c3946926f9c410 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20van=20Noord?= <13665637+DanielNoord@users.noreply.github.com> Date: Tue, 22 Oct 2024 22:28:47 +0200 Subject: [PATCH] Add type annotations to `TLSUpgradeProto` --- asyncpg/connect_utils.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/asyncpg/connect_utils.py b/asyncpg/connect_utils.py index 4890d007..6cefc020 100644 --- a/asyncpg/connect_utils.py +++ b/asyncpg/connect_utils.py @@ -764,14 +764,21 @@ def _parse_connect_arguments(*, dsn, host, port, user, password, passfile, class TLSUpgradeProto(asyncio.Protocol): - def __init__(self, loop, host, port, ssl_context, ssl_is_advisory): + def __init__( + self, + loop: asyncio.AbstractEventLoop, + host: str, + port: int, + ssl_context: ssl_module.SSLContext, + ssl_is_advisory: bool, + ) -> None: self.on_data = _create_future(loop) self.host = host self.port = port self.ssl_context = ssl_context self.ssl_is_advisory = ssl_is_advisory - def data_received(self, data): + def data_received(self, data: bytes) -> None: if data == b'S': self.on_data.set_result(True) elif (self.ssl_is_advisory and @@ -789,7 +796,7 @@ def data_received(self, data): 'rejected SSL upgrade'.format( host=self.host, port=self.port))) - def connection_lost(self, exc): + def connection_lost(self, exc: typing.Optional[Exception]) -> None: if not self.on_data.done(): if exc is None: exc = ConnectionError('unexpected connection_lost() call')