diff --git a/deepmd/dpmodel/model/pair_tab_model.py b/deepmd/dpmodel/model/pair_tab_model.py new file mode 100644 index 0000000000..d62ac5c859 --- /dev/null +++ b/deepmd/dpmodel/model/pair_tab_model.py @@ -0,0 +1,296 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Dict, + List, + Optional, + Union, +) + +import numpy as np + +from deepmd.dpmodel.output_def import ( + FittingOutputDef, + OutputVariableDef, +) +from deepmd.utils.pair_tab import ( + PairTab, +) + +from .base_atomic_model import ( + BaseAtomicModel, +) + + +class PairTabModel(BaseAtomicModel): + """Pairwise tabulation energy model. + + This model can be used to tabulate the pairwise energy between atoms for either + short-range or long-range interactions, such as D3, LJ, ZBL, etc. It should not + be used alone, but rather as one submodel of a linear (sum) model, such as + DP+D3. + + Do not put the model on the first model of a linear model, since the linear + model fetches the type map from the first model. + + At this moment, the model does not smooth the energy at the cutoff radius, so + one needs to make sure the energy has been smoothed to zero. + + Parameters + ---------- + tab_file : str + The path to the tabulation file. + rcut : float + The cutoff radius. + sel : int or list[int] + The maxmum number of atoms in the cut-off radius. + """ + + def __init__( + self, tab_file: str, rcut: float, sel: Union[int, List[int]], **kwargs + ): + super().__init__() + self.tab_file = tab_file + self.rcut = rcut + + self.tab = PairTab(self.tab_file, rcut=rcut) + + if self.tab_file is not None: + self.tab_info, self.tab_data = self.tab.get() + else: + self.tab_info, self.tab_data = None, None + + if isinstance(sel, int): + self.sel = sel + elif isinstance(sel, list): + self.sel = sum(sel) + else: + raise TypeError("sel must be int or list[int]") + + def fitting_output_def(self) -> FittingOutputDef: + return FittingOutputDef( + [ + OutputVariableDef( + name="energy", shape=[1], reduciable=True, differentiable=True + ) + ] + ) + + def get_rcut(self) -> float: + return self.rcut + + def get_sel(self) -> int: + return self.sel + + def distinguish_types(self) -> bool: + # to match DPA1 and DPA2. + return False + + def serialize(self) -> dict: + return {"tab": self.tab.serialize(), "rcut": self.rcut, "sel": self.sel} + + @classmethod + def deserialize(cls, data) -> "PairTabModel": + rcut = data["rcut"] + sel = data["sel"] + tab = PairTab.deserialize(data["tab"]) + tab_model = cls(None, rcut, sel) + tab_model.tab = tab + tab_model.tab_info = tab_model.tab.tab_info + tab_model.tab_data = tab_model.tab.tab_data + return tab_model + + def forward_atomic( + self, + extended_coord, + extended_atype, + nlist, + mapping: Optional[np.ndarray] = None, + do_atomic_virial: bool = False, + ) -> Dict[str, np.ndarray]: + self.nframes, self.nloc, self.nnei = nlist.shape + extended_coord = extended_coord.reshape(self.nframes, -1, 3) + + # this will mask all -1 in the nlist + masked_nlist = np.clip(nlist, 0, None) + + atype = extended_atype[:, : self.nloc] # (nframes, nloc) + pairwise_dr = self._get_pairwise_dist( + extended_coord + ) # (nframes, nall, nall, 3) + pairwise_rr = np.sqrt( + np.sum(np.power(pairwise_dr, 2), axis=-1) + ) # (nframes, nall, nall) + self.tab_data = self.tab_data.reshape( + self.tab.ntypes, self.tab.ntypes, self.tab.nspline, 4 + ) + + # (nframes, nloc, nnei) + j_type = extended_atype[ + np.arange(extended_atype.shape[0])[:, None, None], masked_nlist + ] + + # slice rr to get (nframes, nloc, nnei) + rr = np.take_along_axis(pairwise_rr[:, : self.nloc, :], masked_nlist, 2) + raw_atomic_energy = self._pair_tabulated_inter(nlist, atype, j_type, rr) + atomic_energy = 0.5 * np.sum( + np.where(nlist != -1, raw_atomic_energy, np.zeros_like(raw_atomic_energy)), + axis=-1, + ).reshape(self.nframes, self.nloc, 1) + + return {"energy": atomic_energy} + + def _pair_tabulated_inter( + self, + nlist: np.ndarray, + i_type: np.ndarray, + j_type: np.ndarray, + rr: np.ndarray, + ) -> np.ndarray: + """Pairwise tabulated energy. + + Parameters + ---------- + nlist : np.ndarray + The unmasked neighbour list. (nframes, nloc) + i_type : np.ndarray + The integer representation of atom type for all local atoms for all frames. (nframes, nloc) + j_type : np.ndarray + The integer representation of atom type for all neighbour atoms of all local atoms for all frames. (nframes, nloc, nnei) + rr : np.ndarray + The salar distance vector between two atoms. (nframes, nloc, nnei) + + Returns + ------- + np.ndarray + The masked atomic energy for all local atoms for all frames. (nframes, nloc, nnei) + + Raises + ------ + Exception + If the distance is beyond the table. + + Notes + ----- + This function is used to calculate the pairwise energy between two atoms. + It uses a table containing cubic spline coefficients calculated in PairTab. + """ + rmin = self.tab_info[0] + hh = self.tab_info[1] + hi = 1.0 / hh + + self.nspline = int(self.tab_info[2] + 0.1) + + uu = (rr - rmin) * hi # this is broadcasted to (nframes,nloc,nnei) + + # if nnei of atom 0 has -1 in the nlist, uu would be 0. + # this is to handle the nlist where the mask is set to 0, so that we don't raise exception for those atoms. + uu = np.where(nlist != -1, uu, self.nspline + 1) + + if np.any(uu < 0): + raise Exception("coord go beyond table lower boundary") + + idx = uu.astype(int) + + uu -= idx + table_coef = self._extract_spline_coefficient( + i_type, j_type, idx, self.tab_data, self.nspline + ) + table_coef = table_coef.reshape(self.nframes, self.nloc, self.nnei, 4) + ener = self._calcualte_ener(table_coef, uu) + # here we need to overwrite energy to zero at rcut and beyond. + mask_beyond_rcut = rr >= self.rcut + # also overwrite values beyond extrapolation to zero + extrapolation_mask = rr >= self.tab.rmin + self.nspline * self.tab.hh + ener[mask_beyond_rcut] = 0 + ener[extrapolation_mask] = 0 + + return ener + + @staticmethod + def _get_pairwise_dist(coords: np.ndarray) -> np.ndarray: + """Get pairwise distance `dr`. + + Parameters + ---------- + coords : np.ndarray + The coordinate of the atoms shape of (nframes, nall, 3). + + Returns + ------- + np.ndarray + The pairwise distance between the atoms (nframes, nall, nall, 3). + """ + return np.expand_dims(coords, 2) - np.expand_dims(coords, 1) + + @staticmethod + def _extract_spline_coefficient( + i_type: np.ndarray, + j_type: np.ndarray, + idx: np.ndarray, + tab_data: np.ndarray, + nspline: int, + ) -> np.ndarray: + """Extract the spline coefficient from the table. + + Parameters + ---------- + i_type : np.ndarray + The integer representation of atom type for all local atoms for all frames. (nframes, nloc) + j_type : np.ndarray + The integer representation of atom type for all neighbour atoms of all local atoms for all frames. (nframes, nloc, nnei) + idx : np.ndarray + The index of the spline coefficient. (nframes, nloc, nnei) + tab_data : np.ndarray + The table storing all the spline coefficient. (ntype, ntype, nspline, 4) + nspline : int + The number of splines in the table. + + Returns + ------- + np.ndarray + The spline coefficient. (nframes, nloc, nnei, 4), shape may be squeezed. + """ + # (nframes, nloc, nnei) + expanded_i_type = np.broadcast_to( + i_type[:, :, np.newaxis], + (i_type.shape[0], i_type.shape[1], j_type.shape[-1]), + ) + + # (nframes, nloc, nnei, nspline, 4) + expanded_tab_data = tab_data[expanded_i_type, j_type] + + # (nframes, nloc, nnei, 1, 4) + expanded_idx = np.broadcast_to( + idx[..., np.newaxis, np.newaxis], (*idx.shape, 1, 4) + ) + clipped_indices = np.clip(expanded_idx, 0, nspline - 1).astype(int) + + # (nframes, nloc, nnei, 4) + final_coef = np.squeeze( + np.take_along_axis(expanded_tab_data, clipped_indices, 3) + ) + + # when the spline idx is beyond the table, all spline coefficients are set to `0`, and the resulting ener corresponding to the idx is also `0`. + final_coef[expanded_idx.squeeze() > nspline] = 0 + return final_coef + + @staticmethod + def _calcualte_ener(coef: np.ndarray, uu: np.ndarray) -> np.ndarray: + """Calculate energy using spline coeeficients. + + Parameters + ---------- + coef : np.ndarray + The spline coefficients. (nframes, nloc, nnei, 4) + uu : np.ndarray + The atom displancemnt used in interpolation and extrapolation (nframes, nloc, nnei) + + Returns + ------- + np.ndarray + The atomic energy for all local atoms for all frames. (nframes, nloc, nnei) + """ + a3, a2, a1, a0 = coef[..., 0], coef[..., 1], coef[..., 2], coef[..., 3] + etmp = (a3 * uu + a2) * uu + a1 # this should be elementwise operations. + ener = etmp * uu + a0 # this energy has the extrapolated value when rcut > rmax + return ener diff --git a/deepmd/pt/model/model/pair_tab.py b/deepmd/pt/model/model/pair_tab_model.py similarity index 88% rename from deepmd/pt/model/model/pair_tab.py rename to deepmd/pt/model/model/pair_tab_model.py index 430d090eb0..1a415d633d 100644 --- a/deepmd/pt/model/model/pair_tab.py +++ b/deepmd/pt/model/model/pair_tab_model.py @@ -54,13 +54,19 @@ def __init__( super().__init__() self.tab_file = tab_file self.rcut = rcut - self.tab = PairTab(self.tab_file, rcut=rcut) - self.ntypes = self.tab.ntypes - tab_info, tab_data = self.tab.get() # this returns -> Tuple[np.array, np.array] - self.tab_info = torch.from_numpy(tab_info) - self.tab_data = torch.from_numpy(tab_data) + # handle deserialization with no input file + if self.tab_file is not None: + ( + tab_info, + tab_data, + ) = self.tab.get() # this returns -> Tuple[np.array, np.array] + self.tab_info = torch.from_numpy(tab_info) + self.tab_data = torch.from_numpy(tab_data) + else: + self.tab_info = None + self.tab_data = None # self.model_type = "ener" # self.model_version = MODEL_VERSION ## this shoud be in the parent class @@ -92,12 +98,18 @@ def distinguish_types(self) -> bool: return False def serialize(self) -> dict: - # place holder, implemantated in future PR - raise NotImplementedError - - def deserialize(cls): - # place holder, implemantated in future PR - raise NotImplementedError + return {"tab": self.tab.serialize(), "rcut": self.rcut, "sel": self.sel} + + @classmethod + def deserialize(cls, data) -> "PairTabModel": + rcut = data["rcut"] + sel = data["sel"] + tab = PairTab.deserialize(data["tab"]) + tab_model = cls(None, rcut, sel) + tab_model.tab = tab + tab_model.tab_info = torch.from_numpy(tab_model.tab.tab_info) + tab_model.tab_data = torch.from_numpy(tab_model.tab.tab_data) + return tab_model def forward_atomic( self, @@ -108,6 +120,7 @@ def forward_atomic( do_atomic_virial: bool = False, ) -> Dict[str, torch.Tensor]: self.nframes, self.nloc, self.nnei = nlist.shape + extended_coord = extended_coord.view(self.nframes, -1, 3) # this will mask all -1 in the nlist masked_nlist = torch.clamp(nlist, 0) @@ -118,7 +131,7 @@ def forward_atomic( ) # (nframes, nall, nall, 3) pairwise_rr = pairwise_dr.pow(2).sum(-1).sqrt() # (nframes, nall, nall) - self.tab_data = self.tab_data.reshape( + self.tab_data = self.tab_data.view( self.tab.ntypes, self.tab.ntypes, self.tab.nspline, 4 ) @@ -139,7 +152,7 @@ def forward_atomic( nlist != -1, raw_atomic_energy, torch.zeros_like(raw_atomic_energy) ), dim=-1, - ) + ).unsqueeze(-1) return {"energy": atomic_energy} @@ -200,7 +213,7 @@ def _pair_tabulated_inter( table_coef = self._extract_spline_coefficient( i_type, j_type, idx, self.tab_data, self.nspline ) - table_coef = table_coef.reshape(self.nframes, self.nloc, self.nnei, 4) + table_coef = table_coef.view(self.nframes, self.nloc, self.nnei, 4) ener = self._calcualte_ener(table_coef, uu) # here we need to overwrite energy to zero at rcut and beyond. @@ -219,12 +232,12 @@ def _get_pairwise_dist(coords: torch.Tensor) -> torch.Tensor: Parameters ---------- coords : torch.Tensor - The coordinate of the atoms shape of (nframes * nall * 3). + The coordinate of the atoms shape of (nframes, nall, 3). Returns ------- torch.Tensor - The pairwise distance between the atoms (nframes * nall * nall * 3). + The pairwise distance between the atoms (nframes, nall, nall, 3). Examples -------- diff --git a/deepmd/utils/pair_tab.py b/deepmd/utils/pair_tab.py index 56f8e618df..c97aefc108 100644 --- a/deepmd/utils/pair_tab.py +++ b/deepmd/utils/pair_tab.py @@ -44,6 +44,9 @@ def reinit(self, filename: str, rcut: Optional[float] = None) -> None: For example we have two atom types, 0 and 1. The columes from 2nd to 4th are for 0-0, 0-1 and 1-1 correspondingly. """ + if filename is None: + self.tab_info, self.tab_data = None, None + return self.vdata = np.loadtxt(filename) self.rmin = self.vdata[0][0] self.rmax = self.vdata[-1][0] @@ -65,6 +68,36 @@ def reinit(self, filename: str, rcut: Optional[float] = None) -> None: self.tab_info = np.array([self.rmin, self.hh, self.nspline, self.ntypes]) self.tab_data = self._make_data() + def serialize(self) -> dict: + return { + "rmin": self.rmin, + "rmax": self.rmax, + "hh": self.hh, + "ntypes": self.ntypes, + "rcut": self.rcut, + "nspline": self.nspline, + "@variables": { + "vdata": self.vdata, + "tab_info": self.tab_info, + "tab_data": self.tab_data, + }, + } + + @classmethod + def deserialize(cls, data) -> "PairTab": + variables = data.pop("@variables") + tab = PairTab(None, None) + tab.vdata = variables["vdata"] + tab.rmin = data["rmin"] + tab.rmax = data["rmax"] + tab.hh = data["hh"] + tab.ntypes = data["ntypes"] + tab.rcut = data["rcut"] + tab.nspline = data["nspline"] + tab.tab_info = variables["tab_info"] + tab.tab_data = variables["tab_data"] + return tab + def _check_table_upper_boundary(self) -> None: """Update User Provided Table Based on `rcut`. diff --git a/source/tests/common/test_pairtab_preprocess.py b/source/tests/common/test_pairtab_preprocess.py index a866c42236..26f96a3ca4 100644 --- a/source/tests/common/test_pairtab_preprocess.py +++ b/source/tests/common/test_pairtab_preprocess.py @@ -30,6 +30,18 @@ def setUp(self, mock_loadtxt) -> None: self.tab4 = PairTab(filename=file_path, rcut=0.03) self.tab5 = PairTab(filename=file_path, rcut=0.032) + def test_deserialize(self): + deserialized_tab = PairTab.deserialize(self.tab1.serialize()) + np.testing.assert_allclose(self.tab1.vdata, deserialized_tab.vdata) + np.testing.assert_allclose(self.tab1.rmin, deserialized_tab.rmin) + np.testing.assert_allclose(self.tab1.rmax, deserialized_tab.rmax) + np.testing.assert_allclose(self.tab1.hh, deserialized_tab.hh) + np.testing.assert_allclose(self.tab1.ntypes, deserialized_tab.ntypes) + np.testing.assert_allclose(self.tab1.rcut, deserialized_tab.rcut) + np.testing.assert_allclose(self.tab1.nspline, deserialized_tab.nspline) + np.testing.assert_allclose(self.tab1.tab_info, deserialized_tab.tab_info) + np.testing.assert_allclose(self.tab1.tab_data, deserialized_tab.tab_data) + def test_preprocess(self): np.testing.assert_allclose( self.tab1.vdata, diff --git a/source/tests/dpmodel/__init__.py b/source/tests/dpmodel/__init__.py new file mode 100644 index 0000000000..6ceb116d85 --- /dev/null +++ b/source/tests/dpmodel/__init__.py @@ -0,0 +1 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later diff --git a/source/tests/dpmodel/test_pairtab.py b/source/tests/dpmodel/test_pairtab.py new file mode 100644 index 0000000000..3713d33510 --- /dev/null +++ b/source/tests/dpmodel/test_pairtab.py @@ -0,0 +1,203 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import unittest +from unittest.mock import ( + patch, +) + +import numpy as np + +from deepmd.dpmodel.model.pair_tab_model import ( + PairTabModel, +) + + +class TestPairTab(unittest.TestCase): + @patch("numpy.loadtxt") + def setUp(self, mock_loadtxt) -> None: + file_path = "dummy_path" + mock_loadtxt.return_value = np.array( + [ + [0.005, 1.0, 2.0, 3.0], + [0.01, 0.8, 1.6, 2.4], + [0.015, 0.5, 1.0, 1.5], + [0.02, 0.25, 0.4, 0.75], + ] + ) + + self.model = PairTabModel(tab_file=file_path, rcut=0.02, sel=2) + + self.extended_coord = np.array( + [ + [ + [0.01, 0.01, 0.01], + [0.01, 0.02, 0.01], + [0.01, 0.01, 0.02], + [0.02, 0.01, 0.01], + ], + [ + [0.01, 0.01, 0.01], + [0.01, 0.02, 0.01], + [0.01, 0.01, 0.02], + [0.05, 0.01, 0.01], + ], + ] + ) + + # nframes=2, nall=4 + self.extended_atype = np.array([[0, 1, 0, 1], [0, 0, 1, 1]]) + + # nframes=2, nloc=2, nnei=2 + self.nlist = np.array([[[1, 2], [0, 2]], [[1, 2], [0, 3]]]) + + def test_without_mask(self): + result = self.model.forward_atomic( + self.extended_coord, self.extended_atype, self.nlist + ) + expected_result = np.array([[[1.2000], [1.3614]], [[1.2000], [0.4000]]]) + + np.testing.assert_allclose(result["energy"], expected_result, 0.0001, 0.0001) + + def test_with_mask(self): + self.nlist = np.array([[[1, -1], [0, 2]], [[1, 2], [0, 3]]]) + + result = self.model.forward_atomic( + self.extended_coord, self.extended_atype, self.nlist + ) + expected_result = np.array([[[0.8000], [1.3614]], [[1.2000], [0.4000]]]) + + np.testing.assert_allclose(result["energy"], expected_result, 0.0001, 0.0001) + + def test_deserialize(self): + model1 = PairTabModel.deserialize(self.model.serialize()) + np.testing.assert_allclose(self.model.tab_data, model1.tab_data) + np.testing.assert_allclose(self.model.tab_info, model1.tab_info) + + self.nlist = np.array([[[1, -1], [0, 2]], [[1, 2], [0, 3]]]) + result = model1.forward_atomic( + self.extended_coord, self.extended_atype, self.nlist + ) + expected_result = self.model.forward_atomic( + self.extended_coord, self.extended_atype, self.nlist + ) + + np.testing.assert_allclose( + result["energy"], expected_result["energy"], 0.0001, 0.0001 + ) + + +class TestPairTabTwoAtoms(unittest.TestCase): + @patch("numpy.loadtxt") + def test_extrapolation_nonzero_rmax(self, mock_loadtxt) -> None: + """Scenarios to test. + + rcut < rmax: + rr < rcut: use table values, or interpolate. + rr == rcut: use table values, or interpolate. + rr > rcut: should be 0 + rcut == rmax: + rr < rcut: use table values, or interpolate. + rr == rcut: use table values, or interpolate. + rr > rcut: should be 0 + rcut > rmax: + rr < rmax: use table values, or interpolate. + rr == rmax: use table values, or interpolate. + rmax < rr < rcut: extrapolate + rr >= rcut: should be 0 + + """ + file_path = "dummy_path" + mock_loadtxt.return_value = np.array( + [ + [0.005, 1.0], + [0.01, 0.8], + [0.015, 0.5], + [0.02, 0.25], + ] + ) + + # nframes=1, nall=2 + extended_atype = np.array([[0, 0]]) + + # nframes=1, nloc=2, nnei=1 + nlist = np.array([[[1], [-1]]]) + + results = [] + + for dist, rcut in zip( + [ + 0.01, + 0.015, + 0.020, + 0.015, + 0.02, + 0.021, + 0.015, + 0.02, + 0.021, + 0.025, + 0.026, + 0.025, + 0.025, + 0.0216161, + ], + [ + 0.015, + 0.015, + 0.015, + 0.02, + 0.02, + 0.02, + 0.022, + 0.022, + 0.022, + 0.025, + 0.025, + 0.03, + 0.035, + 0.025, + ], + ): + extended_coord = np.array( + [ + [ + [0.0, 0.0, 0.0], + [0.0, dist, 0.0], + ], + ] + ) + + model = PairTabModel(tab_file=file_path, rcut=rcut, sel=2) + results.append( + model.forward_atomic(extended_coord, extended_atype, nlist)["energy"] + ) + + expected_result = np.stack( + [ + np.array( + [ + [ + [0.4, 0], + [0.0, 0], + [0.0, 0], + [0.25, 0], + [0, 0], + [0, 0], + [0.25, 0], + [0.125, 0], + [0.0922, 0], + [0, 0], + [0, 0], + [0, 0], + [0.0923, 0], + [0.0713, 0], + ] + ] + ) + ] + ).reshape(14, 2) + results = np.stack(results).reshape(14, 2) + + np.testing.assert_allclose(results, expected_result, 0.0001, 0.0001) + + if __name__ == "__main__": + unittest.main() diff --git a/source/tests/pt/test_pairtab.py b/source/tests/pt/test_pairtab.py index b4dbda6702..e27e2cf2a1 100644 --- a/source/tests/pt/test_pairtab.py +++ b/source/tests/pt/test_pairtab.py @@ -7,7 +7,8 @@ import numpy as np import torch -from deepmd.pt.model.model.pair_tab import ( +from deepmd.dpmodel.model.pair_tab_model import PairTabModel as DPPairTabModel +from deepmd.pt.model.model.pair_tab_model import ( PairTabModel, ) @@ -54,9 +55,13 @@ def test_without_mask(self): result = self.model.forward_atomic( self.extended_coord, self.extended_atype, self.nlist ) - expected_result = torch.tensor([[1.2000, 1.3614], [1.2000, 0.4000]]) + expected_result = torch.tensor( + [[[1.2000], [1.3614]], [[1.2000], [0.4000]]], dtype=torch.float64 + ) - torch.testing.assert_allclose(result["energy"], expected_result, 0.0001, 0.0001) + torch.testing.assert_close( + result["energy"], expected_result, rtol=0.0001, atol=0.0001 + ) def test_with_mask(self): self.nlist = torch.tensor([[[1, -1], [0, 2]], [[1, 2], [0, 3]]]) @@ -64,13 +69,56 @@ def test_with_mask(self): result = self.model.forward_atomic( self.extended_coord, self.extended_atype, self.nlist ) - expected_result = torch.tensor([[0.8000, 1.3614], [1.2000, 0.4000]]) + expected_result = torch.tensor( + [[[0.8000], [1.3614]], [[1.2000], [0.4000]]], dtype=torch.float64 + ) - torch.testing.assert_allclose(result["energy"], expected_result, 0.0001, 0.0001) + torch.testing.assert_close( + result["energy"], expected_result, rtol=0.0001, atol=0.0001 + ) def test_jit(self): model = torch.jit.script(self.model) + def test_deserialize(self): + model1 = PairTabModel.deserialize(self.model.serialize()) + torch.testing.assert_close(self.model.tab_data, model1.tab_data) + torch.testing.assert_close(self.model.tab_info, model1.tab_info) + + self.nlist = torch.tensor([[[1, -1], [0, 2]], [[1, 2], [0, 3]]]) + result = model1.forward_atomic( + self.extended_coord, self.extended_atype, self.nlist + ) + expected_result = self.model.forward_atomic( + self.extended_coord, self.extended_atype, self.nlist + ) + + torch.testing.assert_close( + result["energy"], expected_result["energy"], rtol=0.0001, atol=0.0001 + ) + + model1 = torch.jit.script(model1) + + def test_cross_deserialize(self): + model_dict = self.model.serialize() # pytorch model to dict + model1 = DPPairTabModel.deserialize(model_dict) # dict to numpy model + np.testing.assert_allclose(self.model.tab_data, model1.tab_data) + np.testing.assert_allclose(self.model.tab_info, model1.tab_info) + + self.nlist = np.array([[[1, -1], [0, 2]], [[1, 2], [0, 3]]]) + result = model1.forward_atomic( + self.extended_coord.numpy(), + self.extended_atype.numpy(), + self.nlist, + ) + expected_result = self.model.forward_atomic( + self.extended_coord, self.extended_atype, torch.from_numpy(self.nlist) + ) + + np.testing.assert_allclose( + result["energy"], expected_result["energy"], 0.0001, 0.0001 + ) + class TestPairTabTwoAtoms(unittest.TestCase): @patch("numpy.loadtxt") @@ -178,13 +226,14 @@ def test_extrapolation_nonzero_rmax(self, mock_loadtxt) -> None: [0.0923, 0], [0.0713, 0], ] - ] + ], + dtype=torch.float64, ) ] ).reshape(14, 2) results = torch.stack(results).reshape(14, 2) - torch.testing.assert_allclose(results, expected_result, 0.0001, 0.0001) + torch.testing.assert_close(results, expected_result, rtol=0.0001, atol=0.0001) if __name__ == "__main__": unittest.main()