Source code distributed/protocol/cupy.py

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
"""
Efficient serialization GPU arrays.
"""
import copyreg

import cupy

from .cuda import cuda_deserialize, cuda_serialize
from .serialize import dask_deserialize, dask_serialize, register_generic

try:
    from .rmm import dask_deserialize_rmm_device_buffer as dask_deserialize_cuda_buffer
except ImportError:
    from .numba import dask_deserialize_numba_array as dask_deserialize_cuda_buffer


@cuda_serialize.register(cupy.ndarray)
def cuda_serialize_cupy_ndarray(x):
    # Making sure `x` is behaving
    if not (x.flags["C_CONTIGUOUS"] or x.flags["F_CONTIGUOUS"]):
        x = cupy.array(x, copy=True)

    header = x.__cuda_array_interface__.copy()
    header["strides"] = tuple(x.strides)
    header["lengths"] = [x.nbytes]
    frames = [
        cupy.ndarray(
            shape=(x.nbytes,), dtype=cupy.dtype("u1"), memptr=x.data, strides=(1,)
        )
    ]

    return header, frames


@cuda_deserialize.register(cupy.ndarray)
def cuda_deserialize_cupy_ndarray(header, frames):
    (frame,) = frames
    arr = cupy.ndarray(
        shape=header["shape"],
        dtype=header["typestr"],
        memptr=cupy.asarray(frame).data,
        strides=header["strides"],
    )
    return arr


@dask_serialize.register(cupy.ndarray)
def dask_serialize_cupy_ndarray(x):
    header, frames = cuda_serialize_cupy_ndarray(x)
    header["writeable"] = (None,) * len(frames)
    frames = [memoryview(cupy.asnumpy(f)) for f in frames]
    return header, frames


@dask_deserialize.register(cupy.ndarray)
def dask_deserialize_cupy_ndarray(header, frames):
    frames = [dask_deserialize_cuda_buffer(header, frames)]
    arr = cuda_deserialize_cupy_ndarray(header, frames)
    return arr


try:
    from cupy.cusparse import MatDescriptor
    from cupyx.scipy.sparse import spmatrix
except ImportError:
    MatDescriptor = None
    spmatrix = None


if MatDescriptor is not None:

    def reduce_matdescriptor(other):
        # Pickling MatDescriptor errors
        # xref: https://github.com/cupy/cupy/issues/3061
        return cupy.cusparse.MatDescriptor.create, ()

    copyreg.pickle(MatDescriptor, reduce_matdescriptor)

    @cuda_serialize.register(MatDescriptor)
    @dask_serialize.register(MatDescriptor)
    def serialize_cupy_matdescriptor(x):
        header, frames = {}, []
        return header, frames

    @cuda_deserialize.register(MatDescriptor)
    @dask_deserialize.register(MatDescriptor)
    def deserialize_cupy_matdescriptor(header, frames):
        return MatDescriptor.create()


if spmatrix is not None:
    for n, s, d in [
        ("cuda", cuda_serialize, cuda_deserialize),
        ("dask", dask_serialize, dask_deserialize),
    ]:
        register_generic(spmatrix, n, s, d)