Source code distributed/worker_client.py

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
from contextlib import contextmanager
import warnings

import dask
from .threadpoolexecutor import secede, rejoin
from .worker import thread_state, get_client, get_worker
from .utils import parse_timedelta


@contextmanager
def worker_client(timeout=None, separate_thread=True):
    """Get client for this thread

    This context manager is intended to be called within functions that we run
    on workers.  When run as a context manager it delivers a client
    ``Client`` object that can submit other tasks directly from that worker.

    Parameters
    ----------
    timeout: Number or String
        Timeout after which to error out. Defaults to the
        ``distributed.comm.timeouts.connect`` configuration value.
    separate_thread: bool, optional
        Whether to run this function outside of the normal thread pool
        defaults to True

    Examples
    --------
    >>> def func(x):
    ...     with worker_client(timeout="10s") as c:  # connect from worker back to scheduler
    ...         a = c.submit(inc, x)     # this task can submit more tasks
    ...         b = c.submit(dec, x)
    ...         result = c.gather([a, b])  # and gather results
    ...     return result

    >>> future = client.submit(func, 1)  # submit func(1) on cluster

    See Also
    --------
    get_worker
    get_client
    secede
    """

    if timeout is None:
        timeout = dask.config.get("distributed.comm.timeouts.connect")

    timeout = parse_timedelta(timeout, "s")

    worker = get_worker()
    client = get_client(timeout=timeout)
    if separate_thread:
        secede()  # have this thread secede from the thread pool
        worker.loop.add_callback(
            worker.transition, worker.tasks[thread_state.key], "long-running"
        )

    yield client

    if separate_thread:
        rejoin()


def local_client(*args, **kwargs):
    warnings.warn("local_client has moved to worker_client")
    return worker_client(*args, **kwargs)