From 484419d3743b44d8ad0f10a3c068aa0c9d3f27a1 Mon Sep 17 00:00:00 2001 From: fjetter Date: Fri, 19 Jan 2024 12:28:41 +0100 Subject: [PATCH 1/6] decompose server class --- distributed/_async_taskgroup.py | 159 +++++++ distributed/core.py | 606 ++------------------------ distributed/event.py | 2 +- distributed/http/scheduler/json.py | 2 +- distributed/lock.py | 2 +- distributed/nanny.py | 23 +- distributed/node.py | 488 ++++++++++++++++++++- distributed/preloading.py | 8 +- distributed/pubsub.py | 4 +- distributed/queues.py | 4 +- distributed/scheduler.py | 72 +-- distributed/shuffle/_core.py | 2 +- distributed/shuffle/_rechunk.py | 2 +- distributed/shuffle/_shuffle.py | 4 +- distributed/shuffle/_worker_plugin.py | 14 +- distributed/stealing.py | 2 +- distributed/tests/test_core.py | 126 +----- distributed/tests/test_nanny.py | 8 +- distributed/tests/test_node.py | 104 +++++ distributed/tests/test_scheduler.py | 29 +- distributed/utils_test.py | 3 +- distributed/variable.py | 6 +- distributed/worker.py | 89 ++-- 23 files changed, 922 insertions(+), 837 deletions(-) create mode 100644 distributed/_async_taskgroup.py create mode 100644 distributed/tests/test_node.py diff --git a/distributed/_async_taskgroup.py b/distributed/_async_taskgroup.py new file mode 100644 index 00000000000..a048491d302 --- /dev/null +++ b/distributed/_async_taskgroup.py @@ -0,0 +1,159 @@ +from __future__ import annotations + +import asyncio +import threading +from collections.abc import Callable, Coroutine +from typing import TYPE_CHECKING, Any, TypeVar + +if TYPE_CHECKING: + from typing_extensions import ParamSpec + + P = ParamSpec("P") + R = TypeVar("R") + T = TypeVar("T") + Coro = Coroutine[Any, Any, T] + + +class _LoopBoundMixin: + """Backport of the private asyncio.mixins._LoopBoundMixin from 3.11""" + + _global_lock = threading.Lock() + + _loop = None + + def _get_loop(self): + loop = asyncio.get_running_loop() + + if self._loop is None: + with self._global_lock: + if self._loop is None: + self._loop = loop + if loop is not self._loop: + raise RuntimeError(f"{self!r} is bound to a different event loop") + return loop + + +class AsyncTaskGroupClosedError(RuntimeError): + pass + + +def _delayed(corofunc: Callable[P, Coro[T]], delay: float) -> Callable[P, Coro[T]]: + """Decorator to delay the evaluation of a coroutine function by the given delay in seconds.""" + + async def wrapper(*args: P.args, **kwargs: P.kwargs) -> T: + await asyncio.sleep(delay) + return await corofunc(*args, **kwargs) + + return wrapper + + +class AsyncTaskGroup(_LoopBoundMixin): + """Collection tracking all currently running asynchronous tasks within a group""" + + #: If True, the group is closed and does not allow adding new tasks. + closed: bool + + def __init__(self) -> None: + self.closed = False + self._ongoing_tasks: set[asyncio.Task[None]] = set() + + def call_soon( + self, afunc: Callable[P, Coro[None]], /, *args: P.args, **kwargs: P.kwargs + ) -> None: + """Schedule a coroutine function to be executed as an `asyncio.Task`. + + The coroutine function `afunc` is scheduled with `args` arguments and `kwargs` keyword arguments + as an `asyncio.Task`. + + Parameters + ---------- + afunc + Coroutine function to schedule. + *args + Arguments to be passed to `afunc`. + **kwargs + Keyword arguments to be passed to `afunc` + + Returns + ------- + None + + Raises + ------ + AsyncTaskGroupClosedError + If the task group is closed. + """ + if self.closed: # Avoid creating a coroutine + raise AsyncTaskGroupClosedError( + "Cannot schedule a new coroutine function as the group is already closed." + ) + task = self._get_loop().create_task(afunc(*args, **kwargs)) + task.add_done_callback(self._ongoing_tasks.remove) + self._ongoing_tasks.add(task) + return None + + def call_later( + self, + delay: float, + afunc: Callable[P, Coro[None]], + /, + *args: P.args, + **kwargs: P.kwargs, + ) -> None: + """Schedule a coroutine function to be executed after `delay` seconds as an `asyncio.Task`. + + The coroutine function `afunc` is scheduled with `args` arguments and `kwargs` keyword arguments + as an `asyncio.Task` that is executed after `delay` seconds. + + Parameters + ---------- + delay + Delay in seconds. + afunc + Coroutine function to schedule. + *args + Arguments to be passed to `afunc`. + **kwargs + Keyword arguments to be passed to `afunc` + + Returns + ------- + The None + + Raises + ------ + AsyncTaskGroupClosedError + If the task group is closed. + """ + self.call_soon(_delayed(afunc, delay), *args, **kwargs) + + def close(self) -> None: + """Closes the task group so that no new tasks can be scheduled. + + Existing tasks continue to run. + """ + self.closed = True + + async def stop(self) -> None: + """Close the group and stop all currently running tasks. + + Closes the task group and cancels all tasks. All tasks are cancelled + an additional time for each time this task is cancelled. + """ + self.close() + + current_task = asyncio.current_task(self._get_loop()) + err = None + while tasks_to_stop := (self._ongoing_tasks - {current_task}): + for task in tasks_to_stop: + task.cancel() + try: + await asyncio.wait(tasks_to_stop) + except asyncio.CancelledError as e: + err = e + + if err is not None: + raise err + + def __len__(self): + return len(self._ongoing_tasks) diff --git a/distributed/core.py b/distributed/core.py index 90705e80515..0a59bfeb740 100644 --- a/distributed/core.py +++ b/distributed/core.py @@ -5,36 +5,24 @@ import inspect import logging import math -import os import sys -import tempfile -import threading import traceback import types -import uuid import warnings import weakref -from collections import defaultdict, deque -from collections.abc import ( - Awaitable, - Callable, - Container, - Coroutine, - Generator, - Hashable, -) +from collections import defaultdict +from collections.abc import Callable, Generator from enum import Enum -from functools import wraps -from typing import TYPE_CHECKING, Any, ClassVar, Literal, TypedDict, TypeVar, final +from typing import TYPE_CHECKING, Any, ClassVar, Literal, TypedDict import tblib from tlz import merge from tornado.ioloop import IOLoop import dask -from dask.utils import parse_timedelta -from distributed import profile, protocol +from distributed import protocol +from distributed._async_taskgroup import AsyncTaskGroup, AsyncTaskGroupClosedError from distributed.comm import ( Comm, CommClosedError, @@ -45,33 +33,16 @@ unparse_host_port, ) from distributed.comm.core import Listener -from distributed.compatibility import PeriodicCallback -from distributed.counter import Counter -from distributed.diskutils import WorkDir, WorkSpace -from distributed.metrics import context_meter, time -from distributed.system_monitor import SystemMonitor from distributed.utils import ( NoOpAwaitable, get_traceback, has_keyword, - import_file, iscoroutinefunction, - offload, - recursive_to_dict, truncate_exception, - wait_for, - warn_on_duration, ) if TYPE_CHECKING: - from typing_extensions import ParamSpec, Self - - from distributed.counter import Digest - - P = ParamSpec("P") - R = TypeVar("R") - T = TypeVar("T") - Coro = Coroutine[Any, Any, T] + from typing_extensions import Self class Status(Enum): @@ -114,10 +85,6 @@ def _raise(*args, **kwargs): return _raise -tick_maximum_delay = parse_timedelta( - dask.config.get("distributed.admin.tick.limit"), default="ms" -) - LOG_PDB = dask.config.get("distributed.admin.pdb-on-err") @@ -138,151 +105,6 @@ def _expects_comm(func: Callable) -> bool: return False -class _LoopBoundMixin: - """Backport of the private asyncio.mixins._LoopBoundMixin from 3.11""" - - _global_lock = threading.Lock() - - _loop = None - - def _get_loop(self): - loop = asyncio.get_running_loop() - - if self._loop is None: - with self._global_lock: - if self._loop is None: - self._loop = loop - if loop is not self._loop: - raise RuntimeError(f"{self!r} is bound to a different event loop") - return loop - - -class AsyncTaskGroupClosedError(RuntimeError): - pass - - -def _delayed(corofunc: Callable[P, Coro[T]], delay: float) -> Callable[P, Coro[T]]: - """Decorator to delay the evaluation of a coroutine function by the given delay in seconds.""" - - async def wrapper(*args: P.args, **kwargs: P.kwargs) -> T: - await asyncio.sleep(delay) - return await corofunc(*args, **kwargs) - - return wrapper - - -class AsyncTaskGroup(_LoopBoundMixin): - """Collection tracking all currently running asynchronous tasks within a group""" - - #: If True, the group is closed and does not allow adding new tasks. - closed: bool - - def __init__(self) -> None: - self.closed = False - self._ongoing_tasks: set[asyncio.Task[None]] = set() - - def call_soon( - self, afunc: Callable[P, Coro[None]], /, *args: P.args, **kwargs: P.kwargs - ) -> None: - """Schedule a coroutine function to be executed as an `asyncio.Task`. - - The coroutine function `afunc` is scheduled with `args` arguments and `kwargs` keyword arguments - as an `asyncio.Task`. - - Parameters - ---------- - afunc - Coroutine function to schedule. - *args - Arguments to be passed to `afunc`. - **kwargs - Keyword arguments to be passed to `afunc` - - Returns - ------- - None - - Raises - ------ - AsyncTaskGroupClosedError - If the task group is closed. - """ - if self.closed: # Avoid creating a coroutine - raise AsyncTaskGroupClosedError( - "Cannot schedule a new coroutine function as the group is already closed." - ) - task = self._get_loop().create_task(afunc(*args, **kwargs)) - task.add_done_callback(self._ongoing_tasks.remove) - self._ongoing_tasks.add(task) - return None - - def call_later( - self, - delay: float, - afunc: Callable[P, Coro[None]], - /, - *args: P.args, - **kwargs: P.kwargs, - ) -> None: - """Schedule a coroutine function to be executed after `delay` seconds as an `asyncio.Task`. - - The coroutine function `afunc` is scheduled with `args` arguments and `kwargs` keyword arguments - as an `asyncio.Task` that is executed after `delay` seconds. - - Parameters - ---------- - delay - Delay in seconds. - afunc - Coroutine function to schedule. - *args - Arguments to be passed to `afunc`. - **kwargs - Keyword arguments to be passed to `afunc` - - Returns - ------- - The None - - Raises - ------ - AsyncTaskGroupClosedError - If the task group is closed. - """ - self.call_soon(_delayed(afunc, delay), *args, **kwargs) - - def close(self) -> None: - """Closes the task group so that no new tasks can be scheduled. - - Existing tasks continue to run. - """ - self.closed = True - - async def stop(self) -> None: - """Close the group and stop all currently running tasks. - - Closes the task group and cancels all tasks. All tasks are cancelled - an additional time for each time this task is cancelled. - """ - self.close() - - current_task = asyncio.current_task(self._get_loop()) - err = None - while tasks_to_stop := (self._ongoing_tasks - {current_task}): - for task in tasks_to_stop: - task.cancel() - try: - await asyncio.wait(tasks_to_stop) - except asyncio.CancelledError as e: - err = e - - if err is not None: - raise err - - def __len__(self): - return len(self._ongoing_tasks) - - class Server: """Dask Distributed Server @@ -325,34 +147,12 @@ class Server: default_ip: ClassVar[str] = "" default_port: ClassVar[int] = 0 - id: str blocked_handlers: list[str] handlers: dict[str, Callable] stream_handlers: dict[str, Callable] listeners: list[Listener] - counters: defaultdict[str, Counter] deserialize: bool - local_directory: str - - monitor: SystemMonitor - io_loop: IOLoop - thread_id: int - - periodic_callbacks: dict[str, PeriodicCallback] - digests: defaultdict[Hashable, Digest] | None - digests_total: defaultdict[Hashable, float] - digests_total_since_heartbeat: defaultdict[Hashable, float] - digests_max: defaultdict[Hashable, float] - - _last_tick: float - _tick_counter: int - _last_tick_counter: int - _tick_interval: float - _tick_interval_observed: float - - _status: Status - _address: str | None _listen_address: str | None _host: str | None @@ -360,16 +160,7 @@ class Server: _comms: dict[Comm, str | None] - _ongoing_background_tasks: AsyncTaskGroup - _event_finished: asyncio.Event - - _original_local_dir: str - _updated_sys_path: bool - _workspace: WorkSpace - _workdir: None | WorkDir - - _startup_lock: asyncio.Lock - __startup_exc: Exception | None + _handle_comm_tasks: AsyncTaskGroup def __init__( self, @@ -382,150 +173,26 @@ def __init__( deserializers=None, connection_args=None, timeout=None, - io_loop=None, - local_directory=None, - needs_workdir=True, ): - if local_directory is None: - local_directory = ( - dask.config.get("temporary-directory") or tempfile.gettempdir() - ) - - if "dask-scratch-space" not in str(local_directory): - local_directory = os.path.join(local_directory, "dask-scratch-space") - - self._original_local_dir = local_directory - - with warn_on_duration( - "1s", - "Creating scratch directories is taking a surprisingly long time. ({duration:.2f}s) " - "This is often due to running workers on a network file system. " - "Consider specifying a local-directory to point workers to write " - "scratch data to a local disk.", - ): - self._workspace = WorkSpace(local_directory) - - if not needs_workdir: # eg. Nanny will not need a WorkDir - self._workdir = None - self.local_directory = self._workspace.base_dir - else: - name = type(self).__name__.lower() - self._workdir = self._workspace.new_work_dir(prefix=f"{name}-") - self.local_directory = self._workdir.dir_path - - self._updated_sys_path = False - if self.local_directory not in sys.path: - sys.path.insert(0, self.local_directory) - self._updated_sys_path = True - - if io_loop is not None: - warnings.warn( - "The io_loop kwarg to Server is ignored and will be deprecated", - DeprecationWarning, - stacklevel=2, - ) - - self._status = Status.init self.handlers = { - "identity": self.identity, "echo": self.echo, + "identity": self.identity, "connection_stream": self.handle_stream, - "dump_state": self._to_dict, } self.handlers.update(handlers) - if blocked_handlers is None: - blocked_handlers = dask.config.get( - "distributed.%s.blocked-handlers" % type(self).__name__.lower(), [] - ) - self.blocked_handlers = blocked_handlers + self.blocked_handlers = blocked_handlers or {} self.stream_handlers = {} self.stream_handlers.update(stream_handlers or {}) - self.id = type(self).__name__ + "-" + str(uuid.uuid4()) self._address = None self._listen_address = None self._port = None self._host = None self._comms = {} self.deserialize = deserialize - self.monitor = SystemMonitor() - self._ongoing_background_tasks = AsyncTaskGroup() - self._event_finished = asyncio.Event() + self._handle_comm_tasks = AsyncTaskGroup() self.listeners = [] - self.io_loop = self.loop = IOLoop.current() - - if not hasattr(self.io_loop, "profile"): - if dask.config.get("distributed.worker.profile.enabled"): - ref = weakref.ref(self.io_loop) - - def stop() -> bool: - loop = ref() - return loop is None or loop.asyncio_loop.is_closed() - - self.io_loop.profile = profile.watch( - omit=("profile.py", "selectors.py"), - interval=dask.config.get("distributed.worker.profile.interval"), - cycle=dask.config.get("distributed.worker.profile.cycle"), - stop=stop, - ) - else: - self.io_loop.profile = deque() - - self.periodic_callbacks = {} - - # Statistics counters for various events - try: - from distributed.counter import Digest - - self.digests = defaultdict(Digest) - except ImportError: - self.digests = None - - # Also log cumulative totals (reset at server restart) - # and local maximums (reset by prometheus poll) - # Don't cast int metrics to float - self.digests_total = defaultdict(int) - self.digests_total_since_heartbeat = defaultdict(int) - self.digests_max = defaultdict(int) - - self.counters = defaultdict(Counter) - pc = PeriodicCallback(self._shift_counters, 5000) - self.periodic_callbacks["shift_counters"] = pc - - pc = PeriodicCallback( - self.monitor.update, - parse_timedelta( - dask.config.get("distributed.admin.system-monitor.interval") - ) - * 1000, - ) - self.periodic_callbacks["monitor"] = pc - - self._last_tick = time() - self._tick_counter = 0 - self._last_tick_counter = 0 - self._last_tick_cycle = time() - self._tick_interval = parse_timedelta( - dask.config.get("distributed.admin.tick.interval"), default="ms" - ) - self._tick_interval_observed = self._tick_interval - self.periodic_callbacks["tick"] = PeriodicCallback( - self._measure_tick, self._tick_interval * 1000 - ) - self.periodic_callbacks["ticks"] = PeriodicCallback( - self._cycle_ticks, - parse_timedelta(dask.config.get("distributed.admin.tick.cycle")) * 1000, - ) - - self.thread_id = 0 - - def set_thread_ident(): - self.thread_id = threading.get_ident() - - self.io_loop.add_callback(set_thread_ident) - self._startup_lock = asyncio.Lock() - self.__startup_exc = None self.rpc = ConnectionPool( limit=connection_limit, @@ -538,54 +205,7 @@ def set_thread_ident(): ) self.__stopped = False - - async def upload_file( - self, filename: str, data: str | bytes, load: bool = True - ) -> dict[str, Any]: - out_filename = os.path.join(self.local_directory, filename) - - def func(data): - if isinstance(data, str): - data = data.encode() - with open(out_filename, "wb") as f: - f.write(data) - f.flush() - os.fsync(f.fileno()) - return data - - if len(data) < 10000: - data = func(data) - else: - data = await offload(func, data) - - if load: - try: - import_file(out_filename) - except Exception as e: - logger.exception(e) - raise e - - return {"status": "OK", "nbytes": len(data)} - - def _shift_counters(self): - for counter in self.counters.values(): - counter.shift() - if self.digests is not None: - for digest in self.digests.values(): - digest.shift() - - @property - def status(self) -> Status: - try: - return self._status - except AttributeError: - return Status.undefined - - @status.setter - def status(self, value: Status) -> None: - if not isinstance(value, Status): - raise TypeError(f"Expected Status; got {value!r}") - self._status = value + super().__init__() @property def incoming_comms_open(self) -> int: @@ -628,55 +248,13 @@ def get_connection_counters(self) -> dict[str, int]: ] } - async def finished(self): - """Wait until the server has finished""" - await self._event_finished.wait() - def __await__(self): - return self.start().__await__() - - async def start_unsafe(self): - """Attempt to start the server. This is not idempotent and not protected against concurrent startup attempts. + return self.start_pool().__await__() - This is intended to be overwritten or called by subclasses. For a safe - startup, please use ``Server.start`` instead. - - If ``death_timeout`` is configured, we will require this coroutine to - finish before this timeout is reached. If the timeout is reached we will - close the instance and raise an ``asyncio.TimeoutError`` - """ + async def start_pool(self): await self.rpc.start() return self - @final - async def start(self): - async with self._startup_lock: - if self.status == Status.failed: - assert self.__startup_exc is not None - raise self.__startup_exc - elif self.status != Status.init: - return self - timeout = getattr(self, "death_timeout", None) - - async def _close_on_failure(exc: Exception) -> None: - await self.close(reason=f"failure-to-start-{str(type(exc))}") - self.status = Status.failed - self.__startup_exc = exc - - try: - await wait_for(self.start_unsafe(), timeout=timeout) - except asyncio.TimeoutError as exc: - await _close_on_failure(exc) - raise asyncio.TimeoutError( - f"{type(self).__name__} start timed out after {timeout}s." - ) from exc - except Exception as exc: - await _close_on_failure(exc) - raise RuntimeError(f"{type(self).__name__} failed to start.") from exc - if self.status == Status.init: - self.status = Status.running - return self - async def __aenter__(self): await self return self @@ -684,50 +262,21 @@ async def __aenter__(self): async def __aexit__(self, exc_type, exc_value, traceback): await self.close() - def start_periodic_callbacks(self): - """Start Periodic Callbacks consistently - - This starts all PeriodicCallbacks stored in self.periodic_callbacks if - they are not yet running. It does this safely by checking that it is using the - correct event loop. - """ - if self.io_loop.asyncio_loop is not asyncio.get_running_loop(): - raise RuntimeError(f"{self!r} is bound to a different event loop") - - self._last_tick = time() - for pc in self.periodic_callbacks.values(): - if not pc.is_running(): - pc.start() - - def _stop_listeners(self) -> asyncio.Future: - listeners_to_stop: set[Awaitable] = set() - + def _stop_listeners(self) -> None: for listener in self.listeners: - future = listener.stop() - if inspect.isawaitable(future): - warnings.warn( - f"{type(listener)} is using an asynchronous `stop` method. " - "Support for asynchronous `Listener.stop` has been deprecated and " - "will be removed in a future version", - DeprecationWarning, - ) - listeners_to_stop.add(future) - elif hasattr(listener, "abort_handshaking_comms"): + listener.stop() + if hasattr(listener, "abort_handshaking_comms"): listener.abort_handshaking_comms() - return asyncio.gather(*listeners_to_stop) + @property + def stopped(self) -> bool: + return self.__stopped def stop(self) -> None: if self.__stopped: return self.__stopped = True - self.monitor.close() - if not (stop_listeners := self._stop_listeners()).done(): - self._ongoing_background_tasks.call_soon( - asyncio.wait_for(stop_listeners, timeout=None) # type: ignore[arg-type] - ) - if self._workdir is not None: - self._workdir.release() + self._stop_listeners() @property def listener(self) -> Listener | None: @@ -736,33 +285,6 @@ def listener(self) -> Listener | None: else: return None - def _measure_tick(self): - now = time() - tick_duration = now - self._last_tick - self._last_tick = now - self._tick_counter += 1 - # This metric is exposed in Prometheus and is reset there during - # collection - if tick_duration > tick_maximum_delay: - logger.info( - "Event loop was unresponsive in %s for %.2fs. " - "This is often caused by long-running GIL-holding " - "functions or moving large chunks of data. " - "This can cause timeouts and instability.", - type(self).__name__, - tick_duration, - ) - self.digest_metric("tick-duration", tick_duration) - - def _cycle_ticks(self): - if not self._tick_counter: - return - now = time() - last_tick_cycle, self._last_tick_cycle = self._last_tick_cycle, now - count = self._tick_counter - self._last_tick_counter - self._last_tick_counter = self._tick_counter - self._tick_interval_observed = (now - last_tick_cycle) / (count or 1) - @property def address(self) -> str: """ @@ -824,27 +346,7 @@ def port(self): return self._port def identity(self) -> dict[str, str]: - return {"type": type(self).__name__, "id": self.id} - - def _to_dict(self, *, exclude: Container[str] = ()) -> dict[str, Any]: - """Dictionary representation for debugging purposes. - Not type stable and not intended for roundtrips. - - See also - -------- - Server.identity - Client.dump_cluster_state - distributed.utils.recursive_to_dict - """ - info: dict[str, Any] = self.identity() - extra = { - "address": self.address, - "status": self.status.name, - "thread_id": self.thread_id, - } - info.update(extra) - info = {k: v for k, v in info.items() if k not in exclude} - return recursive_to_dict(info, exclude=exclude) + return {"type": type(self).__name__, "id": str(id(self))} def echo(self, data=None): return data @@ -871,7 +373,7 @@ async def listen(self, port_or_addr=None, allow_offload=True, **kwargs): def handle_comm(self, comm: Comm) -> NoOpAwaitable: """Start a background task that dispatches new communications to coroutine-handlers""" try: - self._ongoing_background_tasks.call_soon(self._handle_comm, comm) + self._handle_comm_tasks.call_soon(self._handle_comm, comm) except AsyncTaskGroupClosedError: comm.abort() return NoOpAwaitable() @@ -930,8 +432,6 @@ async def _handle_comm(self, comm: Comm) -> None: raise ValueError( "Received unexpected message without 'op' key: " + str(msg) ) from e - if self.counters is not None: - self.counters["op"].add(op) self._comms[comm] = op serializers = msg.pop("serializers", None) close_desired = msg.pop("close", False) @@ -976,7 +476,7 @@ async def _handle_comm(self, comm: Comm) -> None: f"Comm handler returned unknown awaitable. Expected coroutine, instead got {type(result)}" ) except CommClosedError: - if self.status == Status.running: + if not self.__stopped: logger.info("Lost connection to %r", address, exc_info=True) break except Exception as e: @@ -1067,61 +567,15 @@ async def handle_stream( await comm.close() assert comm.closed() - async def close(self, timeout: float | None = None, reason: str = "") -> None: - try: - for pc in self.periodic_callbacks.values(): - pc.stop() - - self.__stopped = True - self.monitor.close() - await self._stop_listeners() - - # TODO: Deal with exceptions - await self._ongoing_background_tasks.stop() - - await self.rpc.close() - await asyncio.gather(*[comm.close() for comm in list(self._comms)]) - - # Remove scratch directory from global sys.path - if self._updated_sys_path and sys.path[0] == self.local_directory: - sys.path.remove(self.local_directory) - finally: - self._event_finished.set() - - def digest_metric(self, name: Hashable, value: float) -> None: - # Granular data (requires crick) - if self.digests is not None: - self.digests[name].add(value) - # Cumulative data (reset by server restart) - self.digests_total[name] += value - # Cumulative data sent to scheduler and reset on heartbeat - self.digests_total_since_heartbeat[name] += value - # Local maximums (reset by Prometheus poll) - self.digests_max[name] = max(self.digests_max[name], value) - - -def context_meter_to_server_digest(digest_tag: str) -> Callable: - """Decorator for an async method of a Server subclass that calls - ``distributed.metrics.context_meter.meter`` and/or ``digest_metric``. - It routes the calls from ``context_meter.digest_metric(label, value, unit)`` to - ``Server.digest_metric((digest_tag, label, unit), value)``. - """ - - def decorator(func: Callable) -> Callable: - @wraps(func) - async def wrapper(self: Server, *args: Any, **kwargs: Any) -> Any: - def metrics_callback(label: Hashable, value: float, unit: str) -> None: - if not isinstance(label, tuple): - label = (label,) - name = (digest_tag, *label, unit) - self.digest_metric(name, value) - - with context_meter.add_callback(metrics_callback, allow_offload=True): - return await func(self, *args, **kwargs) + async def close(self) -> None: + self.__stopped = True + self._stop_listeners() - return wrapper + # TODO: Deal with exceptions + await self._handle_comm_tasks.stop() - return decorator + await self.rpc.close() + await asyncio.gather(*[comm.close() for comm in list(self._comms)]) def pingpong(comm): diff --git a/distributed/event.py b/distributed/event.py index 145abbf3857..9cad4f0d59a 100644 --- a/distributed/event.py +++ b/distributed/event.py @@ -50,7 +50,7 @@ def __init__(self, scheduler): # we can remove the event self._waiter_count = defaultdict(int) - self.scheduler.handlers.update( + self.scheduler.server.handlers.update( { "event_wait": self.event_wait, "event_set": self.event_set, diff --git a/distributed/http/scheduler/json.py b/distributed/http/scheduler/json.py index 932734f56a7..086cfbbbaae 100644 --- a/distributed/http/scheduler/json.py +++ b/distributed/http/scheduler/json.py @@ -55,7 +55,7 @@ def get(self): class IdentityJSON(RequestHandler): def get(self): - self.write(self.server.identity()) + self.write(self.identity()) class IndexJSON(RequestHandler): diff --git a/distributed/lock.py b/distributed/lock.py index 99ec34cd6f7..42d326dd413 100644 --- a/distributed/lock.py +++ b/distributed/lock.py @@ -27,7 +27,7 @@ def __init__(self, scheduler): self.events = defaultdict(deque) self.ids = dict() - self.scheduler.handlers.update( + self.scheduler.server.handlers.update( {"lock_acquire": self.acquire, "lock_release": self.release} ) diff --git a/distributed/nanny.py b/distributed/nanny.py index 52e4ad5b360..ec1c547f5f3 100644 --- a/distributed/nanny.py +++ b/distributed/nanny.py @@ -268,7 +268,7 @@ def __init__( # type: ignore[no-untyped-def] self.silence_logs = silence_logs self.plugins: dict[str, NannyPlugin] = {} - self.scheduler = self.rpc(self.scheduler_addr) + self.scheduler = self.server.rpc(self.scheduler_addr) self.memory_manager = NannyMemoryManager(self, memory_limit=memory_limit) if ( @@ -334,7 +334,7 @@ async def start_unsafe(self): security=self.security, ) try: - await self.listen( + await self.server.listen( start_address, **self.security.get_listen_args("worker") ) except OSError as e: @@ -351,21 +351,21 @@ async def start_unsafe(self): f"with port {self._start_port}" ) - self.ip = get_address_host(self.address) + self.ip = get_address_host(self.server.address) await self.preloads.start() saddr = self.scheduler.addr - comm = await self.rpc.connect(saddr) + comm = await self.server.rpc.connect(saddr) comm.name = "Nanny->Scheduler (registration)" try: - await comm.write({"op": "register_nanny", "address": self.address}) + await comm.write({"op": "register_nanny", "address": self.server.address}) msg = await comm.read() try: for name, plugin in msg["nanny-plugins"].items(): await self.plugin_add(plugin=plugin, name=name) - logger.info(" Start Nanny at: %r", self.address) + logger.info(" Start Nanny at: %r", self.server.address) response = await self.instantiate() if response != Status.running: @@ -410,7 +410,7 @@ async def instantiate(self) -> Status: nthreads=self.nthreads, local_directory=self._original_local_dir, services=self.services, - nanny=self.address, + nanny=self.server.address, name=self.name, memory_limit=self.memory_manager.memory_limit, resources=self.resources, @@ -589,7 +589,9 @@ def close_gracefully(self, reason: str = "nanny-close-gracefully") -> None: """ self.status = Status.closing_gracefully logger.info( - "Closing Nanny gracefully at %r. Reason: %s", self.address_safe, reason + "Closing Nanny gracefully at %r. Reason: %s", + self.server.address_safe, + reason, ) async def close( # type:ignore[override] @@ -606,7 +608,7 @@ async def close( # type:ignore[override] return "OK" self.status = Status.closing - logger.info("Closing Nanny at %r. Reason: %s", self.address_safe, reason) + logger.info("Closing Nanny at %r. Reason: %s", self.server.address_safe, reason) await self.preloads.teardown() @@ -618,13 +620,10 @@ async def close( # type:ignore[override] await asyncio.gather(*(td for td in teardowns if isawaitable(td))) - self.stop() if self.process is not None: await self.kill(timeout=timeout, reason=reason) self.process = None - await self.rpc.close() - self.status = Status.closed await super().close() self.__exit_stack.__exit__(None, None, None) return "OK" diff --git a/distributed/node.py b/distributed/node.py index d41a3409850..5a3d052c901 100644 --- a/distributed/node.py +++ b/distributed/node.py @@ -1,32 +1,281 @@ from __future__ import annotations +import asyncio import logging +import os import ssl +import sys +import tempfile +import threading import warnings import weakref +from collections import defaultdict, deque +from collections.abc import Container, Coroutine, Hashable from contextlib import suppress +from functools import wraps +from typing import TYPE_CHECKING, Any, Callable, TypeVar, final import tlz from tornado.httpserver import HTTPServer +from tornado.ioloop import IOLoop import dask +from dask.utils import parse_timedelta +from distributed import profile +from distributed._async_taskgroup import AsyncTaskGroup from distributed.comm import get_address_host, get_tcp_server_addresses -from distributed.core import Server +from distributed.compatibility import PeriodicCallback +from distributed.core import Server, Status +from distributed.counter import Counter +from distributed.diskutils import WorkDir, WorkSpace from distributed.http.routing import RoutingApplication -from distributed.utils import DequeHandler, clean_dashboard_address +from distributed.metrics import context_meter, time +from distributed.system_monitor import SystemMonitor +from distributed.utils import ( + DequeHandler, + clean_dashboard_address, + import_file, + offload, + recursive_to_dict, + wait_for, + warn_on_duration, +) from distributed.versions import get_versions +if TYPE_CHECKING: + from typing_extensions import ParamSpec -class ServerNode(Server): - """ - Base class for server nodes in a distributed cluster. - """ + from distributed.counter import Digest + + P = ParamSpec("P") + R = TypeVar("R") + T = TypeVar("T") + Coro = Coroutine[Any, Any, T] + + +logger = logging.getLogger(__name__) +tick_maximum_delay = parse_timedelta( + dask.config.get("distributed.admin.tick.limit"), default="ms" +) + + +class Node: + _startup_lock: asyncio.Lock + __startup_exc: Exception | None + local_directory: str + monitor: SystemMonitor + + periodic_callbacks: dict[str, PeriodicCallback] + digests: defaultdict[Hashable, Digest] | None + digests_total: defaultdict[Hashable, float] + digests_total_since_heartbeat: defaultdict[Hashable, float] + digests_max: defaultdict[Hashable, float] + + _last_tick: float + _tick_counter: int + _last_tick_counter: int + _tick_interval: float + _tick_interval_observed: float + + _original_local_dir: str + _updated_sys_path: bool + _workspace: WorkSpace + _workdir: None | WorkDir + _ongoing_background_tasks: AsyncTaskGroup + + _event_finished: asyncio.Event + + def __init__( + self, + local_directory=None, + needs_workdir=True, + ): + if local_directory is None: + local_directory = ( + dask.config.get("temporary-directory") or tempfile.gettempdir() + ) + + if "dask-scratch-space" not in str(local_directory): + local_directory = os.path.join(local_directory, "dask-scratch-space") + self.monitor = SystemMonitor() + + self._original_local_dir = local_directory + self._ongoing_background_tasks = AsyncTaskGroup() + with warn_on_duration( + "1s", + "Creating scratch directories is taking a surprisingly long time. ({duration:.2f}s) " + "This is often due to running workers on a network file system. " + "Consider specifying a local-directory to point workers to write " + "scratch data to a local disk.", + ): + self._workspace = WorkSpace(local_directory) + + if not needs_workdir: # eg. Nanny will not need a WorkDir + self._workdir = None + self.local_directory = self._workspace.base_dir + else: + name = type(self).__name__.lower() + self._workdir = self._workspace.new_work_dir(prefix=f"{name}-") + self.local_directory = self._workdir.dir_path + + self._updated_sys_path = False + if self.local_directory not in sys.path: + sys.path.insert(0, self.local_directory) + self._updated_sys_path = True + + self.io_loop = self.loop = IOLoop.current() + + if not hasattr(self.io_loop, "profile"): + if dask.config.get("distributed.worker.profile.enabled"): + ref = weakref.ref(self.io_loop) + + def stop() -> bool: + loop = ref() + return loop is None or loop.asyncio_loop.is_closed() + + self.io_loop.profile = profile.watch( + omit=("profile.py", "selectors.py"), + interval=dask.config.get("distributed.worker.profile.interval"), + cycle=dask.config.get("distributed.worker.profile.cycle"), + stop=stop, + ) + else: + self.io_loop.profile = deque() + + self.periodic_callbacks = {} + + # Statistics counters for various events + try: + from distributed.counter import Digest + + self.digests = defaultdict(Digest) + except ImportError: + self.digests = None + + # Also log cumulative totals (reset at server restart) + # and local maximums (reset by prometheus poll) + # Don't cast int metrics to float + self.digests_total = defaultdict(int) + self.digests_total_since_heartbeat = defaultdict(int) + self.digests_max = defaultdict(int) + + self.counters = defaultdict(Counter) + pc = PeriodicCallback(self._shift_counters, 5000) + self.periodic_callbacks["shift_counters"] = pc + + pc = PeriodicCallback( + self.monitor.update, + parse_timedelta( + dask.config.get("distributed.admin.system-monitor.interval") + ) + * 1000, + ) + self.periodic_callbacks["monitor"] = pc + + self.__startup_exc = None + self._startup_lock = asyncio.Lock() + + self._last_tick = time() + self._tick_counter = 0 + self._last_tick_counter = 0 + self._last_tick_cycle = time() + self._tick_interval = parse_timedelta( + dask.config.get("distributed.admin.tick.interval"), default="ms" + ) + self._tick_interval_observed = self._tick_interval + self.periodic_callbacks["tick"] = PeriodicCallback( + self._measure_tick, self._tick_interval * 1000 + ) + self.periodic_callbacks["ticks"] = PeriodicCallback( + self._cycle_ticks, + parse_timedelta(dask.config.get("distributed.admin.tick.cycle")) * 1000, + ) + + self.thread_id = 0 + + def set_thread_ident(): + self.thread_id = threading.get_ident() + + self.io_loop.add_callback(set_thread_ident) + self.status = Status.init + + def __await__(self): + async def _(): + await self.start() + return self + + return _().__await__() + + async def __aenter__(self): + await self + return self + + async def __aexit__(self, exc_type, exc_value, traceback): + await self.close() + + async def close(self, reason: str | None = None) -> None: + for pc in self.periodic_callbacks.values(): + pc.stop() - # TODO factor out security, listening, services, etc. here + self.monitor.close() + await self._ongoing_background_tasks.stop() + if self._workdir is not None: + self._workdir.release() - # XXX avoid inheriting from Server? there is some large potential for confusion - # between base and derived attribute namespaces... + # Remove scratch directory from global sys.path + if self._updated_sys_path and sys.path[0] == self.local_directory: + sys.path.remove(self.local_directory) + + async def upload_file( + self, filename: str, data: str | bytes, load: bool = True + ) -> dict[str, Any]: + out_filename = os.path.join(self.local_directory, filename) + + def func(data): + if isinstance(data, str): + data = data.encode() + with open(out_filename, "wb") as f: + f.write(data) + f.flush() + os.fsync(f.fileno()) + return data + + if len(data) < 10000: + data = func(data) + else: + data = await offload(func, data) + + if load: + try: + import_file(out_filename) + except Exception as e: + logger.exception(e) + raise e + + return {"status": "OK", "nbytes": len(data)} + + def _shift_counters(self): + for counter in self.counters.values(): + counter.shift() + if self.digests is not None: + for digest in self.digests.values(): + digest.shift() + + def start_periodic_callbacks(self): + """Start Periodic Callbacks consistently + + This starts all PeriodicCallbacks stored in self.periodic_callbacks if + they are not yet running. It does this safely by checking that it is using the + correct event loop. + """ + if self.io_loop.asyncio_loop is not asyncio.get_running_loop(): + raise RuntimeError(f"{self!r} is bound to a different event loop") + + self._last_tick = time() + for pc in self.periodic_callbacks.values(): + if not pc.is_running(): + pc.start() def versions(self, packages=None): return get_versions(packages=packages) @@ -184,3 +433,224 @@ def start_http_server( "Perhaps you already have a cluster running?\n" f"Hosting the HTTP server on port {actual} instead" ) + + def _measure_tick(self): + now = time() + tick_duration = now - self._last_tick + self._last_tick = now + self._tick_counter += 1 + # This metric is exposed in Prometheus and is reset there during + # collection + if tick_duration > tick_maximum_delay: + logger.info( + "Event loop was unresponsive in %s for %.2fs. " + "This is often caused by long-running GIL-holding " + "functions or moving large chunks of data. " + "This can cause timeouts and instability.", + type(self).__name__, + tick_duration, + ) + self.digest_metric("tick-duration", tick_duration) + + def _cycle_ticks(self): + if not self._tick_counter: + return + now = time() + last_tick_cycle, self._last_tick_cycle = self._last_tick_cycle, now + count = self._tick_counter - self._last_tick_counter + self._last_tick_counter = self._tick_counter + self._tick_interval_observed = (now - last_tick_cycle) / (count or 1) + + def digest_metric(self, name: Hashable, value: float) -> None: + # Granular data (requires crick) + if self.digests is not None: + self.digests[name].add(value) + # Cumulative data (reset by server restart) + self.digests_total[name] += value + # Cumulative data sent to scheduler and reset on heartbeat + self.digests_total_since_heartbeat[name] += value + # Local maximums (reset by Prometheus poll) + self.digests_max[name] = max(self.digests_max[name], value) + + @property + def status(self) -> Status: + try: + return self._status + except AttributeError: + return Status.undefined + + @status.setter + def status(self, value: Status) -> None: + if not isinstance(value, Status): + raise TypeError(f"Expected Status; got {value!r}") + self._status = value + + async def start_unsafe(self): + """Attempt to start the server. This is not idempotent and not protected against concurrent startup attempts. + + This is intended to be overwritten or called by subclasses. For a safe + startup, please use ``Node.start`` instead. + + If ``death_timeout`` is configured, we will require this coroutine to + finish before this timeout is reached. If the timeout is reached we will + close the instance and raise an ``asyncio.TimeoutError`` + """ + return self + + @final + async def start(self): + async with self._startup_lock: + if self.status == Status.failed: + assert self.__startup_exc is not None + raise self.__startup_exc + elif self.status != Status.init: + return self + timeout = getattr(self, "death_timeout", None) + + async def _close_on_failure(exc: Exception) -> None: + await self.close(reason=f"failure-to-start-{str(type(exc))}") + self.status = Status.failed + self.__startup_exc = exc + + try: + await wait_for(self.start_unsafe(), timeout=timeout) + except asyncio.TimeoutError as exc: + await _close_on_failure(exc) + raise asyncio.TimeoutError( + f"{type(self).__name__} start timed out after {timeout}s." + ) from exc + except Exception as exc: + await _close_on_failure(exc) + raise RuntimeError(f"{type(self).__name__} failed to start.") from exc + if self.status == Status.init: + self.status = Status.running + return self + + +class ServerNode(Node): + def __init__( + self, + handlers, + blocked_handlers=None, + stream_handlers=None, + connection_limit=512, + deserialize=True, + serializers=None, + deserializers=None, + connection_args=None, + timeout=None, + local_directory=None, + needs_workdir=True, + ): + self._event_finished = asyncio.Event() + + _handlers = { + "dump_state": self._to_dict, + "identity": self.identity, + } + if handlers: + _handlers.update(handlers) + import uuid + + self.id = type(self).__name__ + "-" + str(uuid.uuid4()) + + if blocked_handlers is None: + blocked_handlers = dask.config.get( + "distributed.%s.blocked-handlers" % type(self).__name__.lower(), [] + ) + self.server = Server( + handlers=_handlers, + blocked_handlers=blocked_handlers, + stream_handlers=stream_handlers, + connection_limit=connection_limit, + deserialize=deserialize, + serializers=serializers, + deserializers=deserializers, + connection_args=connection_args, + timeout=timeout, + ) + super().__init__( + local_directory=local_directory, + needs_workdir=needs_workdir, + ) + + def identity(self) -> dict[str, str]: + return {"type": type(self).__name__, "id": self.id} + + @property + def port(self): + return self.server.port + + @property + def listen_address(self): + return self.server.address + + @property + def address(self): + return self.server.address + + @property + def address_safe(self): + return self.server.address_safe + + async def start_unsafe(self): + await self.server + await super().start_unsafe() + return self + + async def finished(self) -> None: + """Wait until the server has finished""" + await self._event_finished.wait() + + async def close(self, reason: str | None = None) -> None: + try: + # Close network connections and background tasks + await self.server.close() + await Node.close(self, reason=reason) + self.status = Status.closed + finally: + self._event_finished.set() + + def _to_dict(self, *, exclude: Container[str] = ()) -> dict[str, Any]: + """Dictionary representation for debugging purposes. + Not type stable and not intended for roundtrips. + + See also + -------- + Server.identity + Client.dump_cluster_state + distributed.utils.recursive_to_dict + """ + info: dict[str, Any] = self.identity() + extra = { + "address": self.server.address, + "status": self.status.name, + "thread_id": self.thread_id, + } + info.update(extra) + info = {k: v for k, v in info.items() if k not in exclude} + return recursive_to_dict(info, exclude=exclude) + + +def context_meter_to_node_digest(digest_tag: str) -> Callable: + """Decorator for an async method of a Node subclass that calls + ``distributed.metrics.context_meter.meter`` and/or ``digest_metric``. + It routes the calls from ``context_meter.digest_metric(label, value, unit)`` to + ``Node.digest_metric((digest_tag, label, unit), value)``. + """ + + def decorator(func: Callable) -> Callable: + @wraps(func) + async def wrapper(self: Node, *args: Any, **kwargs: Any) -> Any: + def metrics_callback(label: Hashable, value: float, unit: str) -> None: + if not isinstance(label, tuple): + label = (label,) + name = (digest_tag, *label, unit) + self.digest_metric(name, value) + + with context_meter.add_callback(metrics_callback, allow_offload=True): + return await func(self, *args, **kwargs) + + return wrapper + + return decorator diff --git a/distributed/preloading.py b/distributed/preloading.py index 96c3889327e..0c810910206 100644 --- a/distributed/preloading.py +++ b/distributed/preloading.py @@ -15,7 +15,7 @@ from dask.utils import tmpfile -from distributed.core import Server +from distributed.node import ServerNode from distributed.utils import import_file if TYPE_CHECKING: @@ -167,7 +167,7 @@ class Preload: Path of a directory where files should be copied """ - dask_object: Server | Client + dask_object: ServerNode | Client name: str argv: list[str] file_dir: str | None @@ -175,7 +175,7 @@ class Preload: def __init__( self, - dask_object: Server | Client, + dask_object: ServerNode | Client, name: str, argv: Iterable[str], file_dir: str | None, @@ -250,7 +250,7 @@ def __len__(self) -> int: def process_preloads( - dask_server: Server | Client, + dask_server: ServerNode | Client, preload: str | list[str], preload_argv: list[str] | list[list[str]], *, diff --git a/distributed/pubsub.py b/distributed/pubsub.py index 3a678f0b66b..42b3d917cf9 100644 --- a/distributed/pubsub.py +++ b/distributed/pubsub.py @@ -27,7 +27,7 @@ def __init__(self, scheduler): self.scheduler.handlers.update({"pubsub_add_publisher": self.add_publisher}) - self.scheduler.stream_handlers.update( + self.scheduler.server.stream_handlers.update( { "pubsub-add-subscriber": self.add_subscriber, "pubsub-remove-publisher": self.remove_publisher, @@ -122,7 +122,7 @@ class PubSubWorkerExtension: def __init__(self, worker): self.worker = worker - self.worker.stream_handlers.update( + self.worker.server.stream_handlers.update( { "pubsub-add-subscriber": self.add_subscriber, "pubsub-remove-subscriber": self.remove_subscriber, diff --git a/distributed/queues.py b/distributed/queues.py index c48616459da..2195c31b69e 100644 --- a/distributed/queues.py +++ b/distributed/queues.py @@ -32,7 +32,7 @@ def __init__(self, scheduler): self.client_refcount = dict() self.future_refcount = defaultdict(int) - self.scheduler.handlers.update( + self.scheduler.server.handlers.update( { "queue_create": self.create, "queue_put": self.put, @@ -41,7 +41,7 @@ def __init__(self, scheduler): } ) - self.scheduler.stream_handlers.update( + self.scheduler.server.stream_handlers.update( {"queue-future-release": self.future_release, "queue_release": self.release} ) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 0273d333da3..b5977084c6f 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -50,6 +50,7 @@ valmap, ) from tornado.ioloop import IOLoop +from typing_extensions import Self import dask import dask.utils @@ -3945,7 +3946,7 @@ async def post(self): setproctitle("dask scheduler [not started]") Scheduler._instances.add(self) - self.rpc.allow_offload = False + self.server.rpc.allow_offload = False ################## # Administration # @@ -3953,7 +3954,7 @@ async def post(self): def __repr__(self) -> str: return ( - f"" @@ -3961,7 +3962,7 @@ def __repr__(self) -> str: def _repr_html_(self) -> str: return get_template("scheduler.html.j2").render( - address=self.address, + address=self.server.address, workers=self.workers, threads=self.total_nthreads, tasks=self.tasks, @@ -3972,7 +3973,7 @@ def identity(self) -> dict[str, Any]: d = { "type": type(self).__name__, "id": str(self.id), - "address": self.address, + "address": self.server.address, "services": {key: v.port for (key, v) in self.services.items()}, "started": self.time_started, "workers": { @@ -4095,25 +4096,25 @@ async def start_unsafe(self) -> Self: self._clear_task_state() for addr in self._start_address: - await self.listen( + await self.server.listen( addr, allow_offload=False, handshake_overrides={"pickle-protocol": 4, "compression": None}, **self.security.get_listen_args("scheduler"), ) - self.ip = get_address_host(self.listen_address) + self.ip = get_address_host(self.server.listen_address) listen_ip = self.ip if listen_ip == "0.0.0.0": listen_ip = "" - if self.address.startswith("inproc://"): + if self.server.address.startswith("inproc://"): listen_ip = "localhost" # Services listen on all addresses self.start_services(listen_ip) - for listener in self.listeners: + for listener in self.server.listeners: logger.info(" Scheduler at: %25s", listener.contact_address) for name, server in self.services.items(): if name == "dashboard": @@ -4147,9 +4148,11 @@ def del_scheduler_file() -> None: if self.jupyter: # Allow insecure communications from local users - if self.address.startswith("tls://"): - await self.listen("tcp://localhost:0") - os.environ["DASK_SCHEDULER_ADDRESS"] = self.listeners[-1].contact_address + if self.server.address.startswith("tls://"): + await self.server.listen("tcp://localhost:0") + os.environ["DASK_SCHEDULER_ADDRESS"] = self.server.listeners[ + -1 + ].contact_address await asyncio.gather( *[plugin.start(self) for plugin in list(self.plugins.values())] @@ -4157,7 +4160,7 @@ def del_scheduler_file() -> None: self.start_periodic_callbacks() - setproctitle(f"dask scheduler [{self.address}]") + setproctitle(f"dask scheduler [{self.server.address}]") return self async def close(self, fast=None, close_workers=None, reason=""): @@ -4228,10 +4231,6 @@ async def log_errors(func): for comm in self.client_comms.values(): comm.abort() - await self.rpc.close() - - self.status = Status.closed - self.stop() await super().close() setproctitle("dask scheduler [closed]") @@ -5287,7 +5286,7 @@ async def remove_worker( if not dh_addresses: del self.host_info[host] - self.rpc.remove(address) + self.server.rpc.remove(address) del self.stream_comms[address] del self.aliases[ws.name] self.idle.pop(ws.address, None) @@ -5776,7 +5775,7 @@ async def add_client( bcomm.send(msg) try: - await self.handle_stream(comm=comm, extra={"client": client}) + await self.server.handle_stream(comm=comm, extra={"client": client}) finally: self.remove_client(client=client, stimulus_id=f"remove-client-{time()}") logger.debug("Finished handling client %s", client) @@ -5997,7 +5996,7 @@ async def handle_worker(self, comm: Comm, worker: str) -> None: worker_comm.start(comm) logger.info("Starting worker compute stream, %s", worker) try: - await self.handle_stream(comm=comm, extra={"worker": worker}) + await self.server.handle_stream(comm=comm, extra={"worker": worker}) finally: if worker in self.stream_comms: worker_comm.abort() @@ -6200,8 +6199,9 @@ async def scatter( assert isinstance(data, dict) - workers = list(ws.address for ws in wss) - keys, who_has, nbytes = await scatter_to_workers(workers, data, rpc=self.rpc) + keys, who_has, nbytes = await scatter_to_workers( + nthreads, data, rpc=self.server.rpc + ) self.update_data(who_has=who_has, nbytes=nbytes, client=client) @@ -6238,7 +6238,7 @@ async def gather( new_failed_keys, new_missing_workers, ) = await gather_from_workers( - who_has, rpc=self.rpc, serializers=serializers + who_has, rpc=self.server.rpc, serializers=serializers ) data.update(new_data) failed_keys += new_failed_keys @@ -6524,14 +6524,14 @@ async def broadcast( async def send_message(addr): try: - comm = await self.rpc.connect(addr) + comm = await self.server.rpc.connect(addr) comm.name = "Scheduler Broadcast" try: resp = await send_recv( comm, close=True, serializers=serializers, **msg ) finally: - self.rpc.reuse(addr, comm) + self.server.rpc.reuse(addr, comm) return resp except Exception as e: logger.error(f"broadcast to {addr} failed: {e.__class__.__name__}: {e}") @@ -6581,7 +6581,7 @@ async def gather_on_worker( """ try: result = await retry_operation( - self.rpc(addr=worker_address).gather, who_has=who_has + self.server.rpc(addr=worker_address).gather, who_has=who_has ) except OSError as e: # This can happen e.g. if the worker is going through controlled shutdown; @@ -6632,7 +6632,7 @@ async def delete_worker_data( """ try: await retry_operation( - self.rpc(addr=worker_address).free_keys, + self.server.rpc(addr=worker_address).free_keys, keys=list(keys), stimulus_id=f"delete-data-{time()}", ) @@ -7748,7 +7748,7 @@ async def get_call_stack(self, keys: Iterable[Key] | None = None) -> dict[str, A return {} results = await asyncio.gather( - *(self.rpc(w).call_stack(keys=v) for w, v in workers.items()) + *(self.server.rpc(w).call_stack(keys=v) for w, v in workers.items()) ) response = {w: r for w, r in zip(workers, results) if r} return response @@ -7791,7 +7791,8 @@ async def benchmark_hardware(self) -> dict[str, dict[str, float]]: # Randomize the connections to even out the mean measures. random.shuffle(workers) futures = [ - self.rpc(a).benchmark_network(address=b) for a, b in partition(2, workers) + self.server.rpc(a).benchmark_network(address=b) + for a, b in partition(2, workers) ] responses = await asyncio.gather(*futures) @@ -8175,7 +8176,9 @@ async def get_profile( results = await asyncio.gather( *( - self.rpc(w).profile(start=start, stop=stop, key=key, server=server) + self.server.rpc(w).profile( + start=start, stop=stop, key=key, server=server + ) for w in workers ), return_exceptions=True, @@ -8206,7 +8209,10 @@ async def get_profile_metadata( else: workers = set(self.workers) & set(workers) results: Sequence[Any] = await asyncio.gather( - *(self.rpc(w).profile_metadata(start=start, stop=stop) for w in workers), + *( + self.server.rpc(w).profile_metadata(start=start, stop=stop) + for w in workers + ), return_exceptions=True, ) @@ -8342,7 +8348,7 @@ def profile_to_figure(state): time=format_time(stop - start), ntasks=total_tasks, tasks_timings=tasks_timings, - address=self.address, + address=self.server.address, nworkers=len(self.workers), threads=sum(ws.nthreads for ws in self.workers.values()), memory=format_bytes(sum(ws.memory_limit for ws in self.workers.values())), @@ -8448,7 +8454,9 @@ async def get_worker_monitor_info(self, recent=False, starts=None): starts = {} results = await asyncio.gather( *( - self.rpc(w).get_monitor_info(recent=recent, start=starts.get(w, 0)) + self.server.rpc(w).get_monitor_info( + recent=recent, start=starts.get(w, 0) + ) for w in self.workers ) ) diff --git a/distributed/shuffle/_core.py b/distributed/shuffle/_core.py index 48510dfd41a..8a9577b8ea1 100644 --- a/distributed/shuffle/_core.py +++ b/distributed/shuffle/_core.py @@ -406,7 +406,7 @@ def get_worker_plugin() -> ShuffleWorkerPlugin: return worker.plugins["shuffle"] # type: ignore except KeyError as e: raise RuntimeError( - f"The worker {worker.address} does not have a P2P shuffle plugin." + f"The worker {worker.server.address} does not have a P2P shuffle plugin." ) from e diff --git a/distributed/shuffle/_rechunk.py b/distributed/shuffle/_rechunk.py index b33e90730b7..394c8e06a97 100644 --- a/distributed/shuffle/_rechunk.py +++ b/distributed/shuffle/_rechunk.py @@ -789,7 +789,7 @@ def create_run_on_worker( ), executor=plugin._executor, local_address=plugin.worker.address, - rpc=plugin.worker.rpc, + rpc=plugin.worker.server.rpc, digest_metric=plugin.worker.digest_metric, scheduler=plugin.worker.scheduler, memory_limiter_disk=plugin.memory_limiter_disk, diff --git a/distributed/shuffle/_shuffle.py b/distributed/shuffle/_shuffle.py index d08d5797781..98c1709f936 100644 --- a/distributed/shuffle/_shuffle.py +++ b/distributed/shuffle/_shuffle.py @@ -594,8 +594,8 @@ def create_run_on_worker( f"shuffle-{self.id}-{run_id}", ), executor=plugin._executor, - local_address=plugin.worker.address, - rpc=plugin.worker.rpc, + local_address=plugin.worker.server.address, + rpc=plugin.worker.server.rpc, digest_metric=plugin.worker.digest_metric, scheduler=plugin.worker.scheduler, memory_limiter_disk=plugin.memory_limiter_disk diff --git a/distributed/shuffle/_worker_plugin.py b/distributed/shuffle/_worker_plugin.py index 57d2cfe3696..77e113c3a92 100644 --- a/distributed/shuffle/_worker_plugin.py +++ b/distributed/shuffle/_worker_plugin.py @@ -189,13 +189,13 @@ async def _fetch( if spec is None: response = await self._plugin.worker.scheduler.shuffle_get( id=shuffle_id, - worker=self._plugin.worker.address, + worker=self._plugin.worker.server.address, ) else: response = await self._plugin.worker.scheduler.shuffle_get_or_create( spec=ToPickle(spec), key=key, - worker=self._plugin.worker.address, + worker=self._plugin.worker.server.address, ) status = response["status"] @@ -277,9 +277,9 @@ class ShuffleWorkerPlugin(WorkerPlugin): def setup(self, worker: Worker) -> None: # Attach to worker - worker.handlers["shuffle_receive"] = self.shuffle_receive - worker.handlers["shuffle_inputs_done"] = self.shuffle_inputs_done - worker.stream_handlers["shuffle-fail"] = self.shuffle_fail + worker.server.handlers["shuffle_receive"] = self.shuffle_receive + worker.server.handlers["shuffle_inputs_done"] = self.shuffle_inputs_done + worker.server.stream_handlers["shuffle-fail"] = self.shuffle_fail worker.extensions["shuffle"] = self # Initialize @@ -295,10 +295,10 @@ def setup(self, worker: Worker) -> None: self._executor = ThreadPoolExecutor(self.worker.state.nthreads) def __str__(self) -> str: - return f"ShuffleWorkerPlugin on {self.worker.address}" + return f"ShuffleWorkerPlugin on {self.worker}" def __repr__(self) -> str: - return f"" + return f"" # Handlers ########## diff --git a/distributed/stealing.py b/distributed/stealing.py index 1d72e58a22a..71b35993057 100644 --- a/distributed/stealing.py +++ b/distributed/stealing.py @@ -116,7 +116,7 @@ def __init__(self, scheduler: Scheduler): "request_cost_total": defaultdict(int), } self._request_counter = 0 - self.scheduler.stream_handlers["steal-response"] = self.move_task_confirm + self.scheduler.server.stream_handlers["steal-response"] = self.move_task_confirm async def start(self, scheduler: Any = None) -> None: """Start the background coroutine to balance the tasks on the cluster. diff --git a/distributed/tests/test_core.py b/distributed/tests/test_core.py index 93bb18f16d3..004d8756760 100644 --- a/distributed/tests/test_core.py +++ b/distributed/tests/test_core.py @@ -6,8 +6,6 @@ import os import random import socket -import sys -import threading import time as timemod import weakref from unittest import mock @@ -19,7 +17,6 @@ from distributed.batched import BatchedSend from distributed.comm.core import CommClosedError -from distributed.comm.registry import backends from distributed.comm.tcp import TCPBackend, TCPListener from distributed.core import ( AsyncTaskGroup, @@ -47,8 +44,6 @@ assert_can_connect_locally_4, assert_can_connect_locally_6, assert_cannot_connect, - captured_logger, - gen_cluster, gen_test, has_ipv6, inc, @@ -234,44 +229,6 @@ async def set_flag(): assert not flag -@gen_test() -async def test_server_status_is_always_enum(): - """Assignments with strings is forbidden""" - server = Server({}) - assert isinstance(server.status, Status) - assert server.status != Status.stopped - server.status = Status.stopped - assert server.status == Status.stopped - with pytest.raises(TypeError): - server.status = "running" - - -@gen_test() -async def test_server_assign_assign_enum_is_quiet(): - """That would be the default in user code""" - server = Server({}) - server.status = Status.running - - -@gen_test() -async def test_server_status_compare_enum_is_quiet(): - """That would be the default in user code""" - server = Server({}) - # Note: We only want to assert that this comparison does not - # raise an error/warning. We do not want to assert its result. - server.status == Status.running # noqa: B015 - - -@gen_test(config={"distributed.admin.system-monitor.gil.enabled": True}) -async def test_server_close_stops_gil_monitoring(): - pytest.importorskip("gilknocker") - - server = Server({}) - assert server.monitor._gilknocker.is_running - await server.close() - assert not server.monitor._gilknocker.is_running - - @gen_test() async def test_server(): """ @@ -1006,48 +963,6 @@ async def ping(comm, delay=0.01): await asyncio.gather(*[server.close() for server in servers]) -@gen_test() -async def test_counters(): - async with Server({"div": stream_div}) as server: - await server.listen("tcp://") - - async with rpc(server.address) as r: - for _ in range(2): - await r.identity() - with pytest.raises(ZeroDivisionError): - await r.div(x=1, y=0) - - c = server.counters - assert c["op"].components[0] == {"identity": 2, "div": 1} - - -@gen_cluster(config={"distributed.admin.tick.interval": "20 ms"}) -async def test_ticks(s, a, b): - pytest.importorskip("crick") - await asyncio.sleep(0.1) - c = s.digests["tick-duration"] - assert c.size() - assert 0.01 < c.components[0].quantile(0.5) < 0.5 - - -@gen_cluster(config={"distributed.admin.tick.interval": "20 ms"}) -async def test_tick_logging(s, a, b): - pytest.importorskip("crick") - from distributed import core - - old = core.tick_maximum_delay - core.tick_maximum_delay = 0.001 - try: - with captured_logger("distributed.core") as sio: - await asyncio.sleep(0.1) - - text = sio.getvalue() - assert "unresponsive" in text - assert "Scheduler" in text or "Worker" in text - finally: - core.tick_maximum_delay = old - - @pytest.mark.parametrize("compression", list(compressions)) @pytest.mark.parametrize("serialize", [echo_serialize, echo_no_serialize]) @gen_test() @@ -1076,11 +991,6 @@ async def test_rpc_serialization(): assert result == {"result": inc} -@gen_cluster() -async def test_thread_id(s, a, b): - assert s.thread_id == a.thread_id == b.thread_id == threading.get_ident() - - @gen_test() async def test_deserialize_error(): async with Server({"throws": throws}) as server: @@ -1146,7 +1056,7 @@ async def sleep(comm=None): await asyncio.sleep(2000000) server = await Server({"sleep": sleep}) - assert server.status == Status.running + assert not server.stopped ports = [8881, 8882, 8883] # Previously we close *one* listener, therefore ensure we always use more @@ -1175,7 +1085,7 @@ async def sleep(comm=None): await assert_cannot_connect(f"tcp://{ip}:{port}") # weakref set/dict should be cleaned up - assert not len(server._ongoing_background_tasks) + assert not len(server._handle_comm_tasks) @gen_test() @@ -1271,29 +1181,6 @@ def validate_dict(server): validate_dict(server) -@gen_test() -async def test_server_sys_path_local_directory_cleanup(tmp_path, monkeypatch): - local_directory = str(tmp_path / "dask-scratch-space") - - # Ensure `local_directory` is removed from `sys.path` as part of the - # `Server` shutdown process - assert not any(i.startswith(local_directory) for i in sys.path) - async with Server({}, local_directory=local_directory): - assert sys.path[0].startswith(local_directory) - assert not any(i.startswith(local_directory) for i in sys.path) - - # Ensure `local_directory` isn't removed from `sys.path` if it - # was already there before the `Server` started - monkeypatch.setattr("sys.path", [local_directory] + sys.path) - assert sys.path[0].startswith(local_directory) - # NOTE: `needs_workdir=False` is needed to make sure the same path added - # to `sys.path` above is used by the `Server` (a subdirectory is created - # by default). - async with Server({}, local_directory=local_directory, needs_workdir=False): - assert sys.path[0].startswith(local_directory) - assert sys.path[0].startswith(local_directory) - - @pytest.mark.parametrize("close_via_rpc", [True, False]) @gen_test() async def test_close_fast_without_active_handlers(close_via_rpc): @@ -1400,15 +1287,6 @@ class TCPAsyncListenerBackend(TCPBackend): _listener_class = AsyncStopTCPListener -@gen_test() -async def test_async_listener_stop(monkeypatch): - monkeypatch.setitem(backends, "tcp", TCPAsyncListenerBackend()) - with pytest.warns(DeprecationWarning): - async with Server({}) as s: - await s.listen(0) - assert s.listeners - - @gen_test() async def test_messages_are_ordered_bsend(): ledger = [] diff --git a/distributed/tests/test_nanny.py b/distributed/tests/test_nanny.py index 52073c13d2b..bfcb86f4b79 100644 --- a/distributed/tests/test_nanny.py +++ b/distributed/tests/test_nanny.py @@ -78,7 +78,6 @@ async def test_nanny_process_failure(c, s): assert not os.path.exists(second_dir) assert not os.path.exists(first_dir) assert first_dir != n.worker_dir - s.stop() @gen_cluster(nthreads=[]) @@ -202,10 +201,9 @@ def func(dask_worker): @gen_test() async def test_scheduler_file(): with tmpfile() as fn: - s = await Scheduler(scheduler_file=fn, dashboard_address=":0") - async with Nanny(scheduler_file=fn) as n: - assert set(s.workers) == {n.worker_address} - s.stop() + async with Scheduler(scheduler_file=fn, dashboard_address=":0") as s: + async with Nanny(scheduler_file=fn) as n: + assert set(s.workers) == {n.worker_address} @pytest.mark.xfail( diff --git a/distributed/tests/test_node.py b/distributed/tests/test_node.py new file mode 100644 index 00000000000..f0f1fa1a6da --- /dev/null +++ b/distributed/tests/test_node.py @@ -0,0 +1,104 @@ +from __future__ import annotations + +import asyncio +import sys +import threading + +import pytest + +from distributed.core import Status +from distributed.node import Node +from distributed.utils_test import captured_logger, gen_cluster, gen_test + + +@gen_test(config={"distributed.admin.system-monitor.gil.enabled": True}) +async def test_server_close_stops_gil_monitoring(): + pytest.importorskip("gilknocker") + + node = Node() + assert node.monitor._gilknocker.is_running + await node.close() + assert not node.monitor._gilknocker.is_running + + +@gen_test() +async def test_server_sys_path_local_directory_cleanup(tmp_path, monkeypatch): + local_directory = str(tmp_path / "dask-scratch-space") + + # Ensure `local_directory` is removed from `sys.path` as part of the + # `Server` shutdown process + assert not any(i.startswith(local_directory) for i in sys.path) + async with Node(local_directory=local_directory): + assert sys.path[0].startswith(local_directory) + assert not any(i.startswith(local_directory) for i in sys.path) + + # Ensure `local_directory` isn't removed from `sys.path` if it + # was already there before the `Server` started + monkeypatch.setattr("sys.path", [local_directory] + sys.path) + assert sys.path[0].startswith(local_directory) + # NOTE: `needs_workdir=False` is needed to make sure the same path added + # to `sys.path` above is used by the `Server` (a subdirectory is created + # by default). + async with Node(local_directory=local_directory, needs_workdir=False): + assert sys.path[0].startswith(local_directory) + assert sys.path[0].startswith(local_directory) + + +@gen_test() +async def test_server_status_is_always_enum(): + """Assignments with strings is forbidden""" + server = Node() + assert isinstance(server.status, Status) + assert server.status != Status.stopped + server.status = Status.stopped + assert server.status == Status.stopped + with pytest.raises(TypeError): + server.status = "running" + + +@gen_test() +async def test_server_assign_assign_enum_is_quiet(): + """That would be the default in user code""" + node = Node() + node.status = Status.running + + +@gen_test() +async def test_server_status_compare_enum_is_quiet(): + """That would be the default in user code""" + node = Node() + # Note: We only want to assert that this comparison does not + # raise an error/warning. We do not want to assert its result. + node.status == Status.running # noqa: B015 + + +@gen_cluster(config={"distributed.admin.tick.interval": "20 ms"}) +async def test_ticks(s, a, b): + pytest.importorskip("crick") + await asyncio.sleep(0.1) + c = s.digests["tick-duration"] + assert c.size() + assert 0.01 < c.components[0].quantile(0.5) < 0.5 + + +@gen_cluster(config={"distributed.admin.tick.interval": "20 ms"}) +async def test_tick_logging(s, a, b): + pytest.importorskip("crick") + from distributed import node + + old = node.tick_maximum_delay + node.tick_maximum_delay = 0.001 + try: + with captured_logger("distributed.core") as sio: + await asyncio.sleep(0.1) + + text = sio.getvalue() + assert "unresponsive" in text + assert "Scheduler" in text or "Worker" in text + finally: + node.tick_maximum_delay = old + + +@gen_cluster() +async def test_thread_id(s, a, b): + assert s.thread_id == a.thread_id == b.thread_id == threading.get_ident() diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index 9aaea1288a4..62770cefd2f 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -825,8 +825,8 @@ async def test_retire_workers_concurrently(c, s, w1, w2): async def test_server_listens_to_other_ops(s, a, b): async with rpc(s.address) as r: ident = await r.identity() - assert ident["type"] == "Scheduler" - assert ident["id"].lower().startswith("scheduler") + assert ident["type"] == "Scheduler", ident["type"] + assert ident["id"].lower().startswith("scheduler"), ident["id"] @gen_cluster(client=True) @@ -936,7 +936,7 @@ def func(scheduler): nthreads=[], config={"distributed.scheduler.blocked-handlers": ["test-handler"]} ) async def test_scheduler_init_pulls_blocked_handlers_from_config(s): - assert s.blocked_handlers == ["test-handler"] + assert s.server.blocked_handlers == ["test-handler"] @gen_cluster() @@ -1321,7 +1321,7 @@ async def test_broadcast_nanny(s, a, b): @gen_cluster(config={"distributed.comm.timeouts.connect": "200ms"}) async def test_broadcast_on_error(s, a, b): - a.stop() + a.server.stop() with pytest.raises(OSError): await s.broadcast(msg={"op": "ping"}, on_error="raise") @@ -2005,7 +2005,7 @@ async def test_profile_metadata_timeout(c, s, a, b): def raise_timeout(*args, **kwargs): raise TimeoutError - b.handlers["profile_metadata"] = raise_timeout + b.server.handlers["profile_metadata"] = raise_timeout futures = c.map(slowinc, range(10), delay=0.05, workers=a.address) await wait(futures) @@ -2069,7 +2069,7 @@ async def test_statistical_profiling_failure(c, s, a, b): def raise_timeout(*args, **kwargs): raise TimeoutError - b.handlers["profile"] = raise_timeout + b.server.handlers["profile"] = raise_timeout await wait(futures) profile = await s.get_profile() @@ -3128,7 +3128,7 @@ async def connect(self, *args, **kwargs): async def test_gather_failing_can_recover(c, s, a, b): x = await c.scatter({"x": 1}, workers=a.address) rpc = await FlakyConnectionPool(failing_connections=1) - with mock.patch.object(s, "rpc", rpc), dask.config.set( + with mock.patch.object(s.server, "rpc", rpc), dask.config.set( {"distributed.comm.retry.count": 1} ), captured_handler( logging.getLogger("distributed").handlers[0] @@ -3146,7 +3146,7 @@ async def test_gather_failing_can_recover(c, s, a, b): async def test_gather_failing_cnn_error(c, s, a, b): x = await c.scatter({"x": 1}, workers=a.address) rpc = await FlakyConnectionPool(failing_connections=10) - with mock.patch.object(s, "rpc", rpc): + with mock.patch.object(s.server, "rpc", rpc): res = await s.gather(keys=["x"]) assert res["status"] == "error" assert list(res["keys"]) == ["x"] @@ -3179,7 +3179,7 @@ async def test_gather_bad_worker(c, s, a, direct): """ x = c.submit(inc, 1, key="x") c.rpc = await FlakyConnectionPool(failing_connections=3) - s.rpc = await FlakyConnectionPool(failing_connections=1) + s.server.rpc = await FlakyConnectionPool(failing_connections=1) with captured_logger("distributed.scheduler") as sched_logger: with captured_logger("distributed.client") as client_logger: @@ -3194,12 +3194,12 @@ async def test_gather_bad_worker(c, s, a, direct): # 3. try direct=True again; fail # 4. fall back to direct=False again; success assert c.rpc.cnn_count == 2 - assert s.rpc.cnn_count == 2 + assert s.server.rpc.cnn_count == 2 else: # 1. try direct=False; fail # 2. try again direct=False; success assert c.rpc.cnn_count == 0 - assert s.rpc.cnn_count == 2 + assert s.server.rpc.cnn_count == 2 @gen_cluster(client=True) @@ -3230,8 +3230,8 @@ async def test_multiple_listeners(dashboard_link_template, expected_dashboard_li async with Scheduler( dashboard_address=":0", protocol=["inproc", "tcp"] ) as s: - async with Worker(s.listeners[0].contact_address) as a: - async with Worker(s.listeners[1].contact_address) as b: + async with Worker(s.server.listeners[0].contact_address) as a: + async with Worker(s.server.listeners[1].contact_address) as b: assert a.address.startswith("inproc") assert a.scheduler.address.startswith("inproc") assert b.address.startswith("tcp") @@ -4337,6 +4337,7 @@ async def test_get_cluster_state(s, *workers): _verify_cluster_state(state_no_workers, []) +@pytest.mark.xfail(reason="worked dict missing") @gen_cluster( nthreads=[("", 1)] * 2, config={"distributed.comm.timeouts.connect": "200ms"}, @@ -4703,7 +4704,7 @@ class BrokenGatherDep(Worker): async def gather_dep(self, worker, *args, **kwargs): w = workers.pop(worker, None) if w is not None and workers: - w.listener.stop() + w.server.listener.stop() s.stream_comms[worker].abort() return await super().gather_dep(worker, *args, **kwargs) diff --git a/distributed/utils_test.py b/distributed/utils_test.py index 1fd59b5525a..e8d59c3e75c 100644 --- a/distributed/utils_test.py +++ b/distributed/utils_test.py @@ -796,7 +796,7 @@ async def start_cluster( ): await asyncio.sleep(0.01) if time() > start + 30: - await asyncio.gather(*(w.close(timeout=1) for w in workers)) + await asyncio.gather(*(w.close() for w in workers)) await s.close() check_invalid_worker_transitions(s) check_invalid_task_states(s) @@ -857,7 +857,6 @@ async def end_worker(w): await asyncio.gather(*(end_worker(w) for w in workers)) await s.close() # wait until scheduler stops completely - s.stop() check_invalid_worker_transitions(s) check_invalid_task_states(s) check_worker_fail_hard(s) diff --git a/distributed/variable.py b/distributed/variable.py index 3df28ff3596..3557965b607 100644 --- a/distributed/variable.py +++ b/distributed/variable.py @@ -39,8 +39,10 @@ def __init__(self, scheduler): {"variable_set": self.set, "variable_get": self.get} ) - self.scheduler.stream_handlers["variable-future-release"] = self.future_release - self.scheduler.stream_handlers["variable_delete"] = self.delete + self.scheduler.server.stream_handlers[ + "variable-future-release" + ] = self.future_release + self.scheduler.server.stream_handlers["variable_delete"] = self.delete async def set(self, name=None, key=None, data=None, client=None): if key is not None: diff --git a/distributed/worker.py b/distributed/worker.py index 17d715507ea..d2f38f8a0fa 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -43,6 +43,7 @@ from tlz import keymap, pluck from tornado.ioloop import IOLoop +from typing_extensions import Self import dask from dask.core import istask @@ -73,7 +74,6 @@ PooledRPCCall, Status, coerce_to_address, - context_meter_to_server_digest, error_message, pingpong, ) @@ -85,7 +85,7 @@ from distributed.exceptions import Reschedule from distributed.http import get_handlers from distributed.metrics import context_meter, thread_time, time -from distributed.node import ServerNode +from distributed.node import ServerNode, context_meter_to_node_digest from distributed.proctitle import setproctitle from distributed.protocol import pickle, to_serialize from distributed.protocol.serialize import _is_dumpable @@ -784,7 +784,7 @@ def __init__( ) BaseWorker.__init__(self, state) - self.scheduler = self.rpc(scheduler_addr) + self.scheduler = self.server.rpc(scheduler_addr) self.execution_state = { "scheduler": self.scheduler.address, "ioloop": self.loop, @@ -939,9 +939,9 @@ def waiting_for_data_count(self) -> int: ################## def __repr__(self) -> str: - name = f", name: {self.name}" if self.name != self.address_safe else "" + name = f", name: {self.name}" if self.name != self.server.address_safe else "" return ( - f"<{self.__class__.__name__} {self.address_safe!r}{name}, " + f"<{self.__class__.__name__} {self.server.address_safe!r}{name}, " f"status: {self.status.name}, " f"stored: {len(self.data)}, " f"running: {self.state.executing_count}/{self.state.nthreads}, " @@ -986,7 +986,7 @@ def log_event(self, topic: str | Collection[str], msg: Any) -> None: @property def worker_address(self): """For API compatibility with Nanny""" - return self.address + return self.server.address @property def executor(self): @@ -1164,7 +1164,7 @@ async def _register_with_scheduler(self) -> None: self.periodic_callbacks["heartbeat"].stop() start = time() if self.contact_address is None: - self.contact_address = self.address + self.contact_address = self.server.address logger.info("-" * 49) # Worker reconnection is not supported @@ -1244,7 +1244,7 @@ def _update_latency(self, latency: float) -> None: self.digest_metric("latency", latency) async def heartbeat(self) -> None: - logger.debug("Heartbeat: %s", self.address) + logger.debug("Heartbeat: %s", self.server.address) try: start = time() response = await retry_operation( @@ -1297,7 +1297,7 @@ async def heartbeat(self) -> None: @fail_hard async def handle_scheduler(self, comm: Comm) -> None: try: - await self.handle_stream(comm) + await self.server.handle_stream(comm) finally: await self.close(reason="worker-handle-scheduler-connection-broken") @@ -1328,7 +1328,7 @@ async def gather(self, who_has: dict[Key, list[str]]) -> dict[Key, object]: new_failed_keys, new_missing_workers, ) = await gather_from_workers( - who_has=to_gather, rpc=self.rpc, who=self.address + who_has=to_gather, rpc=self.server.rpc, who=self.server.address ) self.update_data(data, stimulus_id=stimulus_id) del data @@ -1365,7 +1365,7 @@ def get_monitor_info(self, recent: bool = False, start: int = 0) -> dict[str, An # Lifecycle # ############# - async def start_unsafe(self): + async def start_unsafe(self) -> Self: await super().start_unsafe() enable_gc_diagnosis() @@ -1386,7 +1386,7 @@ async def start_unsafe(self): get_address_host(self.scheduler.address) ) try: - await self.listen(start_address, **kwargs) + await self.server.listen(start_address, **kwargs) except OSError as e: if len(ports) > 1 and e.errno == errno.EADDRINUSE: continue @@ -1420,10 +1420,10 @@ async def start_unsafe(self): self, prefix=self._http_prefix, ) - self.ip = get_address_host(self.address) + self.ip = get_address_host(self.server.address) if self.name is None: - self.name = self.address + self.name = self.server.address await self.preloads.start() @@ -1433,13 +1433,17 @@ async def start_unsafe(self): self.start_services(self.ip) try: - listening_address = "%s%s:%d" % (self.listener.prefix, self.ip, self.port) + listening_address = "%s%s:%d" % ( + self.server.listener.prefix, + self.ip, + self.server.port, + ) except Exception: - listening_address = f"{self.listener.prefix}{self.ip}" + listening_address = f"{self.server.listener.prefix}{self.ip}" - logger.info(" Start worker at: %26s", self.address) + logger.info(" Start worker at: %26s", self.server.address) logger.info(" Listening to: %26s", listening_address) - if self.name != self.address_safe: + if self.name != self.server.address_safe: # only if name was not None logger.info(" Worker name: %26s", self.name) for k, v in self.service_ports.items(): @@ -1454,7 +1458,7 @@ async def start_unsafe(self): ) logger.info(" Local Directory: %26s", self.local_directory) - setproctitle("dask worker [%s]" % self.address) + setproctitle("dask worker [%s]" % self.server.address) plugins_msgs = await asyncio.gather( *( @@ -1474,7 +1478,7 @@ async def start_unsafe(self): raise plugins_exceptions[0] self._pending_plugins = () - self.state.address = self.address + self.state.address = self.server.address await self._register_with_scheduler() self.start_periodic_callbacks() return self @@ -1535,13 +1539,17 @@ async def close( # type: ignore disable_gc_diagnosis() try: - self.log_event(self.address, {"action": "closing-worker", "reason": reason}) + self.log_event( + self.server.address, {"action": "closing-worker", "reason": reason} + ) except Exception: # This can happen when the Server is not up yet logger.exception("Failed to log closing event") try: - logger.info("Stopping worker at %s. Reason: %s", self.address, reason) + logger.info( + "Stopping worker at %s. Reason: %s", self.server.address, reason + ) except ValueError: # address not available if already closed logger.info("Stopping worker. Reason: %s", reason) if self.status not in WORKER_ANY_RUNNING: @@ -1554,7 +1562,7 @@ async def close( # type: ignore setproctitle("dask worker [closing]") if nanny and self.nanny: - with self.rpc(self.nanny) as r: + with self.server.rpc(self.nanny) as r: await r.close_gracefully(reason=reason) # Cancel async instructions @@ -1602,8 +1610,7 @@ async def close( # type: ignore # otherwise c.close() - await self._stop_listeners() - await self.rpc.close() + await self.server.close() # Give some time for a UCX scheduler to complete closing endpoints # before closing self.batched_stream, otherwise the local endpoint @@ -1654,7 +1661,6 @@ def _close(executor, wait): executor=executor, wait=executor_wait ) # Just run it directly - self.stop() self.status = Status.closed setproctitle("dask worker [closed]") @@ -1677,12 +1683,14 @@ async def close_gracefully( if self.status == Status.closed: return - logger.info("Closing worker gracefully: %s. Reason: %s", self.address, reason) + logger.info( + "Closing worker gracefully: %s. Reason: %s", self.server.address, reason + ) # Wait for all tasks to leave the worker and don't accept any new ones. # Scheduler.retire_workers will set the status to closing_gracefully and push it # back to this worker. await self.scheduler.retire_workers( - workers=[self.address], + workers=[self.server.address], close_workers=False, remove=True, stimulus_id=f"worker-close-gracefully-{time()}", @@ -1718,7 +1726,7 @@ async def batched_send_connect(): self.stream_comms[address].send(msg) - @context_meter_to_server_digest("get-data") + @context_meter_to_node_digest("get-data") async def get_data( self, comm: Comm, @@ -1728,7 +1736,7 @@ async def get_data( ) -> GetDataBusy | Literal[Status.dont_reply]: max_connections = self.transfer_outgoing_count_limit # Allow same-host connections more liberally - if get_address_host(comm.peer_address) == get_address_host(self.address): + if get_address_host(comm.peer_address) == get_address_host(self.server.address): max_connections = max_connections * 2 if self.status == Status.paused: @@ -1746,7 +1754,7 @@ async def get_data( logger.debug( "Worker %s has too many open connections to respond to data request " "from %s (%d/%d).%s", - self.address, + self.server.address, who, self.transfer_outgoing_count, max_connections, @@ -1766,7 +1774,7 @@ async def get_data( from distributed.actor import Actor data[k] = Actor( - type(self.state.actors[k]), self.address, k, worker=self + type(self.state.actors[k]), self.server.address, k, worker=self ) msg = {"status": "OK", "data": {k: to_serialize(v) for k, v in data.items()}} @@ -1785,7 +1793,7 @@ async def get_data( except OSError: logger.exception( "failed during get data with %s -> %s", - self.address, + self.server.address, who, ) comm.abort() @@ -2057,7 +2065,10 @@ async def gather_dep( try: with context_meter.meter("network", func=time) as m: response = await get_data_from_worker( - rpc=self.rpc, keys=to_gather, worker=worker, who=self.address + rpc=self.server.rpc, + keys=to_gather, + worker=worker, + who=self.server.address, ) if response["status"] == "busy": @@ -2400,7 +2411,9 @@ def _prepare_args_for_execution( except KeyError: from distributed.actor import Actor # TODO: create local actor - data[k] = Actor(type(self.state.actors[k]), self.address, k, self) + data[k] = Actor( + type(self.state.actors[k]), self.server.address, k, self + ) args2 = pack_data(args, data, key_types=(bytes, str, tuple)) kwargs2 = pack_data(kwargs, data, key_types=(bytes, str, tuple)) stop = time() @@ -2461,7 +2474,7 @@ async def get_profile( ): now = time() + self.scheduler_delay if server: - history = self.io_loop.profile # type: ignore[attr-defined] + history = self.io_loop.profile elif key is None: history = self.profile_history else: @@ -2540,7 +2553,7 @@ async def benchmark_memory(self) -> dict[str, float]: return await self.loop.run_in_executor(self.executor, benchmark_memory) async def benchmark_network(self, address: str) -> dict[str, float]: - return await benchmark_network(rpc=self.rpc, address=address) + return await benchmark_network(rpc=self.server.rpc, address=address) ####################################### # Worker Clients (advanced workloads) # @@ -2636,7 +2649,7 @@ def get_current_task(self) -> Key: return self.active_threads[threading.get_ident()] def _handle_remove_worker(self, worker: str, stimulus_id: str) -> None: - self.rpc.remove(worker) + self.server.rpc.remove(worker) self.handle_stimulus(RemoveWorkerEvent(worker=worker, stimulus_id=stimulus_id)) def validate_state(self) -> None: From fcaee9d3207bf4540dcc33f77c49d41650df6aad Mon Sep 17 00:00:00 2001 From: fjetter Date: Fri, 19 Jan 2024 15:07:33 +0100 Subject: [PATCH 2/6] add Server to client --- distributed/client.py | 107 ++++++------------------------- distributed/pubsub.py | 2 +- distributed/tests/test_client.py | 2 +- 3 files changed, 22 insertions(+), 89 deletions(-) diff --git a/distributed/client.py b/distributed/client.py index 3cac5b3fc25..abe15e887b7 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -49,7 +49,7 @@ ) from dask.widgets import get_template -from distributed.core import OKMessage +from distributed.core import ErrorMessage, OKMessage, Server from distributed.protocol.serialize import _is_dumpable from distributed.utils import Deadline, wait_for @@ -68,11 +68,9 @@ from distributed.compatibility import PeriodicCallback from distributed.core import ( CommClosedError, - ConnectionPool, PooledRPCCall, Status, clean_exception, - connect, rpc, ) from distributed.diagnostics.plugin import ( @@ -1050,7 +1048,7 @@ def __init__( self._set_config = dask.config.set(scheduler="dask.distributed") self._event_handlers = {} - self._stream_handlers = { + stream_handlers = { "key-in-memory": self._handle_key_in_memory, "lost-data": self._handle_lost_data, "cancelled-keys": self._handle_cancelled_keys, @@ -1067,15 +1065,17 @@ def __init__( "erred": self._handle_task_erred, } - self.rpc = ConnectionPool( - limit=connection_limit, - serializers=serializers, - deserializers=deserializers, + self.server = Server( + {}, + stream_handlers=stream_handlers, + connection_limit=connection_limit, deserialize=True, - connection_args=self.connection_args, + deserializers=deserializers, + serializers=serializers, timeout=timeout, - server=self, + connection_args=self.connection_args, ) + self.rpc = self.server.rpc self.extensions = { name: extension(self) for name, extension in extensions.items() @@ -1321,7 +1321,7 @@ def _send_to_scheduler(self, msg): async def _start(self, timeout=no_default, **kwargs): self.status = "connecting" - await self.rpc.start() + await self.server if timeout is no_default: timeout = self._timeout @@ -1362,7 +1362,7 @@ async def _start(self, timeout=no_default, **kwargs): self._gather_semaphore = asyncio.Semaphore(5) if self.scheduler is None: - self.scheduler = self.rpc(address) + self.scheduler = self.server.rpc(address) self.scheduler_comm = None try: @@ -1379,7 +1379,9 @@ async def _start(self, timeout=no_default, **kwargs): await self.preloads.start() - self._handle_report_task = asyncio.create_task(self._handle_report()) + self._handle_report_task = asyncio.create_task( + self.server.handle_stream(self.scheduler_comm.comm) + ) return self @@ -1434,9 +1436,7 @@ async def _ensure_connected(self, timeout=None): self._connecting_to_scheduler = True try: - comm = await connect( - self.scheduler.address, timeout=timeout, **self.connection_args - ) + comm = await self.server.rpc.connect(self.scheduler.address) comm.name = "Client->Scheduler" if timeout is not None: await wait_for(self._update_scheduler_info(), timeout) @@ -1621,63 +1621,6 @@ def _release_key(self, key): {"op": "client-releases-keys", "keys": [key], "client": self.id} ) - @log_errors - async def _handle_report(self): - """Listen to scheduler""" - try: - while True: - if self.scheduler_comm is None: - break - try: - msgs = await self.scheduler_comm.comm.read() - except CommClosedError: - if self._is_finalizing(): - return - if self.status == "running": - if self.cluster and self.cluster.status in ( - Status.closed, - Status.closing, - ): - # Don't attempt to reconnect if cluster are already closed. - # Instead close down the client. - await self._close() - return - logger.info("Client report stream closed to scheduler") - logger.info("Reconnecting...") - self.status = "connecting" - await self._reconnect() - continue - else: - break - if not isinstance(msgs, (list, tuple)): - msgs = (msgs,) - - breakout = False - for msg in msgs: - logger.debug("Client %s receives message %s", self.id, msg) - - if "status" in msg and "error" in msg["status"]: - typ, exc, tb = clean_exception(**msg) - raise exc.with_traceback(tb) - - op = msg.pop("op") - - if op == "close" or op == "stream-closed": - breakout = True - break - - try: - handler = self._stream_handlers[op] - result = handler(**msg) - if inspect.isawaitable(result): - await result - except Exception as e: - logger.exception(e) - if breakout: - break - except (CancelledError, asyncio.CancelledError): - pass - def _handle_key_in_memory(self, key=None, type=None, workers=None): state = self.futures.get(key) if state is not None: @@ -1787,13 +1730,6 @@ async def _close(self, fast: bool = False) -> None: self._send_to_scheduler({"op": "close-client"}) self._send_to_scheduler({"op": "close-stream"}) async with self._wait_for_handle_report_task(fast=fast): - if ( - self.scheduler_comm - and self.scheduler_comm.comm - and not self.scheduler_comm.comm.closed() - ): - await self.scheduler_comm.close() - for key in list(self.futures): self._release_key(key=key) @@ -1801,15 +1737,12 @@ async def _close(self, fast: bool = False) -> None: with suppress(AttributeError): await self.cluster.close() - await self.rpc.close() - - self.status = "closed" + await self.server.close() - if _get_global_client() is self: - _set_global_client(None) + self.status = "closed" - with suppress(AttributeError): - await self.scheduler.close_rpc() + if _get_global_client() is self: + _set_global_client(None) self.scheduler = None self.status = "closed" diff --git a/distributed/pubsub.py b/distributed/pubsub.py index 42b3d917cf9..ed9bb2960bf 100644 --- a/distributed/pubsub.py +++ b/distributed/pubsub.py @@ -175,7 +175,7 @@ class PubSubClientExtension: def __init__(self, client): self.client = client - self.client._stream_handlers.update({"pubsub-msg": self.handle_message}) + self.client.server.stream_handlers.update({"pubsub-msg": self.handle_message}) self.subscribers = defaultdict(weakref.WeakSet) diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index f35de46bf94..d3796bc7f82 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -4025,7 +4025,7 @@ async def test_get_versions_async(c, s, a, b): @gen_cluster(client=True, config={"distributed.comm.timeouts.connect": "200ms"}) async def test_get_versions_rpc_error(c, s, a, b): - a.stop() + a.server.stop() v = await c.get_versions() assert v.keys() == {"scheduler", "client", "workers"} assert v["workers"].keys() == {b.address} From 78c8c6bbd6e9a83707d5057e278e2cbe406fd798 Mon Sep 17 00:00:00 2001 From: fjetter Date: Wed, 10 Jul 2024 15:26:33 +0200 Subject: [PATCH 3/6] linting fixes --- distributed/client.py | 2 +- distributed/node.py | 2 ++ distributed/scheduler.py | 5 +---- 3 files changed, 4 insertions(+), 5 deletions(-) diff --git a/distributed/client.py b/distributed/client.py index abe15e887b7..e0331e32826 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -49,7 +49,7 @@ ) from dask.widgets import get_template -from distributed.core import ErrorMessage, OKMessage, Server +from distributed.core import OKMessage, Server from distributed.protocol.serialize import _is_dumpable from distributed.utils import Deadline, wait_for diff --git a/distributed/node.py b/distributed/node.py index 5a3d052c901..eef6088ca47 100644 --- a/distributed/node.py +++ b/distributed/node.py @@ -86,6 +86,8 @@ class Node: _event_finished: asyncio.Event + _is_finalizing: staticmethod[[], bool] = staticmethod(sys.is_finalizing) + def __init__( self, local_directory=None, diff --git a/distributed/scheduler.py b/distributed/scheduler.py index b5977084c6f..ec5322e18fa 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -50,7 +50,6 @@ valmap, ) from tornado.ioloop import IOLoop -from typing_extensions import Self import dask import dask.utils @@ -6199,9 +6198,7 @@ async def scatter( assert isinstance(data, dict) - keys, who_has, nbytes = await scatter_to_workers( - nthreads, data, rpc=self.server.rpc - ) + keys, who_has, nbytes = await scatter_to_workers(wss, data, rpc=self.server.rpc) self.update_data(who_has=who_has, nbytes=nbytes, client=client) From 9ab9924bbc13c22b762c048c98ed45f3b69647ed Mon Sep 17 00:00:00 2001 From: fjetter Date: Wed, 10 Jul 2024 15:30:49 +0200 Subject: [PATCH 4/6] type checking --- distributed/worker.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/distributed/worker.py b/distributed/worker.py index d2f38f8a0fa..c108ea59cfc 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -43,7 +43,6 @@ from tlz import keymap, pluck from tornado.ioloop import IOLoop -from typing_extensions import Self import dask from dask.core import istask @@ -153,7 +152,7 @@ if TYPE_CHECKING: # FIXME import from typing (needs Python >=3.10) - from typing_extensions import ParamSpec + from typing_extensions import ParamSpec, Self # Circular imports from distributed.client import Client From 04f5797461ddc1b93633632c326f212205c395a2 Mon Sep 17 00:00:00 2001 From: fjetter Date: Wed, 10 Jul 2024 15:32:14 +0200 Subject: [PATCH 5/6] fix pubsub ext --- distributed/publish.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/distributed/publish.py b/distributed/publish.py index e7887b8dc96..3dfe34f5a2c 100644 --- a/distributed/publish.py +++ b/distributed/publish.py @@ -33,8 +33,8 @@ def __init__(self, scheduler): "publish_flush_batched_send": self.flush_receive, } - self.scheduler.handlers.update(handlers) - self.scheduler.stream_handlers.update(stream_handlers) + self.scheduler.server.handlers.update(handlers) + self.scheduler.server.stream_handlers.update(stream_handlers) self._flush_received = defaultdict(asyncio.Event) def flush_receive(self, uid, **kwargs): From f8ccddf865efd68e90a6498f1d6c0b8f4cab9911 Mon Sep 17 00:00:00 2001 From: fjetter Date: Wed, 10 Jul 2024 15:37:22 +0200 Subject: [PATCH 6/6] minor fixes --- distributed/active_memory_manager.py | 2 +- distributed/scheduler.py | 6 ++++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/distributed/active_memory_manager.py b/distributed/active_memory_manager.py index 724bfc18923..d6c5344a773 100644 --- a/distributed/active_memory_manager.py +++ b/distributed/active_memory_manager.py @@ -110,7 +110,7 @@ def __init__( self.measure = measure if register: - scheduler.handlers["amm_handler"] = self.amm_handler + scheduler.server.handlers["amm_handler"] = self.amm_handler if interval is None: interval = parse_timedelta( diff --git a/distributed/scheduler.py b/distributed/scheduler.py index ec5322e18fa..46349dc12bf 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -6197,8 +6197,10 @@ async def scatter( await asyncio.sleep(0.1) assert isinstance(data, dict) - - keys, who_has, nbytes = await scatter_to_workers(wss, data, rpc=self.server.rpc) + workers = list(ws.address for ws in wss) + keys, who_has, nbytes = await scatter_to_workers( + workers, data, rpc=self.server.rpc + ) self.update_data(who_has=who_has, nbytes=nbytes, client=client)