Source code distributed/protocol/rmm.py

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import numba
import numba.cuda
import numpy
import rmm

from .cuda import cuda_deserialize, cuda_serialize
from .serialize import dask_deserialize, dask_serialize

# Used for RMM 0.11.0+ otherwise Numba serializers used
if hasattr(rmm, "DeviceBuffer"):

    @cuda_serialize.register(rmm.DeviceBuffer)
    def cuda_serialize_rmm_device_buffer(x):
        header = x.__cuda_array_interface__.copy()
        header["strides"] = (1,)
        header["lengths"] = [x.nbytes]
        frames = [x]
        return header, frames

    @cuda_deserialize.register(rmm.DeviceBuffer)
    def cuda_deserialize_rmm_device_buffer(header, frames):
        (arr,) = frames

        # We should already have `DeviceBuffer`
        # as RMM is used preferably for allocations
        # when it is available (as it is here).
        assert isinstance(arr, rmm.DeviceBuffer)

        return arr

    @dask_serialize.register(rmm.DeviceBuffer)
    def dask_serialize_rmm_device_buffer(x):
        header, frames = cuda_serialize_rmm_device_buffer(x)
        header["writeable"] = (None,) * len(frames)
        frames = [numba.cuda.as_cuda_array(f).copy_to_host().data for f in frames]
        return header, frames

    @dask_deserialize.register(rmm.DeviceBuffer)
    def dask_deserialize_rmm_device_buffer(header, frames):
        (frame,) = frames

        arr = numpy.asarray(memoryview(frame))
        ptr = arr.__array_interface__["data"][0]
        size = arr.nbytes

        buf = rmm.DeviceBuffer(ptr=ptr, size=size)

        return buf