Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve more on new dataset API #434

Merged
merged 11 commits into from
Mar 22, 2020
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,5 @@ dist
*.qdstrm
*.zip
Untitled.ipynb
/nnp_training.py
/test*.py
4 changes: 1 addition & 3 deletions examples/nnp_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,9 +82,7 @@
dspath = os.path.join(path, '../dataset/ani1-up_to_gdb4/ani_gdb_s01.h5')
batch_size = 2560

dataset = torchani.data.load(dspath).subtract_self_energies(energy_shifter).species_to_indices().shuffle()
size = len(dataset)
training, validation = dataset.split(int(0.8 * size), None)
training, validation = torchani.data.load(dspath).subtract_self_energies(energy_shifter).species_to_indices().shuffle().split(0.8, None)
training = training.collate(batch_size).cache()
validation = validation.collate(batch_size).cache()
print('Self atomic energies: ', energy_shifter.self_energies)
Expand Down
4 changes: 1 addition & 3 deletions examples/nnp_training_force.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,7 @@

batch_size = 2560

dataset = torchani.data.load(dspath).subtract_self_energies(energy_shifter).species_to_indices().shuffle()
size = len(dataset)
training, validation = dataset.split(int(0.8 * size), None)
training, validation = torchani.data.load(dspath).subtract_self_energies(energy_shifter).species_to_indices().shuffle().split(0.8, None)
training = training.collate(batch_size).cache()
validation = validation.collate(batch_size).cache()

Expand Down
84 changes: 83 additions & 1 deletion tests/test_data.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import os
import torch
import torchani
import unittest

path = os.path.dirname(os.path.realpath(__file__))
dataset_path = os.path.join(path, 'dataset/ani-1x/sample.h5')
dataset_path = os.path.join(path, '../dataset/ani-1x/sample.h5')
batch_size = 256
ani1x = torchani.models.ANI1x()
consts = ani1x.consts
Expand Down Expand Up @@ -34,6 +35,87 @@ def testNoUnnecessaryPadding(self):
non_padding = (species >= 0)[:, -1].nonzero()
self.assertGreater(non_padding.numel(), 0)

def testReEnter(self):
# make sure that a dataset can be iterated multiple times
ds = torchani.data.load(dataset_path)
for d in ds:
pass
entered = False
for d in ds:
entered = True
self.assertTrue(entered)

ds = ds.subtract_self_energies(sae_dict)
entered = False
for d in ds:
entered = True
self.assertTrue(entered)
entered = False
for d in ds:
entered = True
self.assertTrue(entered)

ds = ds.species_to_indices()
entered = False
for d in ds:
entered = True
self.assertTrue(entered)
entered = False
for d in ds:
entered = True
self.assertTrue(entered)

ds = ds.shuffle()
entered = False
for d in ds:
entered = True
pass
self.assertTrue(entered)
entered = False
for d in ds:
entered = True
self.assertTrue(entered)

ds = ds.collate(batch_size)
entered = False
for d in ds:
entered = True
pass
self.assertTrue(entered)
entered = False
for d in ds:
entered = True
self.assertTrue(entered)

ds = ds.cache()
entered = False
for d in ds:
entered = True
pass
self.assertTrue(entered)
entered = False
for d in ds:
entered = True
self.assertTrue(entered)

def testShapeInference(self):
shifter = torchani.EnergyShifter(None)
ds = torchani.data.load(dataset_path).subtract_self_energies(shifter)
len(ds)
ds = ds.species_to_indices()
len(ds)
ds = ds.shuffle()
len(ds)
ds = ds.collate(batch_size)
len(ds)

def testDataloader(self):
shifter = torchani.EnergyShifter(None)
dataset = list(torchani.data.load(dataset_path).subtract_self_energies(shifter).species_to_indices().shuffle())
loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, collate_fn=torchani.data.collate_fn, num_workers=64)
for i in loader:
pass


if __name__ == '__main__':
unittest.main()
Loading