From bdcef261c5c59cd64ee1bddeb062a2c50a0da19a Mon Sep 17 00:00:00 2001 From: Seunghoon Lee Date: Thu, 26 Sep 2024 13:01:55 +0900 Subject: [PATCH] zluda device id #2 --- installer.py | 28 +++++++++++++++------------- 1 file changed, 15 insertions(+), 13 deletions(-) diff --git a/installer.py b/installer.py index 6a7453f87..7a03a6a84 100644 --- a/installer.py +++ b/installer.py @@ -479,7 +479,7 @@ def install_rocm_zluda(): amd_gpus = rocm.get_agents() if len(amd_gpus) == 0: (log.info if sys.platform == "win32" else log.warning)('ROCm: no agent was found') - else: + elif args.device_id is None: log.info(f'ROCm: agents={[gpu.name for gpu in amd_gpus]}') hip_default_device = amd_gpus[0] for idx, gpu in enumerate(amd_gpus): @@ -494,22 +494,20 @@ def install_rocm_zluda(): except Exception as e: log.warning(f'ROCm agent enumerator failed: {e}') - if args.device_id is not None: - if os.environ.get('HIP_VISIBLE_DEVICES', None) is not None: - log.warning('Setting HIP_VISIBLE_DEVICES and --device-id at the same time may be mistake.') - device_id = int(args.device_id) - if device_id < len(amd_gpus): - hip_default_device = amd_gpus[device_id] - os.environ['HIP_VISIBLE_DEVICES'] = args.device_id - del args.device_id - - if hip_default_device is not None: - os.environ.setdefault('HSA_OVERRIDE_GFX_VERSION', hip_default_device.get_gfx_version()) - log.info(f'ROCm: version={rocm.version}') torch_command = '' if sys.platform == "win32": # TODO after ROCm for Windows is released + + if args.device_id is not None: + if os.environ.get('HIP_VISIBLE_DEVICES', None) is not None: + log.warning('Setting HIP_VISIBLE_DEVICES and --device-id at the same time may be mistake.') + device_id = int(args.device_id) + if device_id < len(amd_gpus): + hip_default_device = amd_gpus[device_id] + os.environ['HIP_VISIBLE_DEVICES'] = args.device_id + del args.device_id + log.warning("ZLUDA support: experimental") error = None from modules import zluda_installer @@ -559,6 +557,10 @@ def install_rocm_zluda(): if hip_default_device is not None and rocm.version != "6.2" and rocm.version == rocm.version_torch and rocm.get_blaslt_enabled(): log.debug(f'ROCm hipBLASLt: arch={hip_default_device.name} available={hip_default_device.blaslt_supported}') rocm.set_blaslt_enabled(hip_default_device.blaslt_supported) + + if hip_default_device is not None: + os.environ.setdefault('HSA_OVERRIDE_GFX_VERSION', hip_default_device.get_gfx_version()) + return torch_command