Source code distributed/protocol/tests/test_pickle.py

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
from functools import partial
import gc
from operator import add
import weakref
import sys

import pytest

from distributed.protocol import deserialize, serialize
from distributed.protocol.pickle import HIGHEST_PROTOCOL, dumps, loads

if sys.version_info < (3, 8):
    try:
        import pickle5 as pickle
    except ImportError:
        import pickle
else:
    import pickle


def test_pickle_data():
    data = [1, b"123", "123", [123], {}, set()]
    for d in data:
        assert loads(dumps(d)) == d
        assert deserialize(*serialize(d, serializers=("pickle",))) == d


def test_pickle_out_of_band():
    class MemoryviewHolder:
        def __init__(self, mv):
            self.mv = memoryview(mv)

        def __reduce_ex__(self, protocol):
            if protocol >= 5:
                return MemoryviewHolder, (pickle.PickleBuffer(self.mv),)
            else:
                return MemoryviewHolder, (self.mv.tobytes(),)

    mv = memoryview(b"123")
    mvh = MemoryviewHolder(mv)

    if HIGHEST_PROTOCOL >= 5:
        l = []
        d = dumps(mvh, buffer_callback=l.append)
        mvh2 = loads(d, buffers=l)

        assert len(l) == 1
        assert isinstance(l[0], pickle.PickleBuffer)
        assert memoryview(l[0]) == mv
    else:
        mvh2 = loads(dumps(mvh))

    assert isinstance(mvh2, MemoryviewHolder)
    assert isinstance(mvh2.mv, memoryview)
    assert mvh2.mv == mv

    h, f = serialize(mvh, serializers=("pickle",))
    mvh3 = deserialize(h, f)

    assert isinstance(mvh3, MemoryviewHolder)
    assert isinstance(mvh3.mv, memoryview)
    assert mvh3.mv == mv

    if HIGHEST_PROTOCOL >= 5:
        assert len(f) == 2
        assert isinstance(f[0], bytes)
        assert isinstance(f[1], memoryview)
        assert f[1] == mv
    else:
        assert len(f) == 1
        assert isinstance(f[0], bytes)


def test_pickle_numpy():
    np = pytest.importorskip("numpy")
    x = np.ones(5)
    assert (loads(dumps(x)) == x).all()
    assert (deserialize(*serialize(x, serializers=("pickle",))) == x).all()

    x = np.ones(5000)
    assert (loads(dumps(x)) == x).all()
    assert (deserialize(*serialize(x, serializers=("pickle",))) == x).all()

    x = np.array([np.arange(3), np.arange(4, 6)], dtype=object)
    x2 = loads(dumps(x))
    assert x.shape == x2.shape
    assert x.dtype == x2.dtype
    assert x.strides == x2.strides
    for e_x, e_x2 in zip(x.flat, x2.flat):
        np.testing.assert_equal(e_x, e_x2)
    h, f = serialize(x, serializers=("pickle",))
    if HIGHEST_PROTOCOL >= 5:
        assert len(f) == 3
    else:
        assert len(f) == 1
    x3 = deserialize(h, f)
    assert x.shape == x3.shape
    assert x.dtype == x3.dtype
    assert x.strides == x3.strides
    for e_x, e_x3 in zip(x.flat, x3.flat):
        np.testing.assert_equal(e_x, e_x3)

    if HIGHEST_PROTOCOL >= 5:
        x = np.ones(5000)

        l = []
        d = dumps(x, buffer_callback=l.append)
        assert len(l) == 1
        assert isinstance(l[0], pickle.PickleBuffer)
        assert memoryview(l[0]) == memoryview(x)
        assert (loads(d, buffers=l) == x).all()

        h, f = serialize(x, serializers=("pickle",))
        assert len(f) == 2
        assert isinstance(f[0], bytes)
        assert isinstance(f[1], memoryview)
        assert (deserialize(h, f) == x).all()


@pytest.mark.xfail(
    sys.version_info[:2] == (3, 8),
    reason="Sporadic failure on Python 3.8",
    strict=False,
)
def test_pickle_functions():
    def make_closure():
        value = 1

        def f(x):  # closure
            return x + value

        return f

    def funcs():
        yield make_closure()
        yield (lambda x: x + 1)
        yield partial(add, 1)

    for func in funcs():
        wr = weakref.ref(func)

        func2 = loads(dumps(func))
        wr2 = weakref.ref(func2)
        assert func2(1) == func(1)

        func3 = deserialize(*serialize(func, serializers=("pickle",)))
        wr3 = weakref.ref(func3)
        assert func3(1) == func(1)

        del func, func2, func3
        gc.collect()
        assert wr() is None
        assert wr2() is None
        assert wr3() is None