Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

2369 graceful shutdown of signal handlers #2444

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions sanic/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -1300,6 +1300,48 @@ def shutdown_tasks(
self.purge_tasks()
timeout -= increment

def shutdown_signal_handlers(
self, timeout: Optional[float] = None, increment: float = 0.1
) -> None:
"""Cancel running signal handler tasks.

Any running ``asyncio.Task`` with a name starting with "signal" will be
cancelled. If a :param:`timeout` is not provided, it will be set to the
``GRACEFUL_SHUTDOWN_TIMEOUT`` config.

:param timeout: the max amount of time to wait for the tasks to be
cancelled. Defaults to None.
:type timeout: Optional[float], optional
:param increment: the amount of time to wait between checking that the
tasks have been cancelled. Defaults to 0.1.
:type increment: float, optional
"""
logger.info("Cancelling signal handlers")

if timeout is None:
timeout = self.config.GRACEFUL_SHUTDOWN_TIMEOUT

signal_handlers = [
task
for task in asyncio.all_tasks(self.loop)
if task.get_name().startswith("signal")
]

logger.debug("%d signal handlers found", len(signal_handlers))

for handler in signal_handlers:
logger.debug("Cancelling signal handler: %s", handler.get_name())
handler.cancel()

with suppress(RuntimeError):
while timeout and not all(
[handler.done() for handler in signal_handlers]
):
self.loop.run_until_complete(asyncio.sleep(increment))
timeout -= increment

logger.info("Signal handlers cancelled")

@property
def tasks(self):
return iter(self._task_registry.values())
Expand Down
1 change: 1 addition & 0 deletions sanic/mixins/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,7 @@ def stop(self):
"""
if self.state.stage is not ServerStage.STOPPED:
self.shutdown_tasks(timeout=0)
self.shutdown_signal_handlers()
for task in all_tasks():
with suppress(AttributeError):
if task.get_name() == "RunServer":
Expand Down
2 changes: 1 addition & 1 deletion sanic/signals.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ async def dispatch(
if inline:
return await dispatch

task = asyncio.get_running_loop().create_task(dispatch)
task = asyncio.get_running_loop().create_task(dispatch, name="signal")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you think we should add a more descriptive name? Also include the handler __name__?

await asyncio.sleep(0)
return task

Expand Down
33 changes: 33 additions & 0 deletions tests/test_graceful_shutdown.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,36 @@ def ping():
"Transport is closed."
)
assert info == 11


def test_no_exceptions_when_cancel_signal_handlers(app, caplog):
@app.signal("foo.bar.baz")
async def async_signal(*_):
await asyncio.sleep(5)

@app.get("/")
async def handler(request):
request.app.dispatch("foo.bar.baz")

def ping():
httpx.get("http://127.0.0.1:8000")

p = Process(target=ping)
p.start()

with caplog.at_level(logging.INFO):
app.run()

p.kill()

info = 0
for record in caplog.record_tuples:
assert record[1] != logging.ERROR

if record[1] == logging.INFO and (
record[2] == "Cancelling signal handlers"
or record[2] == "Signal handlers cancelled"
):
info += 1

assert info == 2
12 changes: 12 additions & 0 deletions tests/test_signals.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,3 +402,15 @@ def test_signal_reservation(app, event, expected):
app.signal(event)(lambda: ...)
else:
app.signal(event)(lambda: ...)


@pytest.mark.asyncio
async def test_signal_handler_task_name(app):
@app.signal("foo.bar.baz")
def sync_signal(*_):
...

app.signal_router.finalize()

signal_handler_task = await app.dispatch("foo.bar.baz")
assert signal_handler_task.get_name() == "signal"