Skip to content

Commit

Permalink
Merge pull request #73 from davkovacs/on_the_fly_dataloading
Browse files Browse the repository at this point in the history
on the fly data loading
  • Loading branch information
ilyes319 authored Feb 15, 2023
2 parents e8b5bac + 652a96e commit d619d9d
Show file tree
Hide file tree
Showing 9 changed files with 772 additions and 159 deletions.
49 changes: 48 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ python ./mace/scripts/run_train.py \

To give a specific validation set, use the argument `--valid_file`. To set a larger batch size for evaluating the validation set, specify `--valid_batch_size`.

To control the model's size, you need to change `--hidden_irreps`. For most applications, the recommended default model size is `--hidden_irreps='256x0e'` (meaning 256 invariant messages) or `--hidden_irreps='128x0e + 128x1o'`. If the model is not accurate enough, you can include higher order features, e.g., `128x0e + 128x1o + 128x2e`, or increase the number of channels to `256`.
To control the model's size, you need to change `--hidden_irreps`. For most applications, the recommended default model size is `--hidden_irreps='256x0e'` (meaning 256 invariant messages) or `--hidden_irreps='128x0e + 128x1o'`. If the model is not accurate enough, you can include higher order features, e.g., `128x0e + 128x1o + 128x2e`, or increase the number of channels to `256`. It is also possible to specify the model using the `--num_channels=128` and `--max_L=1`keys.

It is usually preferred to add the isolated atoms to the training set, rather than reading in their energies through the command line like in the example above. To label them in the training set, set `config_type=IsolatedAtom` in their info fields. If you prefer not to use or do not know the energies of the isolated atoms, you can use the option `--E0s="average"` which estimates the atomic energies using least squares regression.

Expand Down Expand Up @@ -105,6 +105,53 @@ python3 ./mace/scripts/eval_configs.py \

You can run our [Colab tutorial](https://colab.research.google.com/drive/1D6EtMUjQPey_GkuxUAbPgld6_9ibIa-V?authuser=1#scrollTo=Z10787RE1N8T) to quickly get started with MACE.

## On-line data loading for large datasets

If you have a large dataset that might not fit into the GPU memory it is recommended to preprocess the data on a CPU and use on-line dataloading for training the model. To preprocess your dataset specified as an xyz file run the `preprocess_data.py` script. An example is given here:

```sh
mkdir processed_data
python ./mace/scripts/preprocess_data.py \
--train_file="/path/to/train_large.xyz" \
--valid_fraction=0.05 \
--test_file="/path/to/test_large.xyz" \
--atomic_numbers="[1, 6, 7, 8, 9, 15, 16, 17, 35, 53]" \
--r_max=4.5 \
--h5_prefix="processed_data/" \
--compute_statistics \
--E0s="average" \
--num_workers=32 \
--seed=123 \
```

To see all options and a little description of them run `python ./mace/scripts/preprocess_data.py --help` . The script will create a number of HDF5 files in the `processed_data` folder which can be used for training. There wiull be one file for trainin, one for validation and a separate one for each `config_type` in the test set. To train the model use the `run_train.py` script as follows:

```sh
python ./mace/scripts/run_train.py \
--name="MACE_on_big_data" \
--num_workers=16 \
--train_file="./processed_data/train.h5" \
--valid_file="./processed_data/valid.h5" \
--test_dir="./processed_data" \
--statistics_file="./processed_data/statistics.json" \
--model="ScaleShiftMACE" \
--num_interactions=2 \
--num_channels=128 \
--max_L=1 \
--correlation=3 \
--batch_size=32 \
--valid_batch_size=32 \
--max_num_epochs=100 \
--swa \
--start_swa=60 \
--ema \
--ema_decay=0.99 \
--amsgrad \
--error_table='PerAtomMAE' \
--device=cuda \
--seed=123 \
```

## Weights and Biases for experiment tracking

If you would like to use MACE with Weights and Biases to log your experiments simply install with
Expand Down
4 changes: 4 additions & 0 deletions mace/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@
load_from_xyz,
random_train_valid_split,
test_config_types,
save_dataset_as_HDF5,
)
from .hdf5_dataset import HDF5Dataset

__all__ = [
"get_neighborhood",
Expand All @@ -22,4 +24,6 @@
"config_from_atoms_list",
"AtomicData",
"compute_average_E0s",
"save_dataset_as_HDF5",
"HDF5Dataset",
]
59 changes: 59 additions & 0 deletions mace/data/hdf5_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import h5py
import torch
from torch.utils.data import Dataset
from mace.data import AtomicData

class HDF5Dataset(Dataset):
def __init__(self, file, **kwargs):
super(HDF5Dataset, self).__init__()
# it might be dangerous to open the file here
# move opening of file to __getitem__?
self.file = h5py.File(file, 'r')
self.length = len(self.file.keys())
# self.file = file
# self.length = len(h5py.File(file, 'r').keys())

def __len__(self):
return self.length

def __getitem__(self, index):
# file = h5py.File(self.file, 'r')
# grp = file["config_" + str(index)]
grp = self.file["config_" + str(index)]
edge_index = grp['edge_index'][()]
positions = grp['positions'][()]
shifts = grp['shifts'][()]
unit_shifts = grp['unit_shifts'][()]
cell = grp['cell'][()]
node_attrs = grp['node_attrs'][()]
weight = grp['weight'][()]
energy_weight = grp['energy_weight'][()]
forces_weight = grp['forces_weight'][()]
stress_weight = grp['stress_weight'][()]
virials_weight = grp['virials_weight'][()]
forces = grp['forces'][()]
energy = grp['energy'][()]
stress = grp['stress'][()]
virials = grp['virials'][()]
dipole = grp['dipole'][()]
charges = grp['charges'][()]
return AtomicData(
edge_index = torch.tensor(edge_index, dtype=torch.long),
positions = torch.tensor(positions, dtype=torch.get_default_dtype()),
shifts = torch.tensor(shifts, dtype=torch.get_default_dtype()),
unit_shifts=torch.tensor(unit_shifts, dtype=torch.get_default_dtype()),
cell=torch.tensor(cell, dtype=torch.get_default_dtype()),
node_attrs=torch.tensor(node_attrs, dtype=torch.get_default_dtype()),
weight=torch.tensor(weight, dtype=torch.get_default_dtype()),
energy_weight=torch.tensor(energy_weight, dtype=torch.get_default_dtype()),
forces_weight=torch.tensor(forces_weight, dtype=torch.get_default_dtype()),
stress_weight=torch.tensor(stress_weight, dtype=torch.get_default_dtype()),
virials_weight=torch.tensor(virials_weight, dtype=torch.get_default_dtype()),
forces=torch.tensor(forces, dtype=torch.get_default_dtype()),
energy=torch.tensor(energy, dtype=torch.get_default_dtype()),
stress=torch.tensor(stress, dtype=torch.get_default_dtype()),
virials=torch.tensor(virials, dtype=torch.get_default_dtype()),
dipole=torch.tensor(dipole, dtype=torch.get_default_dtype()),
charges=torch.tensor(charges, dtype=torch.get_default_dtype()),
)

30 changes: 30 additions & 0 deletions mace/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
###########################################################################################

import logging
import h5py
from multiprocessing import Pool
from dataclasses import dataclass
from typing import Dict, List, Optional, Sequence, Tuple

Expand Down Expand Up @@ -148,6 +150,9 @@ def config_from_atoms(
if virials is None:
virials = np.zeros((3, 3))
virials_weight = 0.0
if dipole is None:
dipole = np.zeros(3)
#dipoles_weight = 0.0

return Configuration(
atomic_numbers=atomic_numbers,
Expand Down Expand Up @@ -265,3 +270,28 @@ def compute_average_E0s(
for i, z in enumerate(z_table.zs):
atomic_energies_dict[z] = 0.0
return atomic_energies_dict

def save_dataset_as_HDF5(
dataset:List, out_name: str
) -> None:
with h5py.File(out_name, 'w') as f:
for i, data in enumerate(dataset):
grp = f.create_group(f'config_{i}')
grp["num_nodes"] = data.num_nodes
grp["edge_index"] = data.edge_index
grp["positions"] = data.positions
grp["shifts"] = data.shifts
grp["unit_shifts"] = data.unit_shifts
grp["cell"] = data.cell
grp["node_attrs"] = data.node_attrs
grp["weight"] = data.weight
grp["energy_weight"] = data.energy_weight
grp["forces_weight"] = data.forces_weight
grp["stress_weight"] = data.stress_weight
grp["virials_weight"] = data.virials_weight
grp["forces"] = data.forces
grp["energy"] = data.energy
grp["stress"] = data.stress
grp["virials"] = data.virials
grp["dipole"] = data.dipole
grp["charges"] = data.charges
3 changes: 2 additions & 1 deletion mace/tools/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .arg_parser import build_default_arg_parser
from .arg_parser import build_default_arg_parser, build_preprocess_arg_parser
from .cg import U_matrix_real
from .checkpoint import CheckpointHandler, CheckpointIO, CheckpointState
from .torch_tools import (
Expand Down Expand Up @@ -64,4 +64,5 @@
"cartesian_to_spherical",
"voigt_to_matrix",
"init_wandb",
"build_preprocess_arg_parser"
]
Loading

0 comments on commit d619d9d

Please sign in to comment.