Skip to content

Commit

Permalink
zluda device id #1
Browse files Browse the repository at this point in the history
  • Loading branch information
lshqqytiger committed Sep 26, 2024
1 parent 82ce34c commit 6af1524
Showing 1 changed file with 21 additions and 17 deletions.
38 changes: 21 additions & 17 deletions installer.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,28 +472,29 @@ def install_rocm_zluda():
os.environ.setdefault('PYTORCH_HIP_ALLOC_CONF', 'garbage_collection_threshold:0.8,max_split_size_mb:512')
# if not is_windows:
# os.environ.setdefault('TENSORFLOW_PACKAGE', 'tensorflow-rocm')

amd_gpus = []
hip_default_device = None
try:
amd_gpus = rocm.get_agents()
log.info(f'ROCm: agents={[gpu.name for gpu in amd_gpus]}')
if len(amd_gpus) == 0:
(log.info if sys.platform == "win32" else log.warning)('ROCm: no agent was found')
else:
log.info(f'ROCm: agents={[gpu.name for gpu in amd_gpus]}')
hip_default_device = amd_gpus[0]
for idx, gpu in enumerate(amd_gpus):
if gpu.arch == rocm.MicroArchitecture.RDNA:
hip_default_device = gpu
log.debug(f'ROCm default agent: idx={idx} gpu={gpu.name}')
os.environ.setdefault('HIP_VISIBLE_DEVICES', str(idx))
# if os.environ.get('TENSORFLOW_PACKAGE') == 'tensorflow-rocm': # do not use tensorflow-rocm for navi 3x
# os.environ['TENSORFLOW_PACKAGE'] = 'tensorflow==2.13.0'
break
log.debug(f'ROCm: HSA_OVERRIDE_GFX_VERSION auto config skipped for {gpu.name}')
except Exception as e:
log.warning(f'ROCm agent enumerator failed: {e}')
amd_gpus = []

hip_default_device = None
if args.device_id is None:
for idx, gpu in enumerate(amd_gpus):
gfx_version = gpu.get_gfx_version()
if gfx_version is None:
log.debug(f'ROCm: HSA_OVERRIDE_GFX_VERSION auto config skipped for {gpu.name}')
else:
hip_default_device = gpu
log.debug(f'ROCm default agent: idx={idx} gpu={gpu.name}')
os.environ.setdefault('HIP_VISIBLE_DEVICES', str(idx))
# if os.environ.get('TENSORFLOW_PACKAGE') == 'tensorflow-rocm': # do not use tensorflow-rocm for navi 3x
# os.environ['TENSORFLOW_PACKAGE'] = 'tensorflow==2.13.0'
os.environ.setdefault('HSA_OVERRIDE_GFX_VERSION', gfx_version)
break
else:
if args.device_id is not None:
if os.environ.get('HIP_VISIBLE_DEVICES', None) is not None:
log.warning('Setting HIP_VISIBLE_DEVICES and --device-id at the same time may be mistake.')
device_id = int(args.device_id)
Expand All @@ -502,6 +503,9 @@ def install_rocm_zluda():
os.environ['HIP_VISIBLE_DEVICES'] = args.device_id
del args.device_id

if hip_default_device is not None:
os.environ.setdefault('HSA_OVERRIDE_GFX_VERSION', hip_default_device.get_gfx_version())

log.info(f'ROCm: version={rocm.version}')
torch_command = ''
if sys.platform == "win32":
Expand Down

0 comments on commit 6af1524

Please sign in to comment.