Skip to content

Commit

Permalink
Bug fix: set gpu_ids properly (#2071)
Browse files Browse the repository at this point in the history
use fisrt GPU of cuda visible devices
  • Loading branch information
eunwoosh authored Apr 26, 2023
1 parent 6e2c710 commit 845e880
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 18 deletions.
7 changes: 1 addition & 6 deletions otx/algorithms/classification/adapters/mmcls/configurer.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,12 +387,7 @@ def configure_device(self, cfg, training):
cfg.distributed = True
self.configure_distributed(cfg)
elif "gpu_ids" not in cfg:
gpu_ids = os.environ.get("CUDA_VISIBLE_DEVICES")
logger.info(f"CUDA_VISIBLE_DEVICES = {gpu_ids}")
if gpu_ids is not None:
cfg.gpu_ids = range(len(gpu_ids.split(",")))
else:
cfg.gpu_ids = range(1)
cfg.gpu_ids = range(1)

# consider "cuda" and "cpu" device only
if not torch.cuda.is_available():
Expand Down
7 changes: 1 addition & 6 deletions otx/algorithms/detection/adapters/mmdet/configurer.py
Original file line number Diff line number Diff line change
Expand Up @@ -549,12 +549,7 @@ def configure_device(self, cfg, training):
cfg.distributed = True
self.configure_distributed(cfg)
elif "gpu_ids" not in cfg:
gpu_ids = os.environ.get("CUDA_VISIBLE_DEVICES")
logger.info(f"CUDA_VISIBLE_DEVICES = {gpu_ids}")
if gpu_ids is not None:
cfg.gpu_ids = range(len(gpu_ids.split(",")))
else:
cfg.gpu_ids = range(1)
cfg.gpu_ids = range(1)

# consider "cuda" and "cpu" device only
if not torch.cuda.is_available():
Expand Down
7 changes: 1 addition & 6 deletions otx/algorithms/segmentation/adapters/mmseg/configurer.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,12 +381,7 @@ def configure_device(self, cfg: Config, training: bool) -> None:
cfg.distributed = True
self.configure_distributed(cfg)
elif "gpu_ids" not in cfg:
gpu_ids = os.environ.get("CUDA_VISIBLE_DEVICES")
logger.info(f"CUDA_VISIBLE_DEVICES = {gpu_ids}")
if gpu_ids is not None:
cfg.gpu_ids = range(len(gpu_ids.split(",")))
else:
cfg.gpu_ids = range(1)
cfg.gpu_ids = range(1)

# consider "cuda" and "cpu" device only
if not torch.cuda.is_available():
Expand Down

0 comments on commit 845e880

Please sign in to comment.