From 7933c5e8f709872c0bb157f1b01bb267d18b0acb Mon Sep 17 00:00:00 2001 From: Lysithea <52808607+CaRoLZhangxy@users.noreply.github.com> Date: Thu, 28 Mar 2024 20:20:42 +0800 Subject: [PATCH] pt: support list format batch size (#3614) https://github.com/deepmodeling/deepmd-kit/issues/3475 --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- deepmd/pt/utils/dataloader.py | 38 ++++++++++++++++++++--------------- 1 file changed, 22 insertions(+), 16 deletions(-) diff --git a/deepmd/pt/utils/dataloader.py b/deepmd/pt/utils/dataloader.py index 0359071d71..361bc4b0b6 100644 --- a/deepmd/pt/utils/dataloader.py +++ b/deepmd/pt/utils/dataloader.py @@ -14,6 +14,7 @@ ) import h5py +import numpy as np import torch import torch.distributed as dist import torch.multiprocessing @@ -106,29 +107,34 @@ def construct_dataset(system): self.dataloaders = [] self.batch_sizes = [] - for system in self.systems: + if isinstance(batch_size, str): + if batch_size == "auto": + rule = 32 + elif batch_size.startswith("auto:"): + rule = int(batch_size.split(":")[1]) + else: + rule = None + log.error("Unsupported batch size type") + for ii in self.systems: + ni = ii._natoms + bsi = rule // ni + if bsi * ni < rule: + bsi += 1 + self.batch_sizes.append(bsi) + elif isinstance(batch_size, list): + self.batch_sizes = batch_size + else: + self.batch_sizes = batch_size * np.ones(len(systems), dtype=int) + assert len(self.systems) == len(self.batch_sizes) + for system, batch_size in zip(self.systems, self.batch_sizes): if dist.is_initialized(): system_sampler = DistributedSampler(system) self.sampler_list.append(system_sampler) else: system_sampler = None - if isinstance(batch_size, str): - if batch_size == "auto": - rule = 32 - elif batch_size.startswith("auto:"): - rule = int(batch_size.split(":")[1]) - else: - rule = None - log.error("Unsupported batch size type") - self.batch_size = rule // system._natoms - if self.batch_size * system._natoms < rule: - self.batch_size += 1 - else: - self.batch_size = batch_size - self.batch_sizes.append(self.batch_size) system_dataloader = DataLoader( dataset=system, - batch_size=self.batch_size, + batch_size=int(batch_size), num_workers=0, # Should be 0 to avoid too many threads forked sampler=system_sampler, collate_fn=collate_batch,