Source code distributed/protocol/tests/test_netcdf4.py

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import pytest

netCDF4 = pytest.importorskip("netCDF4")
np = pytest.importorskip("numpy")

from distributed.protocol import deserialize, serialize

from distributed.utils import tmpfile


def create_test_dataset(fn):
    with netCDF4.Dataset(fn, mode="w") as ds:
        ds.createDimension("x", 3)
        v = ds.createVariable("x", np.int32, ("x",))
        v[:] = np.arange(3)

        g = ds.createGroup("group")
        g2 = ds.createGroup("group/group1")

        v2 = ds.createVariable("group/y", np.int32, ("x",))
        v2[:] = np.arange(3) + 1

        v3 = ds.createVariable("group/group1/z", np.int32, ("x",))
        v3[:] = np.arange(3) + 2


def test_serialize_deserialize_dataset():
    with tmpfile() as fn:
        create_test_dataset(fn)
        with netCDF4.Dataset(fn, mode="r") as f:
            g = deserialize(*serialize(f))
            assert f.filepath() == g.filepath()
            assert isinstance(g, netCDF4.Dataset)

            assert g.variables["x"].dimensions == ("x",)
            assert g.variables["x"].dtype == np.int32
            assert (g.variables["x"][:] == np.arange(3)).all()


def test_serialize_deserialize_variable():
    with tmpfile() as fn:
        create_test_dataset(fn)
        with netCDF4.Dataset(fn, mode="r") as f:
            x = f.variables["x"]
            y = deserialize(*serialize(x))
            assert isinstance(y, netCDF4.Variable)
            assert y.dimensions == ("x",)
            assert x.dtype == y.dtype
            assert (x[:] == y[:]).all()


def test_serialize_deserialize_group():
    with tmpfile() as fn:
        create_test_dataset(fn)
        with netCDF4.Dataset(fn, mode="r") as f:
            for path in ["group", "group/group1"]:
                g = f[path]
                h = deserialize(*serialize(g))
                assert isinstance(h, netCDF4.Group)
                assert h.name == g.name
                assert list(g.groups) == list(h.groups)
                assert list(g.variables) == list(h.variables)

            vars = [
                f.variables["x"],
                f["group"].variables["y"],
                f["group/group1"].variables["z"],
            ]

            for x in vars:
                y = deserialize(*serialize(x))
                assert isinstance(y, netCDF4.Variable)
                assert y.dimensions == ("x",)
                assert x.dtype == y.dtype
                assert (x[:] == y[:]).all()


from distributed.utils_test import gen_cluster


import dask.array as da


@gen_cluster(client=True)
async def test_netcdf4_serialize(c, s, a, b):
    with tmpfile() as fn:
        create_test_dataset(fn)
        with netCDF4.Dataset(fn, mode="r") as f:
            dset = f.variables["x"]
            x = da.from_array(dset, chunks=2)
            y = c.compute(x)
            y = await y
            assert (y[:] == dset[:]).all()