Skip to content

Commit

Permalink
zluda device id #4
Browse files Browse the repository at this point in the history
  • Loading branch information
lshqqytiger committed Sep 26, 2024
1 parent 993a3d7 commit 220878c
Showing 1 changed file with 20 additions and 15 deletions.
35 changes: 20 additions & 15 deletions installer.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,19 +474,22 @@ def install_rocm_zluda():
# os.environ.setdefault('TENSORFLOW_PACKAGE', 'tensorflow-rocm')

amd_gpus = []
hip_default_device = None
device = None
try:
amd_gpus = rocm.get_agents()
if len(amd_gpus) == 0:
(log.info if sys.platform == "win32" else log.warning)('ROCm: no agent was found')
if sys.platform == "win32":
log.warning('You do not have perl or any AMDGPUs. The installer may select a wrong device as compute device.')
log.info('ROCm: no agent was found')
else:
log.warning('ROCm: no agent was found')
else:
log.info(f'ROCm: agents={[gpu.name for gpu in amd_gpus]}')
if args.device_id is None:
hip_default_device = amd_gpus[0]
device = amd_gpus[0]
for idx, gpu in enumerate(amd_gpus):
if gpu.arch == rocm.MicroArchitecture.RDNA:
hip_default_device = gpu
log.debug(f'ROCm default agent: idx={idx} gpu={gpu.name}')
device = gpu
os.environ.setdefault('HIP_VISIBLE_DEVICES', str(idx))
# if os.environ.get('TENSORFLOW_PACKAGE') == 'tensorflow-rocm': # do not use tensorflow-rocm for navi 3x
# os.environ['TENSORFLOW_PACKAGE'] = 'tensorflow==2.13.0'
Expand All @@ -495,12 +498,14 @@ def install_rocm_zluda():
else:
device_id = int(args.device_id)
if device_id < len(amd_gpus):
hip_default_device = amd_gpus[device_id]
log.debug(f'ROCm agent: id={device_id} gpu={hip_default_device.name}')
device = amd_gpus[device_id]
except Exception as e:
log.warning(f'ROCm agent enumerator failed: {e}')

log.info(f'ROCm: version={rocm.version}')
msg = f'ROCm: version={rocm.version}'
if device is not None:
msg += f', using agent {device.name}'
log.info(msg)
torch_command = ''
if sys.platform == "win32":
# TODO after ROCm for Windows is released
Expand All @@ -526,7 +531,7 @@ def install_rocm_zluda():
if error is None:
try:
zluda_installer.load(zluda_path)
torch_command = os.environ.get('TORCH_COMMAND', f'torch=={zluda_installer.get_default_torch_version(hip_default_device)} torchvision --index-url https://download.pytorch.org/whl/cu118')
torch_command = os.environ.get('TORCH_COMMAND', f'torch=={zluda_installer.get_default_torch_version(device)} torchvision --index-url https://download.pytorch.org/whl/cu118')
log.info(f'Using ZLUDA in {zluda_path}')
except Exception as e:
error = e
Expand All @@ -553,16 +558,16 @@ def install_rocm_zluda():
install(ort_package, 'onnxruntime-training')

if 'Flash attention' in opts.get('sdp_options'):
install(rocm.get_flash_attention_command(hip_default_device))
install(rocm.get_flash_attention_command(device))
elif not args.experimental:
uninstall('flash-attn')

if hip_default_device is not None and rocm.version != "6.2" and rocm.version == rocm.version_torch and rocm.get_blaslt_enabled():
log.debug(f'ROCm hipBLASLt: arch={hip_default_device.name} available={hip_default_device.blaslt_supported}')
rocm.set_blaslt_enabled(hip_default_device.blaslt_supported)
if device is not None and rocm.version != "6.2" and rocm.version == rocm.version_torch and rocm.get_blaslt_enabled():
log.debug(f'ROCm hipBLASLt: arch={device.name} available={device.blaslt_supported}')
rocm.set_blaslt_enabled(device.blaslt_supported)

if hip_default_device is not None:
os.environ.setdefault('HSA_OVERRIDE_GFX_VERSION', hip_default_device.get_gfx_version())
if device is not None:
os.environ.setdefault('HSA_OVERRIDE_GFX_VERSION', device.get_gfx_version())

return torch_command

Expand Down

0 comments on commit 220878c

Please sign in to comment.