Skip to content

Commit

Permalink
zluda device id #2
Browse files Browse the repository at this point in the history
  • Loading branch information
lshqqytiger committed Sep 26, 2024
1 parent 6af1524 commit bdcef26
Showing 1 changed file with 15 additions and 13 deletions.
28 changes: 15 additions & 13 deletions installer.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,7 +479,7 @@ def install_rocm_zluda():
amd_gpus = rocm.get_agents()
if len(amd_gpus) == 0:
(log.info if sys.platform == "win32" else log.warning)('ROCm: no agent was found')
else:
elif args.device_id is None:
log.info(f'ROCm: agents={[gpu.name for gpu in amd_gpus]}')
hip_default_device = amd_gpus[0]
for idx, gpu in enumerate(amd_gpus):
Expand All @@ -494,22 +494,20 @@ def install_rocm_zluda():
except Exception as e:
log.warning(f'ROCm agent enumerator failed: {e}')

if args.device_id is not None:
if os.environ.get('HIP_VISIBLE_DEVICES', None) is not None:
log.warning('Setting HIP_VISIBLE_DEVICES and --device-id at the same time may be mistake.')
device_id = int(args.device_id)
if device_id < len(amd_gpus):
hip_default_device = amd_gpus[device_id]
os.environ['HIP_VISIBLE_DEVICES'] = args.device_id
del args.device_id

if hip_default_device is not None:
os.environ.setdefault('HSA_OVERRIDE_GFX_VERSION', hip_default_device.get_gfx_version())

log.info(f'ROCm: version={rocm.version}')
torch_command = ''
if sys.platform == "win32":
# TODO after ROCm for Windows is released

if args.device_id is not None:
if os.environ.get('HIP_VISIBLE_DEVICES', None) is not None:
log.warning('Setting HIP_VISIBLE_DEVICES and --device-id at the same time may be mistake.')
device_id = int(args.device_id)
if device_id < len(amd_gpus):
hip_default_device = amd_gpus[device_id]
os.environ['HIP_VISIBLE_DEVICES'] = args.device_id
del args.device_id

log.warning("ZLUDA support: experimental")
error = None
from modules import zluda_installer
Expand Down Expand Up @@ -559,6 +557,10 @@ def install_rocm_zluda():
if hip_default_device is not None and rocm.version != "6.2" and rocm.version == rocm.version_torch and rocm.get_blaslt_enabled():
log.debug(f'ROCm hipBLASLt: arch={hip_default_device.name} available={hip_default_device.blaslt_supported}')
rocm.set_blaslt_enabled(hip_default_device.blaslt_supported)

if hip_default_device is not None:
os.environ.setdefault('HSA_OVERRIDE_GFX_VERSION', hip_default_device.get_gfx_version())

return torch_command


Expand Down

0 comments on commit bdcef26

Please sign in to comment.