Skip to content

Commit

Permalink
Merge pull request #29 from wcwitt/torchscript_develop
Browse files Browse the repository at this point in the history
Changes required for LAMMPS integration
  • Loading branch information
ilyes319 authored Oct 11, 2022
2 parents 0caeced + 72b815a commit 7db50e5
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 3 deletions.
2 changes: 1 addition & 1 deletion mace/calculators/mace.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def __init__(
self.results = {}

self.model = torch.load(f=model_path, map_location=device)
self.r_max = self.model.r_max
self.r_max = self.model.r_max.item()
self.device = torch_tools.init_device(device)
self.energy_units_to_eV = energy_units_to_eV
self.length_units_to_A = length_units_to_A
Expand Down
4 changes: 2 additions & 2 deletions mace/modules/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,8 @@ def __init__(
gate: Optional[Callable],
):
super().__init__()
self.r_max = r_max
self.atomic_numbers = atomic_numbers
self.register_buffer("r_max", torch.tensor(r_max, dtype=torch.float64))
self.register_buffer("atomic_numbers", torch.tensor(atomic_numbers, dtype=torch.int64))
# Embedding
node_attr_irreps = o3.Irreps([(num_elements, (0, 1))])
node_feats_irreps = o3.Irreps([(hidden_irreps.count(o3.Irrep(0, 1)), (0, 1))])
Expand Down

0 comments on commit 7db50e5

Please sign in to comment.