1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228 | """
Record known compressors
Includes utilities for determining whether or not to compress
"""
from contextlib import suppress
from functools import partial
import logging
import random
import dask
from tlz import identity
try:
import blosc
n = blosc.set_nthreads(2)
if hasattr("blosc", "releasegil"):
blosc.set_releasegil(True)
except ImportError:
blosc = False
from ..utils import ensure_bytes
compressions = {None: {"compress": identity, "decompress": identity}}
compressions[False] = compressions[None] # alias
default_compression = None
logger = logging.getLogger(__name__)
with suppress(ImportError):
import zlib
compressions["zlib"] = {"compress": zlib.compress, "decompress": zlib.decompress}
with suppress(ImportError):
import snappy
def _fixed_snappy_decompress(data):
# snappy.decompress() doesn't accept memoryviews
if isinstance(data, (memoryview, bytearray)):
data = bytes(data)
return snappy.decompress(data)
compressions["snappy"] = {
"compress": snappy.compress,
"decompress": _fixed_snappy_decompress,
}
default_compression = "snappy"
with suppress(ImportError):
import lz4
try:
# try using the new lz4 API
import lz4.block
lz4_compress = lz4.block.compress
lz4_decompress = lz4.block.decompress
except ImportError:
# fall back to old one
lz4_compress = lz4.LZ4_compress
lz4_decompress = lz4.LZ4_uncompress
# helper to bypass missing memoryview support in current lz4
# (fixed in later versions)
def _fixed_lz4_compress(data):
try:
return lz4_compress(data)
except TypeError:
if isinstance(data, (memoryview, bytearray)):
return lz4_compress(bytes(data))
else:
raise
def _fixed_lz4_decompress(data):
try:
return lz4_decompress(data)
except (ValueError, TypeError):
if isinstance(data, (memoryview, bytearray)):
return lz4_decompress(bytes(data))
else:
raise
compressions["lz4"] = {
"compress": _fixed_lz4_compress,
"decompress": _fixed_lz4_decompress,
}
default_compression = "lz4"
with suppress(ImportError):
import zstandard
zstd_compressor = zstandard.ZstdCompressor(
level=dask.config.get("distributed.comm.zstd.level"),
threads=dask.config.get("distributed.comm.zstd.threads"),
)
zstd_decompressor = zstandard.ZstdDecompressor()
def zstd_compress(data):
return zstd_compressor.compress(data)
def zstd_decompress(data):
return zstd_decompressor.decompress(data)
compressions["zstd"] = {"compress": zstd_compress, "decompress": zstd_decompress}
with suppress(ImportError):
import blosc
compressions["blosc"] = {
"compress": partial(blosc.compress, clevel=5, cname="lz4"),
"decompress": blosc.decompress,
}
def get_default_compression():
default = dask.config.get("distributed.comm.compression")
if default != "auto":
if default in compressions:
return default
else:
raise ValueError(
"Default compression '%s' not found.\n"
"Choices include auto, %s"
% (default, ", ".join(sorted(map(str, compressions))))
)
else:
return default_compression
get_default_compression()
def byte_sample(b, size, n):
"""Sample a bytestring from many locations
Parameters
----------
b: bytes or memoryview
size: int
size of each sample to collect
n: int
number of samples to collect
"""
starts = [random.randint(0, len(b) - size) for j in range(n)]
ends = []
for i, start in enumerate(starts[:-1]):
ends.append(min(start + size, starts[i + 1]))
ends.append(starts[-1] + size)
parts = [b[start:end] for start, end in zip(starts, ends)]
return b"".join(map(ensure_bytes, parts))
def maybe_compress(
payload,
min_size=1e4,
sample_size=1e4,
nsamples=5,
compression=dask.config.get("distributed.comm.compression"),
):
"""
Maybe compress payload
1. We don't compress small messages
2. We sample the payload in a few spots, compress that, and if it doesn't
do any good we return the original
3. We then compress the full original, it it doesn't compress well then we
return the original
4. We return the compressed result
"""
if compression == "auto":
compression = default_compression
if not compression:
return None, payload
if len(payload) < min_size:
return None, payload
if len(payload) > 2 ** 31: # Too large, compression libraries often fail
return None, payload
min_size = int(min_size)
sample_size = int(sample_size)
compress = compressions[compression]["compress"]
# Compress a sample, return original if not very compressed
sample = byte_sample(payload, sample_size, nsamples)
if len(compress(sample)) > 0.9 * len(sample): # sample not very compressible
return None, payload
if type(payload) is memoryview:
nbytes = payload.itemsize * len(payload)
else:
nbytes = len(payload)
if default_compression and blosc and type(payload) is memoryview:
# Blosc does itemsize-aware shuffling, resulting in better compression
compressed = blosc.compress(
payload, typesize=payload.itemsize, cname="lz4", clevel=5
)
compression = "blosc"
else:
compressed = compress(ensure_bytes(payload))
if len(compressed) > 0.9 * nbytes: # full data not very compressible
return None, payload
else:
return compression, compressed
def decompress(header, frames):
""" Decompress frames according to information in the header """
return [
compressions[c]["decompress"](frame)
for c, frame in zip(header["compression"], frames)
]
|