Source code distributed/comm/tests/test_ucx_config.py

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
import pytest
from time import sleep

import dask
from dask.utils import format_bytes
from distributed import Client
from distributed.utils_test import gen_test, loop, inc, cleanup, popen  # noqa: 401
from distributed.utils import get_ip
from distributed.comm.ucx import _scrub_ucx_config

try:
    HOST = get_ip()
except Exception:
    HOST = "127.0.0.1"

ucp = pytest.importorskip("ucp")
rmm = pytest.importorskip("rmm")


@pytest.mark.asyncio
async def test_ucx_config(cleanup):

    ucx = {
        "nvlink": True,
        "infiniband": True,
        "rdmacm": False,
        "net-devices": "",
        "tcp": True,
        "cuda_copy": True,
    }

    with dask.config.set(ucx=ucx):
        ucx_config = _scrub_ucx_config()
        assert ucx_config.get("TLS") == "rc,tcp,sockcm,cuda_copy,cuda_ipc"
        assert ucx_config.get("SOCKADDR_TLS_PRIORITY") == "sockcm"
        assert ucx_config.get("NET_DEVICES") is None

    ucx = {
        "nvlink": False,
        "infiniband": True,
        "rdmacm": False,
        "net-devices": "mlx5_0:1",
        "tcp": True,
        "cuda_copy": False,
    }

    with dask.config.set(ucx=ucx):
        ucx_config = _scrub_ucx_config()
        assert ucx_config.get("TLS") == "rc,tcp,sockcm"
        assert ucx_config.get("SOCKADDR_TLS_PRIORITY") == "sockcm"
        assert ucx_config.get("NET_DEVICES") == "mlx5_0:1"

    ucx = {
        "nvlink": False,
        "infiniband": True,
        "rdmacm": True,
        "net-devices": "all",
        "MEMTYPE_CACHE": "y",
        "tcp": True,
        "cuda_copy": True,
    }

    with dask.config.set(ucx=ucx):
        ucx_config = _scrub_ucx_config()
        assert ucx_config.get("TLS") == "rc,tcp,rdmacm,cuda_copy"
        assert ucx_config.get("SOCKADDR_TLS_PRIORITY") == "rdmacm"
        assert ucx_config.get("MEMTYPE_CACHE") == "y"


def test_ucx_config_w_env_var(cleanup, loop, monkeypatch):
    size = "1000.00 MB"
    monkeypatch.setenv("DASK_RMM__POOL_SIZE", size)

    dask.config.refresh()

    port = "13339"
    sched_addr = "ucx://%s:%s" % (HOST, port)

    with popen(
        ["dask-scheduler", "--no-dashboard", "--protocol", "ucx", "--port", port]
    ) as sched:
        with popen(
            [
                "dask-worker",
                sched_addr,
                "--no-dashboard",
                "--protocol",
                "ucx",
                "--no-nanny",
            ]
        ) as w:
            with Client(sched_addr, loop=loop, timeout=10) as c:
                while not c.scheduler_info()["workers"]:
                    sleep(0.1)

                # configured with 1G pool
                rmm_usage = c.run_on_scheduler(rmm.get_info)
                assert size == format_bytes(rmm_usage.free)

                # configured with 1G pool
                worker_addr = list(c.scheduler_info()["workers"])[0]
                worker_rmm_usage = c.run(rmm.get_info)
                rmm_usage = worker_rmm_usage[worker_addr]
                assert size == format_bytes(rmm_usage.free)