Skip to content

Commit

Permalink
Refactor for distributed.rmm.pool-size warning
Browse files Browse the repository at this point in the history
  • Loading branch information
KoyamaSohei committed Jun 1, 2022
1 parent 2467b8c commit 6575251
Showing 1 changed file with 8 additions and 7 deletions.
15 changes: 8 additions & 7 deletions distributed/comm/ucx.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,13 +125,14 @@ def init_once():

ucp.init(options=ucx_config, env_takes_precedence=True)

pool_size_str = dask.config.get("distributed.rmm.pool-size")

# Find the function, `cuda_array()`, to use when allocating new CUDA arrays
try:
import rmm

device_array = lambda n: rmm.DeviceBuffer(size=n)

pool_size_str = dask.config.get("distributed.rmm.pool-size")
if pool_size_str is not None:
pool_size = parse_bytes(pool_size_str)
rmm.reinitialize(
Expand All @@ -148,19 +149,19 @@ def numba_device_array(n):

device_array = numba_device_array

pool_size_str = dask.config.get("distributed.rmm.pool-size")
if pool_size_str is not None:
warnings.warn(
"Initial RMM pool size defined, but RMM is not available. "
"Please consider installing RMM or removing the pool size option."
)
except ImportError:

def device_array(n):
raise RuntimeError(
"In order to send/recv CUDA arrays, Numba or RMM is required"
)

if pool_size_str is not None:
warnings.warn(
"Initial RMM pool size defined, but RMM is not available. "
"Please consider installing RMM or removing the pool size option."
)


def _close_comm(ref):
"""Callback to close Dask Comm when UCX Endpoint closes or errors
Expand Down

0 comments on commit 6575251

Please sign in to comment.