From e552cc51d2af0f95de579e00a561d23a6b13e319 Mon Sep 17 00:00:00 2001 From: Ilyes Batatia <48651863+ilyes319@users.noreply.github.com> Date: Thu, 6 Apr 2023 12:56:58 +0100 Subject: [PATCH 1/3] implement on the fly graph creation --- mace/data/__init__.py | 6 +- mace/data/hdf5_dataset.py | 153 ++++++++++++++++++++++++++----------- mace/data/utils.py | 83 ++++++++++++-------- scripts/preprocess_data.py | 31 ++++---- scripts/run_train.py | 66 ++++++++++------ 5 files changed, 222 insertions(+), 117 deletions(-) diff --git a/mace/data/__init__.py b/mace/data/__init__.py index 59b05a8b..e2ce912c 100644 --- a/mace/data/__init__.py +++ b/mace/data/__init__.py @@ -8,9 +8,10 @@ config_from_atoms_list, load_from_xyz, random_train_valid_split, + save_configurations_as_HDF5, test_config_types, save_dataset_as_HDF5, - save_AtomicData_to_HDF5 + save_AtomicData_to_HDF5, ) from .hdf5_dataset import HDF5Dataset @@ -27,5 +28,6 @@ "compute_average_E0s", "save_dataset_as_HDF5", "HDF5Dataset", - "save_AtomicData_to_HDF5" + "save_AtomicData_to_HDF5", + "save_configurations_as_HDF5", ] diff --git a/mace/data/hdf5_dataset.py b/mace/data/hdf5_dataset.py index 107c5ddb..53500bf7 100644 --- a/mace/data/hdf5_dataset.py +++ b/mace/data/hdf5_dataset.py @@ -1,59 +1,124 @@ import h5py import torch from torch.utils.data import Dataset +from mace import data from mace.data import AtomicData +from mace.data.utils import Configuration -class HDF5Dataset(Dataset): - def __init__(self, file, **kwargs): - super(HDF5Dataset, self).__init__() +import h5py +import torch +from torch.utils.data import Dataset, IterableDataset, ChainDataset +from mace import data +from mace.data import AtomicData +from mace.data.utils import Configuration + + +class HDF5ChainDataset(ChainDataset): + def __init__(self, file, r_max, z_table, **kwargs): + super(HDF5ChainDataset, self).__init__() + self.file = file + self.length = len(h5py.File(file, "r").keys()) + self.r_max = r_max + self.z_table = z_table + + def __call__(self): + self.file = h5py.File(self.file, "r") + datasets = [] + for i in range(self.length): + grp = self.file["config_" + str(i)] + datasets.append( + HDF5IterDataset(iter_group=grp, r_max=self.r_max, z_table=self.z_table,) + ) + return ChainDataset(datasets) + + +class HDF5IterDataset(IterableDataset): + def __init__(self, iter_group, r_max, z_table, **kwargs): + super(HDF5IterDataset, self).__init__() # it might be dangerous to open the file here # move opening of file to __getitem__? - self.file = h5py.File(file, 'r') - self.length = len(self.file.keys()) + self.iter_group = iter_group + self.length = len(self.iter_group.keys()) + self.r_max = r_max + self.z_table = z_table # self.file = file # self.length = len(h5py.File(file, 'r').keys()) def __len__(self): return self.length + def __iter__(self): + # file = h5py.File(self.file, 'r') + # grp = file["config_" + str(index)] + grp = self.iter_group + len_subgrp = len(grp.keys()) + grp_list = [] + for i in range(len_subgrp): + subgrp = grp["config_" + str(i)] + config = Configuration( + atomic_numbers=subgrp["atomic_numbers"][()], + positions=subgrp["positions"][()], + energy=subgrp["energy"][()], + forces=subgrp["forces"][()], + stress=subgrp["stress"][()], + virials=subgrp["virials"][()], + dipole=subgrp["dipole"][()], + charges=subgrp["charges"][()], + weight=subgrp["weight"][()], + energy_weight=subgrp["energy_weight"][()], + forces_weight=subgrp["forces_weight"][()], + stress_weight=subgrp["stress_weight"][()], + virials_weight=subgrp["virials_weight"][()], + config_type=subgrp["config_type"][()], + pbc=subgrp["pbc"][()], + cell=subgrp["cell"][()], + ) + atomic_data = data.AtomicData.from_config( + config, z_table=self.z_table, cutoff=self.r_max + ) + grp_list.append(atomic_data) + + return iter(grp_list) + + +class HDF5Dataset(Dataset): + def __init__(self, file, r_max, z_table, **kwargs): + super(HDF5Dataset, self).__init__() + self.file = h5py.File(file, "r") # this is dangerous to open the file here + self.batch_size = len(self.file["config_0"].keys()) + self.length = len(self.file.keys()) * len(self.file["config_0"].keys()) + self.r_max = r_max + self.z_table = z_table + + def __len__(self): + return self.length + def __getitem__(self, index): - # file = h5py.File(self.file, 'r') - # grp = file["config_" + str(index)] - grp = self.file["config_" + str(index)] - edge_index = grp['edge_index'][()] - positions = grp['positions'][()] - shifts = grp['shifts'][()] - unit_shifts = grp['unit_shifts'][()] - cell = grp['cell'][()] - node_attrs = grp['node_attrs'][()] - weight = grp['weight'][()] - energy_weight = grp['energy_weight'][()] - forces_weight = grp['forces_weight'][()] - stress_weight = grp['stress_weight'][()] - virials_weight = grp['virials_weight'][()] - forces = grp['forces'][()] - energy = grp['energy'][()] - stress = grp['stress'][()] - virials = grp['virials'][()] - dipole = grp['dipole'][()] - charges = grp['charges'][()] - return AtomicData( - edge_index = torch.tensor(edge_index, dtype=torch.long), - positions = torch.tensor(positions, dtype=torch.get_default_dtype()), - shifts = torch.tensor(shifts, dtype=torch.get_default_dtype()), - unit_shifts=torch.tensor(unit_shifts, dtype=torch.get_default_dtype()), - cell=torch.tensor(cell, dtype=torch.get_default_dtype()), - node_attrs=torch.tensor(node_attrs, dtype=torch.get_default_dtype()), - weight=torch.tensor(weight, dtype=torch.get_default_dtype()), - energy_weight=torch.tensor(energy_weight, dtype=torch.get_default_dtype()), - forces_weight=torch.tensor(forces_weight, dtype=torch.get_default_dtype()), - stress_weight=torch.tensor(stress_weight, dtype=torch.get_default_dtype()), - virials_weight=torch.tensor(virials_weight, dtype=torch.get_default_dtype()), - forces=torch.tensor(forces, dtype=torch.get_default_dtype()), - energy=torch.tensor(energy, dtype=torch.get_default_dtype()), - stress=torch.tensor(stress, dtype=torch.get_default_dtype()), - virials=torch.tensor(virials, dtype=torch.get_default_dtype()), - dipole=torch.tensor(dipole, dtype=torch.get_default_dtype()), - charges=torch.tensor(charges, dtype=torch.get_default_dtype()), + # compute the index of the batch + batch_index = index // self.batch_size + config_index = index % self.batch_size + grp = self.file["config_batch" + str(batch_index)] + subgrp = grp["config_" + str(config_index)] + config = Configuration( + atomic_numbers=subgrp["atomic_numbers"][()], + positions=subgrp["positions"][()], + energy=subgrp["energy"][()], + forces=subgrp["forces"][()], + stress=subgrp["stress"][()], + virials=subgrp["virials"][()], + dipole=subgrp["dipole"][()], + charges=subgrp["charges"][()], + weight=subgrp["weight"][()], + energy_weight=subgrp["energy_weight"][()], + forces_weight=subgrp["forces_weight"][()], + stress_weight=subgrp["stress_weight"][()], + virials_weight=subgrp["virials_weight"][()], + config_type=subgrp["config_type"][()], + pbc=subgrp["pbc"][()], + cell=subgrp["cell"][()], ) - \ No newline at end of file + atomic_data = data.AtomicData.from_config( + config, z_table=self.z_table, cutoff=self.r_max + ) + return atomic_data + diff --git a/mace/data/utils.py b/mace/data/utils.py index 75fac945..7e16aaf8 100644 --- a/mace/data/utils.py +++ b/mace/data/utils.py @@ -152,7 +152,7 @@ def config_from_atoms( virials_weight = 0.0 if dipole is None: dipole = np.zeros(3) - #dipoles_weight = 0.0 + # dipoles_weight = 0.0 return Configuration( atomic_numbers=atomic_numbers, @@ -271,35 +271,33 @@ def compute_average_E0s( atomic_energies_dict[z] = 0.0 return atomic_energies_dict -def save_dataset_as_HDF5( - dataset:List, out_name: str - ) -> None: - with h5py.File(out_name, 'w') as f: - for i, data in enumerate(dataset): - grp = f.create_group(f'config_{i}') - grp["num_nodes"] = data.num_nodes - grp["edge_index"] = data.edge_index - grp["positions"] = data.positions - grp["shifts"] = data.shifts - grp["unit_shifts"] = data.unit_shifts - grp["cell"] = data.cell - grp["node_attrs"] = data.node_attrs - grp["weight"] = data.weight - grp["energy_weight"] = data.energy_weight - grp["forces_weight"] = data.forces_weight - grp["stress_weight"] = data.stress_weight - grp["virials_weight"] = data.virials_weight - grp["forces"] = data.forces - grp["energy"] = data.energy - grp["stress"] = data.stress - grp["virials"] = data.virials - grp["dipole"] = data.dipole - grp["charges"] = data.charges - -def save_AtomicData_to_HDF5( - data, i, h5_file - ) -> None: - grp = h5_file.create_group(f'config_{i}') + +def save_dataset_as_HDF5(dataset: List, out_name: str) -> None: + with h5py.File(out_name, "w") as f: + for i, data in enumerate(dataset): + grp = f.create_group(f"config_{i}") + grp["num_nodes"] = data.num_nodes + grp["edge_index"] = data.edge_index + grp["positions"] = data.positions + grp["shifts"] = data.shifts + grp["unit_shifts"] = data.unit_shifts + grp["cell"] = data.cell + grp["node_attrs"] = data.node_attrs + grp["weight"] = data.weight + grp["energy_weight"] = data.energy_weight + grp["forces_weight"] = data.forces_weight + grp["stress_weight"] = data.stress_weight + grp["virials_weight"] = data.virials_weight + grp["forces"] = data.forces + grp["energy"] = data.energy + grp["stress"] = data.stress + grp["virials"] = data.virials + grp["dipole"] = data.dipole + grp["charges"] = data.charges + + +def save_AtomicData_to_HDF5(data, i, h5_file) -> None: + grp = h5_file.create_group(f"config_{i}") grp["num_nodes"] = data.num_nodes grp["edge_index"] = data.edge_index grp["positions"] = data.positions @@ -318,3 +316,28 @@ def save_AtomicData_to_HDF5( grp["virials"] = data.virials grp["dipole"] = data.dipole grp["charges"] = data.charges + + +def save_configurations_as_HDF5(configurations: Configurations, out_name: str) -> None: + with h5py.File(out_name, "w") as f: + grp = f.create_group(f"config_batch{i}") + for i, config in enumerate(configurations): + subgroup_name = f"config_{i}" + subgroup = grp.create_group(subgroup_name) + subgroup["atomic_numbers"] = config.atomic_numbers + subgroup["positions"] = config.positions + subgroup["energy"] = config.energy + subgroup["forces"] = config.forces + subgroup["stress"] = config.stress + subgroup["virials"] = config.virials + subgroup["dipole"] = config.dipole + subgroup["charges"] = config.charges + subgroup["cell"] = config.cell + subgroup["pbc"] = config.pbc + subgroup["weight"] = config.weight + subgroup["energy_weight"] = config.energy_weight + subgroup["forces_weight"] = config.forces_weight + subgroup["stress_weight"] = config.stress_weight + subgroup["virials_weight"] = config.virials_weight + subgroup["config_type"] = config.config_type + diff --git a/scripts/preprocess_data.py b/scripts/preprocess_data.py index eed3db81..000afd54 100644 --- a/scripts/preprocess_data.py +++ b/scripts/preprocess_data.py @@ -12,7 +12,7 @@ import torch from mace import tools, data -from mace.data.utils import save_AtomicData_to_HDF5 #, save_dataset_as_HDF5 +from mace.data.utils import save_AtomicData_to_HDF5, save_configurations_as_HDF5 #, save_dataset_as_HDF5 from mace.tools.scripts_utils import (get_dataset_from_xyz, get_atomic_energies) from mace.tools import torch_geometric @@ -91,10 +91,10 @@ def main(): random.shuffle(collections.train) with h5py.File(args.h5_prefix + "train.h5", "w") as f: - for i, config in enumerate(collections.train): - atomic_data = data.AtomicData.from_config( - config, z_table=z_table, cutoff=args.r_max) - save_AtomicData_to_HDF5(atomic_data, i, f) + # split collections.train into batches and save them to hdf5 + split_train = np.array_split(collections.train, args.batch_size) + for i, batch in enumerate(split_train): + save_configurations_as_HDF5(batch, f, f"batch_{i}") if args.compute_statistics: @@ -106,8 +106,9 @@ def main(): [atomic_energies_dict[z] for z in z_table.zs] ) logging.info(f"Atomic energies: {atomic_energies.tolist()}") + train_dataset = data.HDF5Dataset(args.h5_prefix + "train.h5", z_table=z_table, r_max=args.r_max) train_loader = torch_geometric.dataloader.DataLoader( - data.HDF5Dataset(args.h5_prefix + "train.h5"), + dataset=train_dataset, batch_size=args.batch_size, shuffle=False, drop_last=False, @@ -136,19 +137,17 @@ def main(): random.shuffle(collections.valid) with h5py.File(args.h5_prefix + "valid.h5", "w") as f: - for i, config in enumerate(collections.valid): - atomic_data = data.AtomicData.from_config( - config, z_table=z_table, cutoff=args.r_max) - save_AtomicData_to_HDF5(atomic_data, i, f) + split_valid = np.array_split(collections.valid, args.batch_size) + for i, batch in enumerate(split_valid): + save_configurations_as_HDF5(batch, f, f"batch_{i}") if args.test_file is not None: logging.info("Preparing test sets") - for name, subset in collections.tests: - with h5py.File(args.h5_prefix + name + "_test.h5", "w") as f: - for i, config in enumerate(subset): - atomic_data = data.AtomicData.from_config( - config, z_table=z_table, cutoff=args.r_max) - save_AtomicData_to_HDF5(atomic_data, i, f) + with h5py.File(args.h5_prefix + name + "_test.h5", "w") as f: + for name, subset in collections.tests: + split_test = np.array_split(subset, args.batch_size) + for i, batch in enumerate(split_test): + save_configurations_as_HDF5(batch, f, f"batch_{i}") if __name__ == "__main__": diff --git a/scripts/run_train.py b/scripts/run_train.py index d9bb6f86..a80b7e0f 100644 --- a/scripts/run_train.py +++ b/scripts/run_train.py @@ -20,14 +20,17 @@ import mace from mace import data, modules, tools from mace.tools import torch_geometric -from mace.tools.scripts_utils import (create_error_table, - get_dataset_from_xyz, - get_atomic_energies, - get_config_type_weights, - get_loss_fn, - get_files_with_suffix) +from mace.tools.scripts_utils import ( + create_error_table, + get_dataset_from_xyz, + get_atomic_energies, + get_config_type_weights, + get_loss_fn, + get_files_with_suffix, +) from mace.data import HDF5Dataset + def main() -> None: args = tools.build_default_arg_parser().parse_args() tag = tools.get_tag(name=args.name, seed=args.seed) @@ -60,7 +63,9 @@ def main() -> None: # Data preparation if args.train_file.endswith(".xyz"): if args.valid_file is not None: - assert args.valid_file.endswith(".xyz"), "valid_file if given must be same format as train_file" + assert args.valid_file.endswith( + ".xyz" + ), "valid_file if given must be same format as train_file" collections, atomic_energies_dict = get_dataset_from_xyz( train_path=args.train_file, valid_path=args.valid_file, @@ -86,7 +91,7 @@ def main() -> None: raise RuntimeError( f"train_file must be either .xyz or .h5, got {args.train_file}" ) - + # Atomic number table # yapf: disable if args.atomic_numbers is None: @@ -107,7 +112,9 @@ def main() -> None: if atomic_energies_dict is None or len(atomic_energies_dict) == 0: if args.train_file.endswith(".xyz"): - atomic_energies_dict = get_atomic_energies(args.E0s, collections.train, z_table) + atomic_energies_dict = get_atomic_energies( + args.E0s, collections.train, z_table + ) else: atomic_energies_dict = get_atomic_energies(args.E0s, None, z_table) @@ -159,23 +166,29 @@ def main() -> None: num_workers=args.num_workers, ) else: - training_set_processed = HDF5Dataset(args.train_file) + training_set_processed = HDF5Dataset( + args.train_file, r_max=args.r_max, z_table=z_table + ) train_loader = torch_geometric.dataloader.DataLoader( training_set_processed, batch_size=args.batch_size, shuffle=True, drop_last=True, num_workers=args.num_workers, - pin_memory=args.pin_memory) + pin_memory=args.pin_memory, + ) - validation_set_processed = HDF5Dataset(args.valid_file) + validation_set_processed = HDF5Dataset( + args.valid_file, r_max=args.r_max, z_table=z_table + ) valid_loader = torch_geometric.dataloader.DataLoader( validation_set_processed, batch_size=args.valid_batch_size, shuffle=False, drop_last=False, num_workers=args.num_workers, - pin_memory=args.pin_memory) + pin_memory=args.pin_memory, + ) loss_fn: torch.nn.Module = get_loss_fn( args.loss, @@ -245,7 +258,9 @@ def main() -> None: args.std = 1.0 logging.info("No scaling selected") elif args.mean is None or args.std is None: - args.mean, args.std = modules.scaling_classes[args.scaling](train_loader, atomic_energies) + args.mean, args.std = modules.scaling_classes[args.scaling]( + train_loader, atomic_energies + ) if args.model == "MACE": model = modules.ScaleShiftMACE( @@ -322,7 +337,6 @@ def main() -> None: else: raise RuntimeError(f"Unknown model: '{args.model}'") - if torch.cuda.device_count() > 1: logging.info(f"Multi-GPUs training on {torch.cuda.device_count()} GPUs.") model = tools.DataParallelModel(model) @@ -477,6 +491,7 @@ def main() -> None: if args.wandb: logging.info("Using Weights and Biases for logging") import wandb + wandb_config = {} args_dict = vars(args) args_dict_json = json.dumps(args_dict) @@ -514,14 +529,15 @@ def main() -> None: logging.info("Computing metrics for training, validation, and test sets") all_data_loaders = { - "train": train_loader, - "valid": valid_loader, + "train": train_loader, + "valid": valid_loader, } if args.train_file.endswith(".xyz"): for name, subset in collections.tests: - test_set = [data.AtomicData.from_config( - config, z_table=z_table, cutoff=args.r_max) - for config in subset] + test_set = [ + data.AtomicData.from_config(config, z_table=z_table, cutoff=args.r_max) + for config in subset + ] test_loader = torch_geometric.dataloader.DataLoader( test_set, batch_size=args.valid_batch_size, @@ -532,18 +548,17 @@ def main() -> None: all_data_loaders[name] = test_loader else: # get all test paths - test_files = get_files_with_suffix( - args.test_dir, "_test.h5" - ) + test_files = get_files_with_suffix(args.test_dir, "_test.h5") for test_file in test_files: - test_set = HDF5Dataset(test_file) + test_set = HDF5Dataset(test_file, z_table=z_table, cutoff=args.r_max) test_loader = torch_geometric.dataloader.DataLoader( test_set, batch_size=args.valid_batch_size, shuffle=False, drop_last=False, num_workers=args.num_workers, - pin_memory=args.pin_memory) + pin_memory=args.pin_memory, + ) test_file_name = os.path.splitext(os.path.basename(test_file))[0] all_data_loaders[test_file_name] = test_loader @@ -584,5 +599,6 @@ def main() -> None: logging.info("Done") + if __name__ == "__main__": main() From 5d9003020440b683498054a5750c61c80772f8b0 Mon Sep 17 00:00:00 2001 From: Ilyes Batatia <48651863+ilyes319@users.noreply.github.com> Date: Thu, 6 Apr 2023 13:05:35 +0100 Subject: [PATCH 2/3] add hf5 test --- tests/test_data.py | 82 +++++++++++++++++++++------------------------- 1 file changed, 38 insertions(+), 44 deletions(-) diff --git a/tests/test_data.py b/tests/test_data.py index 17a6c9de..d254a3d8 100644 --- a/tests/test_data.py +++ b/tests/test_data.py @@ -1,27 +1,25 @@ import ase.build import numpy as np -from mace.data import AtomicData, Configuration, config_from_atoms, get_neighborhood +from mace.data import ( + AtomicData, + Configuration, + config_from_atoms, + get_neighborhood, + save_configurations_as_HDF5, + HDF5Dataset, +) +from pathlib import Path from mace.tools import AtomicNumberTable, torch_geometric +mace_path = Path(__file__).parent.parent + class TestAtomicData: config = Configuration( atomic_numbers=np.array([8, 1, 1]), - positions=np.array( - [ - [0.0, -2.0, 0.0], - [1.0, 0.0, 0.0], - [0.0, 1.0, 0.0], - ] - ), - forces=np.array( - [ - [0.0, -1.3, 0.0], - [1.0, 0.2, 0.0], - [0.0, 1.1, 0.3], - ] - ), + positions=np.array([[0.0, -2.0, 0.0], [1.0, 0.0, 0.0], [0.0, 1.0, 0.0],]), + forces=np.array([[0.0, -1.3, 0.0], [1.0, 0.2, 0.0], [0.0, 1.1, 0.3],]), energy=-1.5, ) @@ -39,10 +37,7 @@ def test_data_loader(self): data2 = AtomicData.from_config(self.config, z_table=self.table, cutoff=3.0) data_loader = torch_geometric.dataloader.DataLoader( - dataset=[data1, data2], - batch_size=2, - shuffle=True, - drop_last=False, + dataset=[data1, data2], batch_size=2, shuffle=True, drop_last=False, ) for batch in data_loader: @@ -59,10 +54,7 @@ def test_to_atomic_data_dict(self): data2 = AtomicData.from_config(self.config, z_table=self.table, cutoff=3.0) data_loader = torch_geometric.dataloader.DataLoader( - dataset=[data1, data2], - batch_size=2, - shuffle=True, - drop_last=False, + dataset=[data1, data2], batch_size=2, shuffle=True, drop_last=False, ) for batch in data_loader: batch_dict = batch.to_dict() @@ -74,16 +66,29 @@ def test_to_atomic_data_dict(self): assert batch_dict["energy"].shape == (2,) assert batch_dict["forces"].shape == (6, 3) + def test_hdf5_dataloader(self): + datasets = [self.config] * 10 + # get path of the mace package + save_configurations_as_HDF5(datasets, mace_path + "test.h5") + train_dataset = HDF5Dataset( + mace_path + "test.h5", z_table=self.table, r_max=3.0 + ) + train_loader = torch_geometric.dataloader.DataLoader( + dataset=train_dataset, batch_size=2, shuffle=False, drop_last=False, + ) + for batch in train_loader: + assert batch.batch.shape == (6,) + assert batch.edge_index.shape == (2, 8) + assert batch.shifts.shape == (8, 3) + assert batch.positions.shape == (6, 3) + assert batch.node_attrs.shape == (6, 2) + assert batch.energy.shape == (2,) + assert batch.forces.shape == (6, 3) + class TestNeighborhood: def test_basic(self): - positions = np.array( - [ - [-1.0, 0.0, 0.0], - [+0.0, 0.0, 0.0], - [+1.0, 0.0, 0.0], - ] - ) + positions = np.array([[-1.0, 0.0, 0.0], [+0.0, 0.0, 0.0], [+1.0, 0.0, 0.0],]) indices, shifts, unit_shifts = get_neighborhood(positions, cutoff=1.5) assert indices.shape == (2, 4) @@ -91,12 +96,7 @@ def test_basic(self): assert unit_shifts.shape == (4, 3) def test_signs(self): - positions = np.array( - [ - [+0.5, 0.5, 0.0], - [+1.0, 1.0, 0.0], - ] - ) + positions = np.array([[+0.5, 0.5, 0.0], [+1.0, 1.0, 0.0],]) cell = np.array([[2.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]) edge_index, shifts, unit_shifts = get_neighborhood( @@ -121,10 +121,7 @@ def test_periodic_edge(): config.positions[receiver] - config.positions[sender] + shifts ) # [n_edges, 3] assert vectors.shape == (12, 3) # 12 neighbors in close-packed bulk - assert np.allclose( - np.linalg.norm(vectors, axis=-1), - dist, - ) + assert np.allclose(np.linalg.norm(vectors, axis=-1), dist,) def test_half_periodic(): @@ -142,7 +139,4 @@ def test_half_periodic(): _, neighbor_count = np.unique(edge_index[0], return_counts=True) assert (neighbor_count == 6).all() # 6 neighbors # Check not periodic in z - assert np.allclose( - vectors[:, 2], - np.zeros(vectors.shape[0]), - ) + assert np.allclose(vectors[:, 2], np.zeros(vectors.shape[0]),) From 511f591afea885a6e433d780d5cd0bab351840d3 Mon Sep 17 00:00:00 2001 From: Ilyes Batatia <48651863+ilyes319@users.noreply.github.com> Date: Thu, 6 Apr 2023 15:09:50 +0100 Subject: [PATCH 3/3] fix stuff --- mace/data/hdf5_dataset.py | 38 ++++++++++++++++-------------- mace/data/utils.py | 47 ++++++++++++++++++++------------------ scripts/preprocess_data.py | 8 ++++--- tests/test_data.py | 32 +++++++++++++++++++++++--- 4 files changed, 80 insertions(+), 45 deletions(-) diff --git a/mace/data/hdf5_dataset.py b/mace/data/hdf5_dataset.py index 53500bf7..6ceeb7ed 100644 --- a/mace/data/hdf5_dataset.py +++ b/mace/data/hdf5_dataset.py @@ -85,8 +85,8 @@ class HDF5Dataset(Dataset): def __init__(self, file, r_max, z_table, **kwargs): super(HDF5Dataset, self).__init__() self.file = h5py.File(file, "r") # this is dangerous to open the file here - self.batch_size = len(self.file["config_0"].keys()) - self.length = len(self.file.keys()) * len(self.file["config_0"].keys()) + self.batch_size = len(self.file["config_batch_0"].keys()) + self.length = len(self.file.keys()) * self.batch_size self.r_max = r_max self.z_table = z_table @@ -97,28 +97,32 @@ def __getitem__(self, index): # compute the index of the batch batch_index = index // self.batch_size config_index = index % self.batch_size - grp = self.file["config_batch" + str(batch_index)] + grp = self.file["config_batch_" + str(batch_index)] subgrp = grp["config_" + str(config_index)] config = Configuration( atomic_numbers=subgrp["atomic_numbers"][()], positions=subgrp["positions"][()], - energy=subgrp["energy"][()], - forces=subgrp["forces"][()], - stress=subgrp["stress"][()], - virials=subgrp["virials"][()], - dipole=subgrp["dipole"][()], - charges=subgrp["charges"][()], - weight=subgrp["weight"][()], - energy_weight=subgrp["energy_weight"][()], - forces_weight=subgrp["forces_weight"][()], - stress_weight=subgrp["stress_weight"][()], - virials_weight=subgrp["virials_weight"][()], - config_type=subgrp["config_type"][()], - pbc=subgrp["pbc"][()], - cell=subgrp["cell"][()], + energy=unpack_value(subgrp["energy"][()]), + forces=unpack_value(subgrp["forces"][()]), + stress=unpack_value(subgrp["stress"][()]), + virials=unpack_value(subgrp["virials"][()]), + dipole=unpack_value(subgrp["dipole"][()]), + charges=unpack_value(subgrp["charges"][()]), + weight=unpack_value(subgrp["weight"][()]), + energy_weight=unpack_value(subgrp["energy_weight"][()]), + forces_weight=unpack_value(subgrp["forces_weight"][()]), + stress_weight=unpack_value(subgrp["stress_weight"][()]), + virials_weight=unpack_value(subgrp["virials_weight"][()]), + config_type=unpack_value(subgrp["config_type"][()]), + pbc=unpack_value(subgrp["pbc"][()]), + cell=unpack_value(subgrp["cell"][()]), ) atomic_data = data.AtomicData.from_config( config, z_table=self.z_table, cutoff=self.r_max ) return atomic_data + +def unpack_value(value): + value = value.decode("utf-8") if isinstance(value, bytes) else value + return None if str(value) == "None" else value diff --git a/mace/data/utils.py b/mace/data/utils.py index 7e16aaf8..5a8a0c12 100644 --- a/mace/data/utils.py +++ b/mace/data/utils.py @@ -318,26 +318,29 @@ def save_AtomicData_to_HDF5(data, i, h5_file) -> None: grp["charges"] = data.charges -def save_configurations_as_HDF5(configurations: Configurations, out_name: str) -> None: - with h5py.File(out_name, "w") as f: - grp = f.create_group(f"config_batch{i}") - for i, config in enumerate(configurations): - subgroup_name = f"config_{i}" - subgroup = grp.create_group(subgroup_name) - subgroup["atomic_numbers"] = config.atomic_numbers - subgroup["positions"] = config.positions - subgroup["energy"] = config.energy - subgroup["forces"] = config.forces - subgroup["stress"] = config.stress - subgroup["virials"] = config.virials - subgroup["dipole"] = config.dipole - subgroup["charges"] = config.charges - subgroup["cell"] = config.cell - subgroup["pbc"] = config.pbc - subgroup["weight"] = config.weight - subgroup["energy_weight"] = config.energy_weight - subgroup["forces_weight"] = config.forces_weight - subgroup["stress_weight"] = config.stress_weight - subgroup["virials_weight"] = config.virials_weight - subgroup["config_type"] = config.config_type +def save_configurations_as_HDF5(configurations: Configurations, i, h5_file) -> None: + grp = h5_file.create_group(f"config_batch_{i}") + for i, config in enumerate(configurations): + subgroup_name = f"config_{i}" + subgroup = grp.create_group(subgroup_name) + subgroup["atomic_numbers"] = write_value(config.atomic_numbers) + subgroup["positions"] = write_value(config.positions) + subgroup["energy"] = write_value(config.energy) + subgroup["forces"] = write_value(config.forces) + subgroup["stress"] = write_value(config.stress) + subgroup["virials"] = write_value(config.virials) + subgroup["dipole"] = write_value(config.dipole) + subgroup["charges"] = write_value(config.charges) + subgroup["cell"] = write_value(config.cell) + subgroup["pbc"] = write_value(config.pbc) + subgroup["weight"] = write_value(config.weight) + subgroup["energy_weight"] = write_value(config.energy_weight) + subgroup["forces_weight"] = write_value(config.forces_weight) + subgroup["stress_weight"] = write_value(config.stress_weight) + subgroup["virials_weight"] = write_value(config.virials_weight) + subgroup["config_type"] = write_value(config.config_type) + + +def write_value(value): + return value if value is not None else "None" diff --git a/scripts/preprocess_data.py b/scripts/preprocess_data.py index 000afd54..ff14df4b 100644 --- a/scripts/preprocess_data.py +++ b/scripts/preprocess_data.py @@ -94,7 +94,7 @@ def main(): # split collections.train into batches and save them to hdf5 split_train = np.array_split(collections.train, args.batch_size) for i, batch in enumerate(split_train): - save_configurations_as_HDF5(batch, f, f"batch_{i}") + save_configurations_as_HDF5(batch, i, f) if args.compute_statistics: @@ -129,6 +129,8 @@ def main(): "atomic_numbers": str(z_table.zs), "r_max": args.r_max, } + del train_dataset + del train_loader with open(args.h5_prefix + "statistics.json", "w") as f: json.dump(statistics, f) @@ -139,7 +141,7 @@ def main(): with h5py.File(args.h5_prefix + "valid.h5", "w") as f: split_valid = np.array_split(collections.valid, args.batch_size) for i, batch in enumerate(split_valid): - save_configurations_as_HDF5(batch, f, f"batch_{i}") + save_configurations_as_HDF5(batch, i, f) if args.test_file is not None: logging.info("Preparing test sets") @@ -147,7 +149,7 @@ def main(): for name, subset in collections.tests: split_test = np.array_split(subset, args.batch_size) for i, batch in enumerate(split_test): - save_configurations_as_HDF5(batch, f, f"batch_{i}") + save_configurations_as_HDF5(batch, i, f) if __name__ == "__main__": diff --git a/tests/test_data.py b/tests/test_data.py index d254a3d8..fc8a2325 100644 --- a/tests/test_data.py +++ b/tests/test_data.py @@ -1,5 +1,7 @@ +from copy import deepcopy import ase.build import numpy as np +import torch from mace.data import ( AtomicData, @@ -9,6 +11,7 @@ save_configurations_as_HDF5, HDF5Dataset, ) +import h5py from pathlib import Path from mace.tools import AtomicNumberTable, torch_geometric @@ -22,6 +25,8 @@ class TestAtomicData: forces=np.array([[0.0, -1.3, 0.0], [1.0, 0.2, 0.0], [0.0, 1.1, 0.3],]), energy=-1.5, ) + config_2 = deepcopy(config) + config_2.positions = config.positions + 0.01 table = AtomicNumberTable([1, 8]) @@ -67,16 +72,19 @@ def test_to_atomic_data_dict(self): assert batch_dict["forces"].shape == (6, 3) def test_hdf5_dataloader(self): - datasets = [self.config] * 10 + datasets = [self.config, self.config_2] * 5 # get path of the mace package - save_configurations_as_HDF5(datasets, mace_path + "test.h5") + with h5py.File(str(mace_path) + "test.h5", "w") as f: + save_configurations_as_HDF5(datasets, 0, f) train_dataset = HDF5Dataset( - mace_path + "test.h5", z_table=self.table, r_max=3.0 + str(mace_path) + "test.h5", z_table=self.table, r_max=3.0 ) train_loader = torch_geometric.dataloader.DataLoader( dataset=train_dataset, batch_size=2, shuffle=False, drop_last=False, ) + batch_count = 0 for batch in train_loader: + batch_count += 1 assert batch.batch.shape == (6,) assert batch.edge_index.shape == (2, 8) assert batch.shifts.shape == (8, 3) @@ -84,6 +92,24 @@ def test_hdf5_dataloader(self): assert batch.node_attrs.shape == (6, 2) assert batch.energy.shape == (2,) assert batch.forces.shape == (6, 3) + print(batch_count, len(train_loader), len(train_dataset)) + assert batch_count == len(train_loader) == len(train_dataset) / 2 + train_loader_direct = torch_geometric.dataloader.DataLoader( + dataset=[ + AtomicData.from_config(config, z_table=self.table, cutoff=3.0) + for config in datasets + ], + batch_size=2, + shuffle=False, + drop_last=False, + ) + for batch_direct, batch in zip(train_loader_direct, train_loader): + assert torch.all(batch_direct.edge_index == batch.edge_index) + assert torch.all(batch_direct.shifts == batch.shifts) + assert torch.all(batch_direct.positions == batch.positions) + assert torch.all(batch_direct.node_attrs == batch.node_attrs) + assert torch.all(batch_direct.energy == batch.energy) + assert torch.all(batch_direct.forces == batch.forces) class TestNeighborhood: