diff --git a/distributed/comm/ucx.py b/distributed/comm/ucx.py index b4742a7c65..38be5a27f9 100644 --- a/distributed/comm/ucx.py +++ b/distributed/comm/ucx.py @@ -125,11 +125,19 @@ def init_once(): ucp.init(options=ucx_config, env_takes_precedence=True) + pool_size_str = dask.config.get("distributed.rmm.pool-size") + # Find the function, `cuda_array()`, to use when allocating new CUDA arrays try: import rmm device_array = lambda n: rmm.DeviceBuffer(size=n) + + if pool_size_str is not None: + pool_size = parse_bytes(pool_size_str) + rmm.reinitialize( + pool_allocator=True, managed_memory=False, initial_pool_size=pool_size + ) except ImportError: try: import numba.cuda @@ -140,6 +148,7 @@ def numba_device_array(n): return a device_array = numba_device_array + except ImportError: def device_array(n): @@ -147,12 +156,11 @@ def device_array(n): "In order to send/recv CUDA arrays, Numba or RMM is required" ) - pool_size_str = dask.config.get("distributed.rmm.pool-size") - if pool_size_str is not None: - pool_size = parse_bytes(pool_size_str) - rmm.reinitialize( - pool_allocator=True, managed_memory=False, initial_pool_size=pool_size - ) + if pool_size_str is not None: + warnings.warn( + "Initial RMM pool size defined, but RMM is not available. " + "Please consider installing RMM or removing the pool size option." + ) def _close_comm(ref):