Skip to content

Commit

Permalink
pt: support list format batch size (#3614)
Browse files Browse the repository at this point in the history
#3475

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
CaRoLZhangxy and pre-commit-ci[bot] authored Mar 28, 2024
1 parent c2371cd commit 7933c5e
Showing 1 changed file with 22 additions and 16 deletions.
38 changes: 22 additions & 16 deletions deepmd/pt/utils/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
)

import h5py
import numpy as np
import torch
import torch.distributed as dist
import torch.multiprocessing
Expand Down Expand Up @@ -106,29 +107,34 @@ def construct_dataset(system):

self.dataloaders = []
self.batch_sizes = []
for system in self.systems:
if isinstance(batch_size, str):
if batch_size == "auto":
rule = 32
elif batch_size.startswith("auto:"):
rule = int(batch_size.split(":")[1])
else:
rule = None
log.error("Unsupported batch size type")
for ii in self.systems:
ni = ii._natoms
bsi = rule // ni
if bsi * ni < rule:
bsi += 1
self.batch_sizes.append(bsi)
elif isinstance(batch_size, list):
self.batch_sizes = batch_size
else:
self.batch_sizes = batch_size * np.ones(len(systems), dtype=int)
assert len(self.systems) == len(self.batch_sizes)
for system, batch_size in zip(self.systems, self.batch_sizes):
if dist.is_initialized():
system_sampler = DistributedSampler(system)
self.sampler_list.append(system_sampler)
else:
system_sampler = None
if isinstance(batch_size, str):
if batch_size == "auto":
rule = 32
elif batch_size.startswith("auto:"):
rule = int(batch_size.split(":")[1])
else:
rule = None
log.error("Unsupported batch size type")
self.batch_size = rule // system._natoms
if self.batch_size * system._natoms < rule:
self.batch_size += 1
else:
self.batch_size = batch_size
self.batch_sizes.append(self.batch_size)
system_dataloader = DataLoader(
dataset=system,
batch_size=self.batch_size,
batch_size=int(batch_size),
num_workers=0, # Should be 0 to avoid too many threads forked
sampler=system_sampler,
collate_fn=collate_batch,
Expand Down

0 comments on commit 7933c5e

Please sign in to comment.