Source code distributed/protocol/tests/test_scipy.py

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
import pytest
from distributed.protocol import deserialize, serialize

numpy = pytest.importorskip("numpy")
scipy = pytest.importorskip("scipy")
scipy_sparse = pytest.importorskip("scipy.sparse")


@pytest.mark.parametrize(
    "sparse_type",
    [
        scipy_sparse.bsr_matrix,
        scipy_sparse.coo_matrix,
        scipy_sparse.csc_matrix,
        scipy_sparse.csr_matrix,
        scipy_sparse.dia_matrix,
        scipy_sparse.dok_matrix,
        scipy_sparse.lil_matrix,
    ],
)
@pytest.mark.parametrize(
    "dtype",
    [numpy.dtype("<f4"), numpy.dtype(">f4"), numpy.dtype("<f8"), numpy.dtype(">f8")],
)
def test_serialize_scipy_sparse(sparse_type, dtype):
    a = numpy.array([[0, 1, 0], [2, 0, 3], [0, 4, 0]], dtype=dtype)

    anz = a.nonzero()
    acoo = scipy_sparse.coo_matrix((a[anz], anz))
    asp = sparse_type(acoo)

    header, frames = serialize(asp, serializers=["dask"])
    asp2 = deserialize(header, frames)

    a2 = asp2.todense()

    assert (a == a2).all()