Source code distributed/protocol/numba.py

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import weakref

import numba.cuda
import numpy as np

from .cuda import cuda_deserialize, cuda_serialize
from .serialize import dask_deserialize, dask_serialize

try:
    from .rmm import dask_deserialize_rmm_device_buffer
except ImportError:
    dask_deserialize_rmm_device_buffer = None


@cuda_serialize.register(numba.cuda.devicearray.DeviceNDArray)
def cuda_serialize_numba_ndarray(x):
    # Making sure `x` is behaving
    if not (x.flags["C_CONTIGUOUS"] or x.flags["F_CONTIGUOUS"]):
        shape = x.shape
        t = numba.cuda.device_array(shape, dtype=x.dtype)
        t.copy_to_device(x)
        x = t

    header = x.__cuda_array_interface__.copy()
    header["strides"] = tuple(x.strides)
    header["lengths"] = [x.nbytes]
    frames = [
        numba.cuda.cudadrv.devicearray.DeviceNDArray(
            shape=(x.nbytes,), strides=(1,), dtype=np.dtype("u1"), gpu_data=x.gpu_data
        )
    ]

    return header, frames


@cuda_deserialize.register(numba.cuda.devicearray.DeviceNDArray)
def cuda_deserialize_numba_ndarray(header, frames):
    (frame,) = frames
    shape = header["shape"]
    strides = header["strides"]

    arr = numba.cuda.devicearray.DeviceNDArray(
        shape=shape,
        strides=strides,
        dtype=np.dtype(header["typestr"]),
        gpu_data=numba.cuda.as_cuda_array(frame).gpu_data,
    )
    return arr


@dask_serialize.register(numba.cuda.devicearray.DeviceNDArray)
def dask_serialize_numba_ndarray(x):
    header, frames = cuda_serialize_numba_ndarray(x)
    header["writeable"] = (None,) * len(frames)
    frames = [memoryview(f.copy_to_host()) for f in frames]
    return header, frames


@dask_deserialize.register(numba.cuda.devicearray.DeviceNDArray)
def dask_deserialize_numba_array(header, frames):
    if dask_deserialize_rmm_device_buffer:
        frames = [dask_deserialize_rmm_device_buffer(header, frames)]
    else:
        frames = [numba.cuda.to_device(np.asarray(memoryview(f))) for f in frames]
        for f in frames:
            weakref.finalize(f, numba.cuda.current_context)

    arr = cuda_deserialize_numba_ndarray(header, frames)
    return arr