Skip to content

Commit

Permalink
Fix import error when distributed.rmm.pool-size is setted (#6482)
Browse files Browse the repository at this point in the history
Fix import error when distributed.rmm.pool-size is set
  • Loading branch information
KoyamaSohei authored Jun 6, 2022
1 parent d360b63 commit 7d280fd
Showing 1 changed file with 14 additions and 6 deletions.
20 changes: 14 additions & 6 deletions distributed/comm/ucx.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,11 +125,19 @@ def init_once():

ucp.init(options=ucx_config, env_takes_precedence=True)

pool_size_str = dask.config.get("distributed.rmm.pool-size")

# Find the function, `cuda_array()`, to use when allocating new CUDA arrays
try:
import rmm

device_array = lambda n: rmm.DeviceBuffer(size=n)

if pool_size_str is not None:
pool_size = parse_bytes(pool_size_str)
rmm.reinitialize(
pool_allocator=True, managed_memory=False, initial_pool_size=pool_size
)
except ImportError:
try:
import numba.cuda
Expand All @@ -140,19 +148,19 @@ def numba_device_array(n):
return a

device_array = numba_device_array

except ImportError:

def device_array(n):
raise RuntimeError(
"In order to send/recv CUDA arrays, Numba or RMM is required"
)

pool_size_str = dask.config.get("distributed.rmm.pool-size")
if pool_size_str is not None:
pool_size = parse_bytes(pool_size_str)
rmm.reinitialize(
pool_allocator=True, managed_memory=False, initial_pool_size=pool_size
)
if pool_size_str is not None:
warnings.warn(
"Initial RMM pool size defined, but RMM is not available. "
"Please consider installing RMM or removing the pool size option."
)


def _close_comm(ref):
Expand Down

0 comments on commit 7d280fd

Please sign in to comment.