diff --git a/mace/calculators/lammps_mace.py b/mace/calculators/lammps_mace.py index 597b9b82..67a846c1 100644 --- a/mace/calculators/lammps_mace.py +++ b/mace/calculators/lammps_mace.py @@ -44,7 +44,7 @@ def forward( virials: Optional[torch.Tensor] = torch.zeros_like(data["cell"]) stress: Optional[torch.Tensor] = torch.zeros_like(data["cell"]) if mask_ghost is not None and displacement is not None: - displacement.requires_grad_(True) # For some reason torchscript needs that. + #displacement.requires_grad_(True) # For some reason torchscript needs that. node_energy_ghost = node_energy * mask_ghost total_energy_ghost = scatter_sum( src=node_energy_ghost, index=data["batch"], dim=-1, dim_size=num_graphs