Skip to content

Commit

Permalink
Allow easily loading individual models (#476)
Browse files Browse the repository at this point in the history
* Allow easily loading individual models

* Hopefully make the tests a tiny bit faster by not loading the whole network each time
  • Loading branch information
IgnacioJPickering committed Jun 3, 2020
1 parent 3a043d4 commit a5bad5c
Show file tree
Hide file tree
Showing 5 changed files with 63 additions and 35 deletions.
2 changes: 1 addition & 1 deletion tests/test_ase.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def get_numeric_force(atoms, eps):
class TestASE(unittest.TestCase):

def setUp(self):
self.model = torchani.models.ANI1x().double()[0]
self.model = torchani.models.ANI1x(model_index=0).double()

def testWithNumericalForceWithPBCEnabled(self):
atoms = Diamond(symbol="C", pbc=True)
Expand Down
8 changes: 4 additions & 4 deletions tests/test_energies.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@ class TestEnergies(unittest.TestCase):

def setUp(self):
self.tolerance = 5e-5
ani1x = torchani.models.ANI1x()
self.aev_computer = ani1x.aev_computer
self.nnp = ani1x.neural_networks[0]
self.energy_shifter = ani1x.energy_shifter
model = torchani.models.ANI1x(model_index=0)
self.aev_computer = model.aev_computer
self.nnp = model.neural_networks
self.energy_shifter = model.energy_shifter
self.nn = torchani.nn.Sequential(self.nnp, self.energy_shifter)
self.model = torchani.nn.Sequential(self.aev_computer, self.nnp, self.energy_shifter)

Expand Down
6 changes: 3 additions & 3 deletions tests/test_forces.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@ class TestForce(unittest.TestCase):

def setUp(self):
self.tolerance = 1e-5
ani1x = torchani.models.ANI1x()
self.aev_computer = ani1x.aev_computer
self.nnp = ani1x.neural_networks[0]
model = torchani.models.ANI1x(model_index=0)
self.aev_computer = model.aev_computer
self.nnp = model.neural_networks
self.model = torchani.nn.Sequential(self.aev_computer, self.nnp)

def random_skip(self):
Expand Down
3 changes: 1 addition & 2 deletions tests/test_structure_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,7 @@ class TestStructureOptimization(unittest.TestCase):

def setUp(self):
self.tolerance = 1e-6
self.ani1x = torchani.models.ANI1x()
self.calculator = self.ani1x[0].ase()
self.calculator = torchani.models.ANI1x(model_index=0).ase()

def testRMSE(self):
datafile = os.path.join(path, 'test_data/NeuroChemOptimized/all')
Expand Down
79 changes: 54 additions & 25 deletions torchani/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
Note that the class BuiltinModels can be accessed but it is deprecated and
shouldn't be used anymore.
"""

import os
import torch
from torch import Tensor
from typing import Tuple, Optional
Expand All @@ -53,6 +53,45 @@ def __init__(self, species_converter, aev_computer, neural_networks, energy_shif
self.consts = consts
self.sae_dict = sae_dict

@classmethod
def _from_neurochem_resources(cls, info_file_path, periodic_table_index=False, model_index=0):
# this is used to load only 1 model (by default model 0)
consts, sae_file, ensemble_prefix, ensemble_size = cls._parse_neurochem_resources(info_file_path)
if (model_index >= ensemble_size):
raise ValueError("The ensemble size is only {}, model {} can't be loaded".format(ensemble_size, model_index))

species_converter = SpeciesConverter(consts.species)
aev_computer = AEVComputer(**consts)
energy_shifter, sae_dict = neurochem.load_sae(sae_file, return_dict=True)
species_to_tensor = consts.species_to_tensor

network_dir = os.path.join('{}{}'.format(ensemble_prefix, model_index), 'networks')
neural_networks = neurochem.load_model(consts.species, network_dir)

return cls(species_converter, aev_computer, neural_networks,
energy_shifter, species_to_tensor, consts, sae_dict, periodic_table_index)

@staticmethod
def _parse_neurochem_resources(info_file_path):
def get_resource(file_path):
package_name = '.'.join(__name__.split('.')[:-1])
return resource_filename(package_name, 'resources/' + file_path)

info_file = get_resource(info_file_path)

with open(info_file) as f:
# const_file: Path to the file with the builtin constants.
# sae_file: Path to the file with the Self Atomic Energies.
# ensemble_prefix: Prefix of the neurochem resource directories.
lines = [x.strip() for x in f.readlines()][:4]
const_file_path, sae_file_path, ensemble_prefix_path, ensemble_size = lines
const_file = get_resource(const_file_path)
sae_file = get_resource(sae_file_path)
ensemble_prefix = get_resource(ensemble_prefix_path)
ensemble_size = int(ensemble_size)
consts = neurochem.Constants(const_file)
return consts, sae_file, ensemble_prefix, ensemble_size

def forward(self, species_coordinates: Tuple[Tensor, Tensor],
cell: Optional[Tensor] = None,
pbc: Optional[Tensor] = None) -> SpeciesEnergies:
Expand Down Expand Up @@ -159,31 +198,15 @@ def __init__(self, species_converter, aev_computer, neural_networks,

@classmethod
def _from_neurochem_resources(cls, info_file_path, periodic_table_index=False):

def get_resource(file_path):
package_name = '.'.join(__name__.split('.')[:-1])
return resource_filename(package_name, 'resources/' + file_path)

info_file = get_resource(info_file_path)

with open(info_file) as f:
# const_file: Path to the file with the builtin constants.
# sae_file: Path to the file with the Self Atomic Energies.
# ensemble_prefix: Prefix of the neurochem resource directories.
lines = [x.strip() for x in f.readlines()][:4]
const_file_path, sae_file_path, ensemble_prefix_path, ensemble_size = lines
const_file = get_resource(const_file_path)
sae_file = get_resource(sae_file_path)
ensemble_prefix = get_resource(ensemble_prefix_path)
ensemble_size = int(ensemble_size)
consts = neurochem.Constants(const_file)
# this is used to load only 1 model (by default model 0)
consts, sae_file, ensemble_prefix, ensemble_size = cls._parse_neurochem_resources(info_file_path)

species_converter = SpeciesConverter(consts.species)
aev_computer = AEVComputer(**consts)
neural_networks = neurochem.load_model_ensemble(consts.species,
ensemble_prefix, ensemble_size)
energy_shifter, sae_dict = neurochem.load_sae(sae_file, return_dict=True)
species_to_tensor = consts.species_to_tensor
neural_networks = neurochem.load_model_ensemble(consts.species,
ensemble_prefix, ensemble_size)

return cls(species_converter, aev_computer, neural_networks,
energy_shifter, species_to_tensor, consts, sae_dict, periodic_table_index)
Expand Down Expand Up @@ -220,7 +243,7 @@ def __len__(self):
return len(self.neural_networks)


def ANI1x(periodic_table_index=False):
def ANI1x(periodic_table_index=False, model_index=None):
"""The ANI-1x model as in `ani-1x_8x on GitHub`_ and `Active Learning Paper`_.
The ANI-1x model is an ensemble of 8 networks that was trained using
Expand All @@ -234,10 +257,13 @@ def ANI1x(periodic_table_index=False):
.. _Active Learning Paper:
https://aip.scitation.org/doi/abs/10.1063/1.5023802
"""
return BuiltinEnsemble._from_neurochem_resources('ani-1x_8x.info', periodic_table_index)
info_file = 'ani-1x_8x.info'
if model_index is None:
return BuiltinEnsemble._from_neurochem_resources(info_file, periodic_table_index)
return BuiltinModel._from_neurochem_resources(info_file, periodic_table_index, model_index)


def ANI1ccx(periodic_table_index=False):
def ANI1ccx(periodic_table_index=False, model_index=None):
"""The ANI-1ccx model as in `ani-1ccx_8x on GitHub`_ and `Transfer Learning Paper`_.
The ANI-1ccx model is an ensemble of 8 networks that was trained
Expand All @@ -252,4 +278,7 @@ def ANI1ccx(periodic_table_index=False):
.. _Transfer Learning Paper:
https://doi.org/10.26434/chemrxiv.6744440.v1
"""
return BuiltinEnsemble._from_neurochem_resources('ani-1ccx_8x.info', periodic_table_index)
info_file = 'ani-1ccx_8x.info'
if model_index is None:
return BuiltinEnsemble._from_neurochem_resources(info_file, periodic_table_index)
return BuiltinModel._from_neurochem_resources(info_file, periodic_table_index, model_index)

0 comments on commit a5bad5c

Please sign in to comment.