Source code distributed/node.py

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
from contextlib import suppress
import logging
import warnings
import weakref

from tornado.httpserver import HTTPServer
import tlz
import dask

from .comm import get_tcp_server_address
from .comm import get_address_host
from .core import Server
from .http.routing import RoutingApplication
from .versions import get_versions
from .utils import DequeHandler, clean_dashboard_address


class ServerNode(Server):
    """
    Base class for server nodes in a distributed cluster.
    """

    # TODO factor out security, listening, services, etc. here

    # XXX avoid inheriting from Server? there is some large potential for confusion
    # between base and derived attribute namespaces...

    def versions(self, comm=None, packages=None):
        return get_versions(packages=packages)

    def start_services(self, default_listen_ip):
        if default_listen_ip == "0.0.0.0":
            default_listen_ip = ""  # for IPV6

        for k, v in self.service_specs.items():
            listen_ip = None
            if isinstance(k, tuple):
                k, port = k
            else:
                port = 0

            if isinstance(port, str):
                port = port.split(":")

            if isinstance(port, (tuple, list)):
                if len(port) == 2:
                    listen_ip, port = (port[0], int(port[1]))
                elif len(port) == 1:
                    [listen_ip], port = port, 0
                else:
                    raise ValueError(port)

            if isinstance(v, tuple):
                v, kwargs = v
            else:
                kwargs = {}

            try:
                service = v(self, io_loop=self.loop, **kwargs)
                service.listen(
                    (listen_ip if listen_ip is not None else default_listen_ip, port)
                )
                self.services[k] = service
            except Exception as e:
                warnings.warn(
                    "\nCould not launch service '%s' on port %s. " % (k, port)
                    + "Got the following message:\n\n"
                    + str(e),
                    stacklevel=3,
                )

    def stop_services(self):
        for service in self.services.values():
            service.stop()

    @property
    def service_ports(self):
        return {k: v.port for k, v in self.services.items()}

    def _setup_logging(self, logger):
        self._deque_handler = DequeHandler(
            n=dask.config.get("distributed.admin.log-length")
        )
        self._deque_handler.setFormatter(
            logging.Formatter(dask.config.get("distributed.admin.log-format"))
        )
        logger.addHandler(self._deque_handler)
        weakref.finalize(self, logger.removeHandler, self._deque_handler)

    def get_logs(self, comm=None, n=None):
        deque_handler = self._deque_handler
        if n is None:
            L = list(deque_handler.deque)
        else:
            L = deque_handler.deque
            L = [L[-i] for i in range(min(n, len(L)))]
        return [(msg.levelname, deque_handler.format(msg)) for msg in L]

    def start_http_server(
        self, routes, dashboard_address, default_port=0, ssl_options=None
    ):
        """ This creates an HTTP Server running on this node """

        self.http_application = RoutingApplication(routes)

        # TLS configuration
        tls_key = dask.config.get("distributed.scheduler.dashboard.tls.key")
        tls_cert = dask.config.get("distributed.scheduler.dashboard.tls.cert")
        tls_ca_file = dask.config.get("distributed.scheduler.dashboard.tls.ca-file")
        if tls_cert:
            import ssl

            ssl_options = ssl.create_default_context(
                cafile=tls_ca_file, purpose=ssl.Purpose.SERVER_AUTH
            )
            ssl_options.load_cert_chain(tls_cert, keyfile=tls_key)
            # We don't care about auth here, just encryption
            ssl_options.check_hostname = False
            ssl_options.verify_mode = ssl.CERT_NONE

        self.http_server = HTTPServer(self.http_application, ssl_options=ssl_options)
        http_address = clean_dashboard_address(dashboard_address or default_port)

        if http_address["address"] is None:
            address = self._start_address
            if isinstance(address, (list, tuple)):
                address = address[0]
            if address:
                with suppress(ValueError):
                    http_address["address"] = get_address_host(address)

        change_port = False
        retries_left = 3
        while True:
            try:
                if not change_port:
                    self.http_server.listen(**http_address)
                else:
                    self.http_server.listen(**tlz.merge(http_address, {"port": 0}))
                break
            except Exception:
                change_port = True
                retries_left = retries_left - 1
                if retries_left < 1:
                    raise

        self.http_server.port = get_tcp_server_address(self.http_server)[1]
        self.services["dashboard"] = self.http_server

        if change_port and dashboard_address:
            warnings.warn(
                "Port {} is already in use.\n"
                "Perhaps you already have a cluster running?\n"
                "Hosting the HTTP server on port {} instead".format(
                    http_address["port"], self.http_server.port
                )
            )