Skip to content

Commit

Permalink
Reduce user inputs for training custom models (#460)
Browse files Browse the repository at this point in the history
* Reduce user inputs for training custom models

* Update nnp_training.py
  • Loading branch information
farhadrgh committed Apr 24, 2020
1 parent 5eb73cd commit b12e7e6
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 6 deletions.
6 changes: 3 additions & 3 deletions examples/nnp_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,10 @@
ShfZ = torch.tensor([1.9634954e-01, 5.8904862e-01, 9.8174770e-01, 1.3744468e+00, 1.7671459e+00, 2.1598449e+00, 2.5525440e+00, 2.9452431e+00], device=device)
EtaA = torch.tensor([8.0000000e+00], device=device)
ShfA = torch.tensor([9.0000000e-01, 1.5500000e+00, 2.2000000e+00, 2.8500000e+00], device=device)
num_species = 4
species_order = ['H', 'C', 'N', 'O']
num_species = len(species_order)
aev_computer = torchani.AEVComputer(Rcr, Rca, EtaR, ShfR, EtaA, Zeta, ShfA, ShfZ, num_species)
energy_shifter = torchani.utils.EnergyShifter(None)
species_to_tensor = torchani.utils.ChemicalSymbolsToInts(['H', 'C', 'N', 'O'])

###############################################################################
# Now let's setup datasets. These paths assumes the user run this script under
Expand All @@ -82,7 +82,7 @@
dspath = os.path.join(path, '../dataset/ani1-up_to_gdb4/ani_gdb_s01.h5')
batch_size = 2560

training, validation = torchani.data.load(dspath).subtract_self_energies(energy_shifter).species_to_indices().shuffle().split(0.8, None)
training, validation = torchani.data.load(dspath).subtract_self_energies(energy_shifter).species_to_indices(species_order).shuffle().split(0.8, None)
training = training.collate(batch_size).cache()
validation = validation.collate(batch_size).cache()
print('Self atomic energies: ', energy_shifter.self_energies)
Expand Down
6 changes: 3 additions & 3 deletions examples/nnp_training_force.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,10 @@
ShfZ = torch.tensor([1.9634954e-01, 5.8904862e-01, 9.8174770e-01, 1.3744468e+00, 1.7671459e+00, 2.1598449e+00, 2.5525440e+00, 2.9452431e+00], device=device)
EtaA = torch.tensor([8.0000000e+00], device=device)
ShfA = torch.tensor([9.0000000e-01, 1.5500000e+00, 2.2000000e+00, 2.8500000e+00], device=device)
num_species = 4
species_order = ['H', 'C', 'N', 'O']
num_species = len(species_order)
aev_computer = torchani.AEVComputer(Rcr, Rca, EtaR, ShfR, EtaA, Zeta, ShfA, ShfZ, num_species)
energy_shifter = torchani.utils.EnergyShifter(None)
species_to_tensor = torchani.utils.ChemicalSymbolsToInts(['H', 'C', 'N', 'O'])


try:
Expand All @@ -49,7 +49,7 @@

batch_size = 2560

training, validation = torchani.data.load(dspath).subtract_self_energies(energy_shifter).species_to_indices().shuffle().split(0.8, None)
training, validation = torchani.data.load(dspath).subtract_self_energies(energy_shifter).species_to_indices(species_order).shuffle().split(0.8, None)
training = training.collate(batch_size).cache()
validation = validation.collate(batch_size).cache()

Expand Down

0 comments on commit b12e7e6

Please sign in to comment.