diff --git a/continuous_integration/environment-3.9.yaml b/continuous_integration/environment-3.9.yaml index 3d35c74dfa8..0e45c7a8388 100644 --- a/continuous_integration/environment-3.9.yaml +++ b/continuous_integration/environment-3.9.yaml @@ -43,6 +43,7 @@ dependencies: - toolz - torchvision # Only tested here - tornado=6 + - zeroconf # Only tested here - zict # overridden by git tip below - zstandard - pip: diff --git a/distributed/distributed-schema.yaml b/distributed/distributed-schema.yaml index d10022f7574..b8923750745 100644 --- a/distributed/distributed-schema.yaml +++ b/distributed/distributed-schema.yaml @@ -162,6 +162,11 @@ properties: Whether or not to run consistency checks during execution. This is typically only used for debugging. + zeroconf: + type: boolean + description: | + Whether or not to advertise the scheduler via zeroconf. + dashboard: type: object description: | diff --git a/distributed/distributed.yaml b/distributed/distributed.yaml index 929b58676a7..7651eb2dc9f 100644 --- a/distributed/distributed.yaml +++ b/distributed/distributed.yaml @@ -30,6 +30,7 @@ distributed: rechunk-split: 1us split-shuffle: 1us validate: False # Check scheduler state at every step for debugging + zeroconf: true dashboard: status: task-stream-length: 1000 diff --git a/distributed/scheduler.py b/distributed/scheduler.py index f89e486356d..60152ba40ad 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -8,6 +8,7 @@ import operator import os import random +import socket import sys import uuid import warnings @@ -49,7 +50,11 @@ resolve_address, unparse_host_port, ) -from .comm.addressing import addresses_from_user_args +from .comm.addressing import ( + addresses_from_user_args, + get_address_host_port, + parse_address, +) from .core import CommClosedError, Status, clean_exception, rpc, send_recv from .diagnostics.plugin import SchedulerPlugin, _get_plugin_name from .event import EventExtension @@ -88,6 +93,12 @@ except ImportError: compiled = False +try: + import zeroconf + from zeroconf.asyncio import AsyncServiceInfo, AsyncZeroconf +except ImportError: + zeroconf = False + if compiled: from cython import ( Py_hash_t, @@ -3555,6 +3566,12 @@ def __init__( self._lock = asyncio.Lock() self.bandwidth_workers = defaultdict(float) self.bandwidth_types = defaultdict(float) + if zeroconf and dask.config.get("distributed.scheduler.zeroconf"): + self._zeroconf = AsyncZeroconf(ip_version=zeroconf.IPVersion.V4Only) + else: + self._zeroconf = None + self._zeroconf_services = [] + self._zeroconf_registration_tasks = [] if not preload: preload = dask.config.get("distributed.scheduler.preload") @@ -3929,6 +3946,32 @@ async def start(self): for listener in self.listeners: logger.info(" Scheduler at: %25s", listener.contact_address) + if ( + zeroconf + and dask.config.get("distributed.scheduler.zeroconf") + and not self.address.startswith("inproc://") + ): + # Advertise service via mdns service discovery + try: + host, port = get_address_host_port(listener.contact_address) + except NotImplementedError: + # If address is not IP based continue + continue + protocol, _ = parse_address(listener.contact_address) + short_id = self.id.split("-")[1] + info = AsyncServiceInfo( + "_dask._tcp.local.", + f"_sched-{short_id}._dask._tcp.local.", + addresses=[socket.inet_aton(host)], + port=port, + properties={"protocol": protocol}, + server=f"sched-{short_id}.dask.local.", + ) + self._zeroconf_services.append(info) + self._zeroconf_registration_tasks.append( + await self._zeroconf.async_register_service(info) + ) + logger.info(" Advertising as: %25s", info.server) for k, v in self.services.items(): logger.info("%11s at: %25s", k, "%s:%d" % (listen_ip, v.port)) @@ -4000,6 +4043,12 @@ async def close(self, comm=None, fast=False, close_workers=False): self.stop_services() + if self._zeroconf: + await self._zeroconf.async_close() + for task in self._zeroconf_registration_tasks: + with suppress(asyncio.CancelledError): + task.cancel() + for ext in parent._extensions.values(): with suppress(AttributeError): ext.teardown() diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index 260513e954e..795d9573295 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -22,7 +22,7 @@ from distributed import Client, Nanny, Worker, fire_and_forget, wait from distributed.comm import Comm -from distributed.compatibility import LINUX, WINDOWS +from distributed.compatibility import LINUX, MACOS, WINDOWS from distributed.core import ConnectionPool, Status, connect, rpc from distributed.metrics import time from distributed.protocol.pickle import dumps @@ -3156,6 +3156,22 @@ async def test_transition_counter(c, s, a, b): assert s.transition_counter > 1 +@gen_cluster( + config={"distributed.scheduler.zeroconf": True}, +) +async def test_zeroconf(s, *_): + zeroconf = pytest.importorskip("zeroconf") + assert len(s._zeroconf_services) == 1 + async with zeroconf.asyncio.AsyncZeroconf(interfaces=["127.0.0.1"]) as aiozc: + service = s._zeroconf_services[0] + service = await aiozc.async_get_service_info("_dask._tcp.local.", service.name) + [address] = service.parsed_addresses() + assert str(address) in s.address + assert str(service.port) in s.address + + +@pytest.mark.skipif(MACOS and sys.version_info < (3, 9), reason="GH#5056") +@pytest.mark.slow @gen_cluster( client=True, nthreads=[("127.0.0.1", 1) for _ in range(10)],