diff --git a/mace/calculators/mace.py b/mace/calculators/mace.py index 54828888..b4751c3b 100644 --- a/mace/calculators/mace.py +++ b/mace/calculators/mace.py @@ -30,7 +30,7 @@ def __init__( self.results = {} self.model = torch.load(f=model_path, map_location=device) - self.r_max = self.model.r_max + self.r_max = self.model.r_max.item() self.device = torch_tools.init_device(device) self.energy_units_to_eV = energy_units_to_eV self.length_units_to_A = length_units_to_A diff --git a/mace/modules/models.py b/mace/modules/models.py index e16373d5..3b310af1 100644 --- a/mace/modules/models.py +++ b/mace/modules/models.py @@ -53,8 +53,8 @@ def __init__( gate: Optional[Callable], ): super().__init__() - self.r_max = r_max - self.atomic_numbers = atomic_numbers + self.register_buffer("r_max", torch.tensor(r_max, dtype=torch.float64)) + self.register_buffer("atomic_numbers", torch.tensor(atomic_numbers, dtype=torch.int64)) # Embedding node_attr_irreps = o3.Irreps([(num_elements, (0, 1))]) node_feats_irreps = o3.Irreps([(hidden_irreps.count(o3.Irrep(0, 1)), (0, 1))])