diff --git a/tools/training-benchmark.py b/tools/training-benchmark.py index 7e3f697b0..3b1ceb902 100644 --- a/tools/training-benchmark.py +++ b/tools/training-benchmark.py @@ -8,18 +8,45 @@ synchronize = False - -def atomic(): - model = torch.nn.Sequential( - torch.nn.Linear(384, 128), - torch.nn.CELU(0.1), - torch.nn.Linear(128, 128), - torch.nn.CELU(0.1), - torch.nn.Linear(128, 64), - torch.nn.CELU(0.1), - torch.nn.Linear(64, 1) - ) - return model +H_network = torch.nn.Sequential( + torch.nn.Linear(384, 160), + torch.nn.CELU(0.1), + torch.nn.Linear(160, 128), + torch.nn.CELU(0.1), + torch.nn.Linear(128, 96), + torch.nn.CELU(0.1), + torch.nn.Linear(96, 1) +) + +C_network = torch.nn.Sequential( + torch.nn.Linear(384, 144), + torch.nn.CELU(0.1), + torch.nn.Linear(144, 112), + torch.nn.CELU(0.1), + torch.nn.Linear(112, 96), + torch.nn.CELU(0.1), + torch.nn.Linear(96, 1) +) + +N_network = torch.nn.Sequential( + torch.nn.Linear(384, 128), + torch.nn.CELU(0.1), + torch.nn.Linear(128, 112), + torch.nn.CELU(0.1), + torch.nn.Linear(112, 96), + torch.nn.CELU(0.1), + torch.nn.Linear(96, 1) +) + +O_network = torch.nn.Sequential( + torch.nn.Linear(384, 128), + torch.nn.CELU(0.1), + torch.nn.Linear(128, 112), + torch.nn.CELU(0.1), + torch.nn.Linear(112, 96), + torch.nn.CELU(0.1), + torch.nn.Linear(96, 1) +) def time_func(key, func): @@ -71,7 +98,7 @@ def wrapper(*args, **kwargs): num_species = 4 aev_computer = torchani.AEVComputer(Rcr, Rca, EtaR, ShfR, EtaA, Zeta, ShfA, ShfZ, num_species) - nn = torchani.ANIModel([atomic() for _ in range(4)]) + nn = torchani.ANIModel([H_network, C_network, N_network, O_network]) model = torch.nn.Sequential(aev_computer, nn).to(parser.device) optimizer = torch.optim.Adam(model.parameters(), lr=0.000001) mse = torch.nn.MSELoss(reduction='none')