Skip to content

Commit

Permalink
add multi-GPU dataparrallel
Browse files Browse the repository at this point in the history
  • Loading branch information
ilyes319 committed Feb 9, 2023
1 parent f2405d0 commit e8b5bac
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 0 deletions.
14 changes: 14 additions & 0 deletions mace/tools/torch_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,3 +119,17 @@ def init_wandb(project: str, entity: str, name: str, config: dict):
import wandb

wandb.init(project=project, entity=entity, name=name, config=config)

class DataParallelModel(torch.nn.Module):
def __init__(self, model):
super(DataParallelModel, self).__init__()
self.model = torch.nn.DataParallel(model).cuda()

def forward(self, batch, training):
return self.model(batch, training=training)

def __getattr__(self, name):
try:
return super().__getattr__(name)
except AttributeError:
return getattr(self.model.module, name)
4 changes: 4 additions & 0 deletions scripts/run_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,10 @@ def main() -> None:
else:
raise RuntimeError(f"Unknown model: '{args.model}'")


if torch.cuda.device_count() > 1:
logging.info(f"Multi-GPUs training on {torch.cuda.device_count()} GPUs.")
model = tools.DataParallelModel(model)
model.to(device)

# Optimizer
Expand Down

0 comments on commit e8b5bac

Please sign in to comment.