diff --git a/mace/calculators/lammps_mace.py b/mace/calculators/lammps_mace.py index d2ce9d00..597b9b82 100644 --- a/mace/calculators/lammps_mace.py +++ b/mace/calculators/lammps_mace.py @@ -70,6 +70,8 @@ def forward( torch.cross(cell[:, 1, :], cell[:, 2, :], dim=1), ).unsqueeze(-1) stress = virials / volume.view(-1, 1, 1) + else: + virials = torch.zeros_like(displacement) total_energy = scatter_sum( src=node_energy, index=data["batch"], dim=-1, dim_size=num_graphs