Skip to content

Commit

Permalink
Update NN in training benchmark (#453)
Browse files Browse the repository at this point in the history
ANI1x NNs are updated to be consistent with what we used in the paper
  • Loading branch information
farhadrgh committed Apr 13, 2020
1 parent fec59ac commit f967877
Showing 1 changed file with 40 additions and 13 deletions.
53 changes: 40 additions & 13 deletions tools/training-benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,45 @@

synchronize = False


def atomic():
model = torch.nn.Sequential(
torch.nn.Linear(384, 128),
torch.nn.CELU(0.1),
torch.nn.Linear(128, 128),
torch.nn.CELU(0.1),
torch.nn.Linear(128, 64),
torch.nn.CELU(0.1),
torch.nn.Linear(64, 1)
)
return model
H_network = torch.nn.Sequential(
torch.nn.Linear(384, 160),
torch.nn.CELU(0.1),
torch.nn.Linear(160, 128),
torch.nn.CELU(0.1),
torch.nn.Linear(128, 96),
torch.nn.CELU(0.1),
torch.nn.Linear(96, 1)
)

C_network = torch.nn.Sequential(
torch.nn.Linear(384, 144),
torch.nn.CELU(0.1),
torch.nn.Linear(144, 112),
torch.nn.CELU(0.1),
torch.nn.Linear(112, 96),
torch.nn.CELU(0.1),
torch.nn.Linear(96, 1)
)

N_network = torch.nn.Sequential(
torch.nn.Linear(384, 128),
torch.nn.CELU(0.1),
torch.nn.Linear(128, 112),
torch.nn.CELU(0.1),
torch.nn.Linear(112, 96),
torch.nn.CELU(0.1),
torch.nn.Linear(96, 1)
)

O_network = torch.nn.Sequential(
torch.nn.Linear(384, 128),
torch.nn.CELU(0.1),
torch.nn.Linear(128, 112),
torch.nn.CELU(0.1),
torch.nn.Linear(112, 96),
torch.nn.CELU(0.1),
torch.nn.Linear(96, 1)
)


def time_func(key, func):
Expand Down Expand Up @@ -71,7 +98,7 @@ def wrapper(*args, **kwargs):
num_species = 4
aev_computer = torchani.AEVComputer(Rcr, Rca, EtaR, ShfR, EtaA, Zeta, ShfA, ShfZ, num_species)

nn = torchani.ANIModel([atomic() for _ in range(4)])
nn = torchani.ANIModel([H_network, C_network, N_network, O_network])
model = torch.nn.Sequential(aev_computer, nn).to(parser.device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.000001)
mse = torch.nn.MSELoss(reduction='none')
Expand Down

0 comments on commit f967877

Please sign in to comment.