import weakref
import numba.cuda
import numpy as np
from .cuda import cuda_deserialize, cuda_serialize
from .serialize import dask_deserialize, dask_serialize
try:
from .rmm import dask_deserialize_rmm_device_buffer
except ImportError:
dask_deserialize_rmm_device_buffer = None
@cuda_serialize.register(numba.cuda.devicearray.DeviceNDArray)
def cuda_serialize_numba_ndarray(x):
# Making sure `x` is behaving
if not (x.flags["C_CONTIGUOUS"] or x.flags["F_CONTIGUOUS"]):
shape = x.shape
t = numba.cuda.device_array(shape, dtype=x.dtype)
t.copy_to_device(x)
x = t
header = x.__cuda_array_interface__.copy()
header["strides"] = tuple(x.strides)
header["lengths"] = [x.nbytes]
frames = [
numba.cuda.cudadrv.devicearray.DeviceNDArray(
shape=(x.nbytes,), strides=(1,), dtype=np.dtype("u1"), gpu_data=x.gpu_data
)
]
return header, frames
@cuda_deserialize.register(numba.cuda.devicearray.DeviceNDArray)
def cuda_deserialize_numba_ndarray(header, frames):
(frame,) = frames
shape = header["shape"]
strides = header["strides"]
arr = numba.cuda.devicearray.DeviceNDArray(
shape=shape,
strides=strides,
dtype=np.dtype(header["typestr"]),
gpu_data=numba.cuda.as_cuda_array(frame).gpu_data,
)
return arr
@dask_serialize.register(numba.cuda.devicearray.DeviceNDArray)
def dask_serialize_numba_ndarray(x):
header, frames = cuda_serialize_numba_ndarray(x)
header["writeable"] = (None,) * len(frames)
frames = [memoryview(f.copy_to_host()) for f in frames]
return header, frames
@dask_deserialize.register(numba.cuda.devicearray.DeviceNDArray)
def dask_deserialize_numba_array(header, frames):
if dask_deserialize_rmm_device_buffer:
frames = [dask_deserialize_rmm_device_buffer(header, frames)]
else:
frames = [numba.cuda.to_device(np.asarray(memoryview(f))) for f in frames]
for f in frames:
weakref.finalize(f, numba.cuda.current_context)
arr = cuda_deserialize_numba_ndarray(header, frames)
return arr