1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127 | import logging
import math
import socket
import dask
from dask.sizeof import sizeof
from dask.utils import parse_bytes
from .. import protocol
from ..utils import get_ip, get_ipv6, nbytes, offload
logger = logging.getLogger(__name__)
# Offload (de)serializing large frames to improve event loop responsiveness.
# We use at most 4 threads to allow for parallel processing of large messages.
FRAME_OFFLOAD_THRESHOLD = dask.config.get("distributed.comm.offload")
if isinstance(FRAME_OFFLOAD_THRESHOLD, str):
FRAME_OFFLOAD_THRESHOLD = parse_bytes(FRAME_OFFLOAD_THRESHOLD)
async def to_frames(
msg, serializers=None, on_error="message", context=None, allow_offload=True
):
"""
Serialize a message into a list of Distributed protocol frames.
"""
def _to_frames():
try:
return list(
protocol.dumps(
msg, serializers=serializers, on_error=on_error, context=context
)
)
except Exception as e:
logger.info("Unserializable Message: %s", msg)
logger.exception(e)
raise
if FRAME_OFFLOAD_THRESHOLD and allow_offload:
try:
msg_size = sizeof(msg)
except RecursionError:
msg_size = math.inf
else:
msg_size = 0
if allow_offload and FRAME_OFFLOAD_THRESHOLD and msg_size > FRAME_OFFLOAD_THRESHOLD:
return await offload(_to_frames)
else:
return _to_frames()
async def from_frames(frames, deserialize=True, deserializers=None, allow_offload=True):
"""
Unserialize a list of Distributed protocol frames.
"""
size = False
def _from_frames():
try:
return protocol.loads(
frames, deserialize=deserialize, deserializers=deserializers
)
except EOFError:
if size > 1000:
datastr = "[too large to display]"
else:
datastr = frames
# Aid diagnosing
logger.error("truncated data stream (%d bytes): %s", size, datastr)
raise
if allow_offload and deserialize and FRAME_OFFLOAD_THRESHOLD:
size = sum(map(nbytes, frames))
if (
allow_offload
and deserialize
and FRAME_OFFLOAD_THRESHOLD
and size > FRAME_OFFLOAD_THRESHOLD
):
res = await offload(_from_frames)
else:
res = _from_frames()
return res
def get_tcp_server_address(tcp_server):
"""
Get the bound address of a started Tornado TCPServer.
"""
sockets = list(tcp_server._sockets.values())
if not sockets:
raise RuntimeError("TCP Server %r not started yet?" % (tcp_server,))
def _look_for_family(fam):
for sock in sockets:
if sock.family == fam:
return sock
return None
# If listening on both IPv4 and IPv6, prefer IPv4 as defective IPv6
# is common (e.g. Travis-CI).
sock = _look_for_family(socket.AF_INET)
if sock is None:
sock = _look_for_family(socket.AF_INET6)
if sock is None:
raise RuntimeError("No Internet socket found on TCPServer??")
return sock.getsockname()
def ensure_concrete_host(host):
"""
Ensure the given host string (or IP) denotes a concrete host, not a
wildcard listening address.
"""
if host in ("0.0.0.0", ""):
return get_ip()
elif host == "::":
return get_ipv6()
else:
return host
|