Skip to content

Commit

Permalink
Test typing round 6 (#9226)
Browse files Browse the repository at this point in the history
  • Loading branch information
Dreamsorcerer authored Sep 25, 2024
1 parent de997af commit 56aa261
Show file tree
Hide file tree
Showing 7 changed files with 585 additions and 411 deletions.
8 changes: 2 additions & 6 deletions aiohttp/http_websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -708,16 +708,12 @@ def _write(self, data: bytes) -> None:
raise ClientConnectionResetError("Cannot write to closing transport")
self.transport.write(data)

async def pong(self, message: Union[bytes, str] = b"") -> None:
async def pong(self, message: bytes = b"") -> None:
"""Send pong message."""
if isinstance(message, str):
message = message.encode("utf-8")
await self._send_frame(message, WSMsgType.PONG)

async def ping(self, message: Union[bytes, str] = b"") -> None:
async def ping(self, message: bytes = b"") -> None:
"""Send ping message."""
if isinstance(message, str):
message = message.encode("utf-8")
await self._send_frame(message, WSMsgType.PING)

async def send(
Expand Down
6 changes: 3 additions & 3 deletions aiohttp/web_ws.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,14 +404,14 @@ async def pong(self, message: bytes = b"") -> None:
raise RuntimeError("Call .prepare() first")
await self._writer.pong(message)

async def send_str(self, data: str, compress: Optional[bool] = None) -> None:
async def send_str(self, data: str, compress: Optional[int] = None) -> None:
if self._writer is None:
raise RuntimeError("Call .prepare() first")
if not isinstance(data, str):
raise TypeError("data argument must be str (%r)" % type(data))
await self._writer.send(data, binary=False, compress=compress)

async def send_bytes(self, data: bytes, compress: Optional[bool] = None) -> None:
async def send_bytes(self, data: bytes, compress: Optional[int] = None) -> None:
if self._writer is None:
raise RuntimeError("Call .prepare() first")
if not isinstance(data, (bytes, bytearray, memoryview)):
Expand All @@ -421,7 +421,7 @@ async def send_bytes(self, data: bytes, compress: Optional[bool] = None) -> None
async def send_json(
self,
data: Any,
compress: Optional[bool] = None,
compress: Optional[int] = None,
*,
dumps: JSONEncoder = json.dumps,
) -> None:
Expand Down
100 changes: 55 additions & 45 deletions tests/test_web_runner.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,35 @@
# type: ignore
import asyncio
import platform
import signal
from typing import Any
from unittest.mock import patch
from typing import Any, Iterator, NoReturn, Protocol, Union
from unittest import mock

import pytest

from aiohttp import web
from aiohttp.abc import AbstractAccessLogger
from aiohttp.test_utils import get_unused_port_socket
from aiohttp.web_log import AccessLogger


class _RunnerMaker(Protocol):
def __call__(self, handle_signals: bool = ..., **kwargs: Any) -> web.AppRunner: ...


@pytest.fixture
def app():
def app() -> web.Application:
return web.Application()


@pytest.fixture
def make_runner(loop: Any, app: Any):
def make_runner(
loop: asyncio.AbstractEventLoop, app: web.Application
) -> Iterator[_RunnerMaker]:
asyncio.set_event_loop(loop)
runners = []

def go(**kwargs):
runner = web.AppRunner(app, **kwargs)
def go(handle_signals: bool = False, **kwargs: Any) -> web.AppRunner:
runner = web.AppRunner(app, handle_signals=handle_signals, **kwargs)
runners.append(runner)
return runner

Expand All @@ -32,7 +38,7 @@ def go(**kwargs):
loop.run_until_complete(runner.cleanup())


async def test_site_for_nonfrozen_app(make_runner: Any) -> None:
async def test_site_for_nonfrozen_app(make_runner: _RunnerMaker) -> None:
runner = make_runner()
with pytest.raises(RuntimeError):
web.TCPSite(runner)
Expand All @@ -42,7 +48,7 @@ async def test_site_for_nonfrozen_app(make_runner: Any) -> None:
@pytest.mark.skipif(
platform.system() == "Windows", reason="the test is not valid for Windows"
)
async def test_runner_setup_handle_signals(make_runner: Any) -> None:
async def test_runner_setup_handle_signals(make_runner: _RunnerMaker) -> None:
runner = make_runner(handle_signals=True)
await runner.setup()
assert signal.getsignal(signal.SIGTERM) is not signal.SIG_DFL
Expand All @@ -53,15 +59,15 @@ async def test_runner_setup_handle_signals(make_runner: Any) -> None:
@pytest.mark.skipif(
platform.system() == "Windows", reason="the test is not valid for Windows"
)
async def test_runner_setup_without_signal_handling(make_runner: Any) -> None:
async def test_runner_setup_without_signal_handling(make_runner: _RunnerMaker) -> None:
runner = make_runner(handle_signals=False)
await runner.setup()
assert signal.getsignal(signal.SIGTERM) is signal.SIG_DFL
await runner.cleanup()
assert signal.getsignal(signal.SIGTERM) is signal.SIG_DFL


async def test_site_double_added(make_runner: Any) -> None:
async def test_site_double_added(make_runner: _RunnerMaker) -> None:
_sock = get_unused_port_socket("127.0.0.1")
runner = make_runner()
await runner.setup()
Expand All @@ -73,7 +79,7 @@ async def test_site_double_added(make_runner: Any) -> None:
assert len(runner.sites) == 1


async def test_site_stop_not_started(make_runner: Any) -> None:
async def test_site_stop_not_started(make_runner: _RunnerMaker) -> None:
runner = make_runner()
await runner.setup()
site = web.TCPSite(runner)
Expand All @@ -83,34 +89,35 @@ async def test_site_stop_not_started(make_runner: Any) -> None:
assert len(runner.sites) == 0


async def test_custom_log_format(make_runner: Any) -> None:
async def test_custom_log_format(make_runner: _RunnerMaker) -> None:
runner = make_runner(access_log_format="abc")
await runner.setup()
assert runner.server is not None
assert runner.server._kwargs["access_log_format"] == "abc"


async def test_unreg_site(make_runner: Any) -> None:
async def test_unreg_site(make_runner: _RunnerMaker) -> None:
runner = make_runner()
await runner.setup()
site = web.TCPSite(runner)
with pytest.raises(RuntimeError):
runner._unreg_site(site)


async def test_app_property(make_runner: Any, app: Any) -> None:
async def test_app_property(make_runner: _RunnerMaker, app: web.Application) -> None:
runner = make_runner()
assert runner.app is app


def test_non_app() -> None:
with pytest.raises(TypeError):
web.AppRunner(object())
web.AppRunner(object()) # type: ignore[arg-type]


def test_app_handler_args() -> None:
app = web.Application(handler_args={"test": True})
runner = web.AppRunner(app)
assert runner._kwargs == {"access_log_class": web.AccessLogger, "test": True}
assert runner._kwargs == {"access_log_class": AccessLogger, "test": True}


async def test_app_handler_args_failure() -> None:
Expand All @@ -132,7 +139,9 @@ async def test_app_handler_args_failure() -> None:
("2", 2),
),
)
async def test_app_handler_args_ceil_threshold(value: Any, expected: Any) -> None:
async def test_app_handler_args_ceil_threshold(
value: Union[int, str, None], expected: int
) -> None:
app = web.Application(handler_args={"timeout_ceil_threshold": value})
runner = web.AppRunner(app)
await runner.setup()
Expand All @@ -150,7 +159,7 @@ class Logger:
app = web.Application()

with pytest.raises(TypeError):
web.AppRunner(app, access_log_class=Logger)
web.AppRunner(app, access_log_class=Logger) # type: ignore[arg-type]


async def test_app_make_handler_access_log_class_bad_type2() -> None:
Expand All @@ -165,7 +174,9 @@ class Logger:

async def test_app_make_handler_access_log_class1() -> None:
class Logger(AbstractAccessLogger):
def log(self, request, response, time):
def log(
self, request: web.BaseRequest, response: web.StreamResponse, time: float
) -> None:
"""Pass log method."""

app = web.Application()
Expand All @@ -175,15 +186,17 @@ def log(self, request, response, time):

async def test_app_make_handler_access_log_class2() -> None:
class Logger(AbstractAccessLogger):
def log(self, request, response, time):
def log(
self, request: web.BaseRequest, response: web.StreamResponse, time: float
) -> None:
"""Pass log method."""

app = web.Application(handler_args={"access_log_class": Logger})
runner = web.AppRunner(app)
assert runner._kwargs["access_log_class"] is Logger


async def test_addresses(make_runner: Any, unix_sockname: Any) -> None:
async def test_addresses(make_runner: _RunnerMaker, unix_sockname: str) -> None:
_sock = get_unused_port_socket("127.0.0.1")
runner = make_runner()
await runner.setup()
Expand All @@ -200,7 +213,7 @@ async def test_addresses(make_runner: Any, unix_sockname: Any) -> None:
platform.system() != "Windows", reason="Proactor Event loop present only in Windows"
)
async def test_named_pipe_runner_wrong_loop(
app: Any, selector_loop: Any, pipe_name: Any
app: web.Application, selector_loop: asyncio.AbstractEventLoop, pipe_name: str
) -> None:
runner = web.AppRunner(app)
await runner.setup()
Expand All @@ -212,7 +225,7 @@ async def test_named_pipe_runner_wrong_loop(
platform.system() != "Windows", reason="Proactor Event loop present only in Windows"
)
async def test_named_pipe_runner_proactor_loop(
proactor_loop: Any, app: Any, pipe_name: Any
proactor_loop: asyncio.AbstractEventLoop, app: web.Application, pipe_name: str
) -> None:
runner = web.AppRunner(app)
await runner.setup()
Expand All @@ -221,45 +234,42 @@ async def test_named_pipe_runner_proactor_loop(
await runner.cleanup()


async def test_tcpsite_default_host(make_runner: Any) -> None:
async def test_tcpsite_default_host(make_runner: _RunnerMaker) -> None:
runner = make_runner()
await runner.setup()
site = web.TCPSite(runner)
assert site.name == "http://0.0.0.0:8080"

calls = []

async def mock_create_server(*args, **kwargs):
calls.append((args, kwargs))

with patch("asyncio.get_event_loop") as mock_get_loop:
mock_get_loop.return_value.create_server = mock_create_server
m = mock.create_autospec(asyncio.AbstractEventLoop, spec_set=True, instance=True)
m.create_server.return_value = mock.create_autospec(asyncio.Server, spec_set=True)
with mock.patch(
"asyncio.get_event_loop", autospec=True, spec_set=True, return_value=m
):
await site.start()

assert len(calls) == 1
server, host, port = calls[0][0]
assert server is runner.server
assert host is None
assert port == 8080
m.create_server.assert_called_once()
args, kwargs = m.create_server.call_args
assert args == (runner.server, None, 8080)


async def test_tcpsite_empty_str_host(make_runner: Any) -> None:
async def test_tcpsite_empty_str_host(make_runner: _RunnerMaker) -> None:
runner = make_runner()
await runner.setup()
site = web.TCPSite(runner, host="")
assert site.name == "http://0.0.0.0:8080"


def test_run_after_asyncio_run() -> None:
async def nothing():
pass
called = False

def spy():
spy.called = True
async def nothing() -> None:
pass

spy.called = False
def spy() -> None:
nonlocal called
called = True

async def shutdown():
async def shutdown() -> NoReturn:
spy()
raise web.GracefulExit()

Expand All @@ -271,4 +281,4 @@ async def shutdown():
app.on_startup.append(lambda a: asyncio.create_task(shutdown()))

web.run_app(app)
assert spy.called, "run_app() should work after asyncio.run()."
assert called, "run_app() should work after asyncio.run()."
Loading

0 comments on commit 56aa261

Please sign in to comment.