From e8d9b8153c7181adde512b8d4aae5c5989faa709 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Tue, 29 Oct 2024 15:29:39 -0500 Subject: [PATCH] Remove OMP_NUM_THREADS if set to empty Generally though, sometimes our use of OMP_NUM_THREADS makes other systems sad. It's unfortunately somewhat difficult to turn off. A common approach is that people set `OMP_NUM_THREADS=""` but this doesn't properly unset things. I'm curious if an approach like this would be helpful. There might be a cleaner way to do this, and this might be a bad idea. Please feel free to reject. It was just easy to put this up as a PR. --- distributed/nanny.py | 6 ++++++ distributed/tests/test_nanny.py | 13 +++++++++++++ 2 files changed, 19 insertions(+) diff --git a/distributed/nanny.py b/distributed/nanny.py index 859b9f22dc..99d9ff5d68 100644 --- a/distributed/nanny.py +++ b/distributed/nanny.py @@ -240,6 +240,7 @@ def __init__( # type: ignore[no-untyped-def] # https://github.com/dask/dask/issues/6640. self.pre_spawn_env.update({"PYTHONHASHSEED": "6640"}) + self.env = merge( self.pre_spawn_env, _get_env_variables("distributed.nanny.environ"), @@ -1031,4 +1032,9 @@ def _get_env_variables(config_key: str) -> dict[str, str]: # Override dask config with explicitly defined env variables from the OS # Allow unsetting a variable in a config override by setting its value to None. cfg = {k: os.environ.get(k, str(v)) for k, v in cfg.items() if v is not None} + + for k, v in list(cfg.items()): + if "_NUM_THREADS" in k and not v: + del cfg[k] + return cfg diff --git a/distributed/tests/test_nanny.py b/distributed/tests/test_nanny.py index b05b7dc90c..dc00c33503 100644 --- a/distributed/tests/test_nanny.py +++ b/distributed/tests/test_nanny.py @@ -312,6 +312,19 @@ async def test_environment_variable(c, s): await asyncio.gather(a.close(), b.close()) +@gen_cluster( + nthreads=[("", 1)], + client=True, + Worker=Nanny, + config={ + "distributed.nanny.pre-spawn-environ": {"OMP_NUM_THREADS": ""}, + }, +) +async def test_omp_num_threads_off(c, s, a): + results = await c.run(lambda: "OMP_NUM_THREADS" in os.environ) + assert results == {a.worker_address: False} + + @gen_cluster(nthreads=[], client=True) async def test_environment_variable_by_config(c, s, monkeypatch): with dask.config.set({"distributed.nanny.environ": "456"}):