Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multi gpu #99

Merged
merged 3 commits into from
Apr 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions mace/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,10 @@
config_from_atoms_list,
load_from_xyz,
random_train_valid_split,
save_configurations_as_HDF5,
test_config_types,
save_dataset_as_HDF5,
save_AtomicData_to_HDF5
save_AtomicData_to_HDF5,
)
from .hdf5_dataset import HDF5Dataset

Expand All @@ -27,5 +28,6 @@
"compute_average_E0s",
"save_dataset_as_HDF5",
"HDF5Dataset",
"save_AtomicData_to_HDF5"
"save_AtomicData_to_HDF5",
"save_configurations_as_HDF5",
]
157 changes: 113 additions & 44 deletions mace/data/hdf5_dataset.py
Original file line number Diff line number Diff line change
@@ -1,59 +1,128 @@
import h5py
import torch
from torch.utils.data import Dataset
from mace import data
from mace.data import AtomicData
from mace.data.utils import Configuration

class HDF5Dataset(Dataset):
def __init__(self, file, **kwargs):
super(HDF5Dataset, self).__init__()
import h5py
import torch
from torch.utils.data import Dataset, IterableDataset, ChainDataset
from mace import data
from mace.data import AtomicData
from mace.data.utils import Configuration


class HDF5ChainDataset(ChainDataset):
def __init__(self, file, r_max, z_table, **kwargs):
super(HDF5ChainDataset, self).__init__()
self.file = file
self.length = len(h5py.File(file, "r").keys())
self.r_max = r_max
self.z_table = z_table

def __call__(self):
self.file = h5py.File(self.file, "r")
datasets = []
for i in range(self.length):
grp = self.file["config_" + str(i)]
datasets.append(
HDF5IterDataset(iter_group=grp, r_max=self.r_max, z_table=self.z_table,)
)
return ChainDataset(datasets)


class HDF5IterDataset(IterableDataset):
def __init__(self, iter_group, r_max, z_table, **kwargs):
super(HDF5IterDataset, self).__init__()
# it might be dangerous to open the file here
# move opening of file to __getitem__?
self.file = h5py.File(file, 'r')
self.length = len(self.file.keys())
self.iter_group = iter_group
self.length = len(self.iter_group.keys())
self.r_max = r_max
self.z_table = z_table
# self.file = file
# self.length = len(h5py.File(file, 'r').keys())

def __len__(self):
return self.length

def __iter__(self):
# file = h5py.File(self.file, 'r')
# grp = file["config_" + str(index)]
grp = self.iter_group
len_subgrp = len(grp.keys())
grp_list = []
for i in range(len_subgrp):
subgrp = grp["config_" + str(i)]
config = Configuration(
atomic_numbers=subgrp["atomic_numbers"][()],
positions=subgrp["positions"][()],
energy=subgrp["energy"][()],
forces=subgrp["forces"][()],
stress=subgrp["stress"][()],
virials=subgrp["virials"][()],
dipole=subgrp["dipole"][()],
charges=subgrp["charges"][()],
weight=subgrp["weight"][()],
energy_weight=subgrp["energy_weight"][()],
forces_weight=subgrp["forces_weight"][()],
stress_weight=subgrp["stress_weight"][()],
virials_weight=subgrp["virials_weight"][()],
config_type=subgrp["config_type"][()],
pbc=subgrp["pbc"][()],
cell=subgrp["cell"][()],
)
atomic_data = data.AtomicData.from_config(
config, z_table=self.z_table, cutoff=self.r_max
)
grp_list.append(atomic_data)

return iter(grp_list)


class HDF5Dataset(Dataset):
def __init__(self, file, r_max, z_table, **kwargs):
super(HDF5Dataset, self).__init__()
self.file = h5py.File(file, "r") # this is dangerous to open the file here
self.batch_size = len(self.file["config_batch_0"].keys())
self.length = len(self.file.keys()) * self.batch_size
self.r_max = r_max
self.z_table = z_table

def __len__(self):
return self.length

def __getitem__(self, index):
# file = h5py.File(self.file, 'r')
# grp = file["config_" + str(index)]
grp = self.file["config_" + str(index)]
edge_index = grp['edge_index'][()]
positions = grp['positions'][()]
shifts = grp['shifts'][()]
unit_shifts = grp['unit_shifts'][()]
cell = grp['cell'][()]
node_attrs = grp['node_attrs'][()]
weight = grp['weight'][()]
energy_weight = grp['energy_weight'][()]
forces_weight = grp['forces_weight'][()]
stress_weight = grp['stress_weight'][()]
virials_weight = grp['virials_weight'][()]
forces = grp['forces'][()]
energy = grp['energy'][()]
stress = grp['stress'][()]
virials = grp['virials'][()]
dipole = grp['dipole'][()]
charges = grp['charges'][()]
return AtomicData(
edge_index = torch.tensor(edge_index, dtype=torch.long),
positions = torch.tensor(positions, dtype=torch.get_default_dtype()),
shifts = torch.tensor(shifts, dtype=torch.get_default_dtype()),
unit_shifts=torch.tensor(unit_shifts, dtype=torch.get_default_dtype()),
cell=torch.tensor(cell, dtype=torch.get_default_dtype()),
node_attrs=torch.tensor(node_attrs, dtype=torch.get_default_dtype()),
weight=torch.tensor(weight, dtype=torch.get_default_dtype()),
energy_weight=torch.tensor(energy_weight, dtype=torch.get_default_dtype()),
forces_weight=torch.tensor(forces_weight, dtype=torch.get_default_dtype()),
stress_weight=torch.tensor(stress_weight, dtype=torch.get_default_dtype()),
virials_weight=torch.tensor(virials_weight, dtype=torch.get_default_dtype()),
forces=torch.tensor(forces, dtype=torch.get_default_dtype()),
energy=torch.tensor(energy, dtype=torch.get_default_dtype()),
stress=torch.tensor(stress, dtype=torch.get_default_dtype()),
virials=torch.tensor(virials, dtype=torch.get_default_dtype()),
dipole=torch.tensor(dipole, dtype=torch.get_default_dtype()),
charges=torch.tensor(charges, dtype=torch.get_default_dtype()),
# compute the index of the batch
batch_index = index // self.batch_size
config_index = index % self.batch_size
grp = self.file["config_batch_" + str(batch_index)]
subgrp = grp["config_" + str(config_index)]
config = Configuration(
atomic_numbers=subgrp["atomic_numbers"][()],
positions=subgrp["positions"][()],
energy=unpack_value(subgrp["energy"][()]),
forces=unpack_value(subgrp["forces"][()]),
stress=unpack_value(subgrp["stress"][()]),
virials=unpack_value(subgrp["virials"][()]),
dipole=unpack_value(subgrp["dipole"][()]),
charges=unpack_value(subgrp["charges"][()]),
weight=unpack_value(subgrp["weight"][()]),
energy_weight=unpack_value(subgrp["energy_weight"][()]),
forces_weight=unpack_value(subgrp["forces_weight"][()]),
stress_weight=unpack_value(subgrp["stress_weight"][()]),
virials_weight=unpack_value(subgrp["virials_weight"][()]),
config_type=unpack_value(subgrp["config_type"][()]),
pbc=unpack_value(subgrp["pbc"][()]),
cell=unpack_value(subgrp["cell"][()]),
)
atomic_data = data.AtomicData.from_config(
config, z_table=self.z_table, cutoff=self.r_max
)

return atomic_data


def unpack_value(value):
value = value.decode("utf-8") if isinstance(value, bytes) else value
return None if str(value) == "None" else value
86 changes: 56 additions & 30 deletions mace/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def config_from_atoms(
virials_weight = 0.0
if dipole is None:
dipole = np.zeros(3)
#dipoles_weight = 0.0
# dipoles_weight = 0.0

return Configuration(
atomic_numbers=atomic_numbers,
Expand Down Expand Up @@ -271,35 +271,33 @@ def compute_average_E0s(
atomic_energies_dict[z] = 0.0
return atomic_energies_dict

def save_dataset_as_HDF5(
dataset:List, out_name: str
) -> None:
with h5py.File(out_name, 'w') as f:
for i, data in enumerate(dataset):
grp = f.create_group(f'config_{i}')
grp["num_nodes"] = data.num_nodes
grp["edge_index"] = data.edge_index
grp["positions"] = data.positions
grp["shifts"] = data.shifts
grp["unit_shifts"] = data.unit_shifts
grp["cell"] = data.cell
grp["node_attrs"] = data.node_attrs
grp["weight"] = data.weight
grp["energy_weight"] = data.energy_weight
grp["forces_weight"] = data.forces_weight
grp["stress_weight"] = data.stress_weight
grp["virials_weight"] = data.virials_weight
grp["forces"] = data.forces
grp["energy"] = data.energy
grp["stress"] = data.stress
grp["virials"] = data.virials
grp["dipole"] = data.dipole
grp["charges"] = data.charges

def save_AtomicData_to_HDF5(
data, i, h5_file
) -> None:
grp = h5_file.create_group(f'config_{i}')

def save_dataset_as_HDF5(dataset: List, out_name: str) -> None:
with h5py.File(out_name, "w") as f:
for i, data in enumerate(dataset):
grp = f.create_group(f"config_{i}")
grp["num_nodes"] = data.num_nodes
grp["edge_index"] = data.edge_index
grp["positions"] = data.positions
grp["shifts"] = data.shifts
grp["unit_shifts"] = data.unit_shifts
grp["cell"] = data.cell
grp["node_attrs"] = data.node_attrs
grp["weight"] = data.weight
grp["energy_weight"] = data.energy_weight
grp["forces_weight"] = data.forces_weight
grp["stress_weight"] = data.stress_weight
grp["virials_weight"] = data.virials_weight
grp["forces"] = data.forces
grp["energy"] = data.energy
grp["stress"] = data.stress
grp["virials"] = data.virials
grp["dipole"] = data.dipole
grp["charges"] = data.charges


def save_AtomicData_to_HDF5(data, i, h5_file) -> None:
grp = h5_file.create_group(f"config_{i}")
grp["num_nodes"] = data.num_nodes
grp["edge_index"] = data.edge_index
grp["positions"] = data.positions
Expand All @@ -318,3 +316,31 @@ def save_AtomicData_to_HDF5(
grp["virials"] = data.virials
grp["dipole"] = data.dipole
grp["charges"] = data.charges


def save_configurations_as_HDF5(configurations: Configurations, i, h5_file) -> None:
grp = h5_file.create_group(f"config_batch_{i}")
for i, config in enumerate(configurations):
subgroup_name = f"config_{i}"
subgroup = grp.create_group(subgroup_name)
subgroup["atomic_numbers"] = write_value(config.atomic_numbers)
subgroup["positions"] = write_value(config.positions)
subgroup["energy"] = write_value(config.energy)
subgroup["forces"] = write_value(config.forces)
subgroup["stress"] = write_value(config.stress)
subgroup["virials"] = write_value(config.virials)
subgroup["dipole"] = write_value(config.dipole)
subgroup["charges"] = write_value(config.charges)
subgroup["cell"] = write_value(config.cell)
subgroup["pbc"] = write_value(config.pbc)
subgroup["weight"] = write_value(config.weight)
subgroup["energy_weight"] = write_value(config.energy_weight)
subgroup["forces_weight"] = write_value(config.forces_weight)
subgroup["stress_weight"] = write_value(config.stress_weight)
subgroup["virials_weight"] = write_value(config.virials_weight)
subgroup["config_type"] = write_value(config.config_type)


def write_value(value):
return value if value is not None else "None"

33 changes: 17 additions & 16 deletions scripts/preprocess_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import torch

from mace import tools, data
from mace.data.utils import save_AtomicData_to_HDF5 #, save_dataset_as_HDF5
from mace.data.utils import save_AtomicData_to_HDF5, save_configurations_as_HDF5 #, save_dataset_as_HDF5
from mace.tools.scripts_utils import (get_dataset_from_xyz,
get_atomic_energies)
from mace.tools import torch_geometric
Expand Down Expand Up @@ -91,10 +91,10 @@ def main():
random.shuffle(collections.train)

with h5py.File(args.h5_prefix + "train.h5", "w") as f:
for i, config in enumerate(collections.train):
atomic_data = data.AtomicData.from_config(
config, z_table=z_table, cutoff=args.r_max)
save_AtomicData_to_HDF5(atomic_data, i, f)
# split collections.train into batches and save them to hdf5
split_train = np.array_split(collections.train, args.batch_size)
for i, batch in enumerate(split_train):
save_configurations_as_HDF5(batch, i, f)


if args.compute_statistics:
Expand All @@ -106,8 +106,9 @@ def main():
[atomic_energies_dict[z] for z in z_table.zs]
)
logging.info(f"Atomic energies: {atomic_energies.tolist()}")
train_dataset = data.HDF5Dataset(args.h5_prefix + "train.h5", z_table=z_table, r_max=args.r_max)
train_loader = torch_geometric.dataloader.DataLoader(
data.HDF5Dataset(args.h5_prefix + "train.h5"),
dataset=train_dataset,
batch_size=args.batch_size,
shuffle=False,
drop_last=False,
Expand All @@ -128,6 +129,8 @@ def main():
"atomic_numbers": str(z_table.zs),
"r_max": args.r_max,
}
del train_dataset
del train_loader
with open(args.h5_prefix + "statistics.json", "w") as f:
json.dump(statistics, f)

Expand All @@ -136,19 +139,17 @@ def main():
random.shuffle(collections.valid)

with h5py.File(args.h5_prefix + "valid.h5", "w") as f:
for i, config in enumerate(collections.valid):
atomic_data = data.AtomicData.from_config(
config, z_table=z_table, cutoff=args.r_max)
save_AtomicData_to_HDF5(atomic_data, i, f)
split_valid = np.array_split(collections.valid, args.batch_size)
for i, batch in enumerate(split_valid):
save_configurations_as_HDF5(batch, i, f)

if args.test_file is not None:
logging.info("Preparing test sets")
for name, subset in collections.tests:
with h5py.File(args.h5_prefix + name + "_test.h5", "w") as f:
for i, config in enumerate(subset):
atomic_data = data.AtomicData.from_config(
config, z_table=z_table, cutoff=args.r_max)
save_AtomicData_to_HDF5(atomic_data, i, f)
with h5py.File(args.h5_prefix + name + "_test.h5", "w") as f:
for name, subset in collections.tests:
split_test = np.array_split(subset, args.batch_size)
for i, batch in enumerate(split_test):
save_configurations_as_HDF5(batch, i, f)


if __name__ == "__main__":
Expand Down
Loading