1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154 | from functools import partial
import gc
from operator import add
import weakref
import sys
import pytest
from distributed.protocol import deserialize, serialize
from distributed.protocol.pickle import HIGHEST_PROTOCOL, dumps, loads
if sys.version_info < (3, 8):
try:
import pickle5 as pickle
except ImportError:
import pickle
else:
import pickle
def test_pickle_data():
data = [1, b"123", "123", [123], {}, set()]
for d in data:
assert loads(dumps(d)) == d
assert deserialize(*serialize(d, serializers=("pickle",))) == d
def test_pickle_out_of_band():
class MemoryviewHolder:
def __init__(self, mv):
self.mv = memoryview(mv)
def __reduce_ex__(self, protocol):
if protocol >= 5:
return MemoryviewHolder, (pickle.PickleBuffer(self.mv),)
else:
return MemoryviewHolder, (self.mv.tobytes(),)
mv = memoryview(b"123")
mvh = MemoryviewHolder(mv)
if HIGHEST_PROTOCOL >= 5:
l = []
d = dumps(mvh, buffer_callback=l.append)
mvh2 = loads(d, buffers=l)
assert len(l) == 1
assert isinstance(l[0], pickle.PickleBuffer)
assert memoryview(l[0]) == mv
else:
mvh2 = loads(dumps(mvh))
assert isinstance(mvh2, MemoryviewHolder)
assert isinstance(mvh2.mv, memoryview)
assert mvh2.mv == mv
h, f = serialize(mvh, serializers=("pickle",))
mvh3 = deserialize(h, f)
assert isinstance(mvh3, MemoryviewHolder)
assert isinstance(mvh3.mv, memoryview)
assert mvh3.mv == mv
if HIGHEST_PROTOCOL >= 5:
assert len(f) == 2
assert isinstance(f[0], bytes)
assert isinstance(f[1], memoryview)
assert f[1] == mv
else:
assert len(f) == 1
assert isinstance(f[0], bytes)
def test_pickle_numpy():
np = pytest.importorskip("numpy")
x = np.ones(5)
assert (loads(dumps(x)) == x).all()
assert (deserialize(*serialize(x, serializers=("pickle",))) == x).all()
x = np.ones(5000)
assert (loads(dumps(x)) == x).all()
assert (deserialize(*serialize(x, serializers=("pickle",))) == x).all()
x = np.array([np.arange(3), np.arange(4, 6)], dtype=object)
x2 = loads(dumps(x))
assert x.shape == x2.shape
assert x.dtype == x2.dtype
assert x.strides == x2.strides
for e_x, e_x2 in zip(x.flat, x2.flat):
np.testing.assert_equal(e_x, e_x2)
h, f = serialize(x, serializers=("pickle",))
if HIGHEST_PROTOCOL >= 5:
assert len(f) == 3
else:
assert len(f) == 1
x3 = deserialize(h, f)
assert x.shape == x3.shape
assert x.dtype == x3.dtype
assert x.strides == x3.strides
for e_x, e_x3 in zip(x.flat, x3.flat):
np.testing.assert_equal(e_x, e_x3)
if HIGHEST_PROTOCOL >= 5:
x = np.ones(5000)
l = []
d = dumps(x, buffer_callback=l.append)
assert len(l) == 1
assert isinstance(l[0], pickle.PickleBuffer)
assert memoryview(l[0]) == memoryview(x)
assert (loads(d, buffers=l) == x).all()
h, f = serialize(x, serializers=("pickle",))
assert len(f) == 2
assert isinstance(f[0], bytes)
assert isinstance(f[1], memoryview)
assert (deserialize(h, f) == x).all()
@pytest.mark.xfail(
sys.version_info[:2] == (3, 8),
reason="Sporadic failure on Python 3.8",
strict=False,
)
def test_pickle_functions():
def make_closure():
value = 1
def f(x): # closure
return x + value
return f
def funcs():
yield make_closure()
yield (lambda x: x + 1)
yield partial(add, 1)
for func in funcs():
wr = weakref.ref(func)
func2 = loads(dumps(func))
wr2 = weakref.ref(func2)
assert func2(1) == func(1)
func3 = deserialize(*serialize(func, serializers=("pickle",)))
wr3 = weakref.ref(func3)
assert func3(1) == func(1)
del func, func2, func3
gc.collect()
assert wr() is None
assert wr2() is None
assert wr3() is None
|