Skip to content

Commit

Permalink
solve conflict
Browse files Browse the repository at this point in the history
  • Loading branch information
Chengqian-Zhang committed Sep 11, 2024
2 parents ba348fd + 4136afa commit c2b8ae1
Showing 1 changed file with 48 additions and 14 deletions.
62 changes: 48 additions & 14 deletions source/tests/pt/model/test_property_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,18 +159,30 @@ def setUp(self) -> None:
self.nf = 1
self.nt = 3
self.rng = np.random.default_rng()
self.coord = torch.tensor([[1.1042, 0.6852, 1.3582],
[1.8812, 1.6277, 0.3153],
[1.5655, 1.0383, 0.4152],
[0.9594, 1.2298, 0.8124],
[0.7905, 0.5014, 0.6654]], dtype=dtype, device=env.DEVICE)
self.coord = torch.tensor(
[
[1.1042, 0.6852, 1.3582],
[1.8812, 1.6277, 0.3153],
[1.5655, 1.0383, 0.4152],
[0.9594, 1.2298, 0.8124],
[0.7905, 0.5014, 0.6654],
],
dtype=dtype,
device=env.DEVICE,
)
self.shift = torch.tensor([1000, 1000, 1000], dtype=dtype, device=env.DEVICE)
self.atype = torch.tensor([0, 0, 0, 1, 1], dtype=torch.int32, device=env.DEVICE)
self.dd0 = DescrptSeA(self.rcut, self.rcut_smth, self.sel).to(env.DEVICE)
self.cell = torch.tensor([[0.7333, 0.9166, 0.6533],
[0.1151, 0.9078, 0.2058],
[0.6907, 0.0370, 0.4863]], dtype=dtype, device=env.DEVICE)
self.cell = (self.cell + self.cell.T) + 5.0 * torch.eye(3, device=env.DEVICE)
self.cell = torch.tensor(
[
[0.7333, 0.9166, 0.6533],
[0.1151, 0.9078, 0.2058],
[0.6907, 0.0370, 0.4863],
],
dtype=dtype,
device=env.DEVICE,
)
self.cell = (self.cell + self.cell.T) + 5.0 * torch.eye(3, device=env.DEVICE)

def test_trans(self):
atype = self.atype.reshape(1, 5)
Expand All @@ -196,7 +208,12 @@ def test_trans(self):
_,
nlist,
) = extend_input_and_build_neighbor_list(
xyz, atype, self.rcut, self.sel, self.dd0.mixed_types(), box=self.cell,
xyz,
atype,
self.rcut,
self.sel,
self.dd0.mixed_types(),
box=self.cell,
)

rd0, gr0, _, _, _ = self.dd0(
Expand All @@ -210,6 +227,7 @@ def test_trans(self):

np.testing.assert_allclose(to_numpy_array(res[0]), to_numpy_array(res[1]))


class TestInvarianceRandomShift(unittest.TestCase):
def setUp(self) -> None:
self.natoms = 5
Expand Down Expand Up @@ -272,7 +290,12 @@ def test_rot(self):
_,
nlist,
) = extend_input_and_build_neighbor_list(
xyz + self.shift, atype, self.rcut, self.sel, self.dd0.mixed_types(), box=cell_rot,
xyz + self.shift,
atype,
self.rcut,
self.sel,
self.dd0.mixed_types(),
box=cell_rot,
)

rd0, gr0, _, _, _ = self.dd0(
Expand Down Expand Up @@ -307,7 +330,12 @@ def test_permu(self):
_,
nlist,
) = extend_input_and_build_neighbor_list(
coord[idx_perm], atype, self.rcut, self.sel, self.dd0.mixed_types(), box=self.cell,
coord[idx_perm],
atype,
self.rcut,
self.sel,
self.dd0.mixed_types(),
box=self.cell,
)

rd0, gr0, _, _, _ = self.dd0(
Expand All @@ -323,7 +351,7 @@ def test_permu(self):
to_numpy_array(res[0][:, idx_perm]),
to_numpy_array(res[1]),
)

def test_trans(self):
atype = self.atype.reshape(1, 5)
coord_s = torch.matmul(
Expand All @@ -348,7 +376,12 @@ def test_trans(self):
_,
nlist,
) = extend_input_and_build_neighbor_list(
xyz, atype, self.rcut, self.sel, self.dd0.mixed_types(), box=self.cell,
xyz,
atype,
self.rcut,
self.sel,
self.dd0.mixed_types(),
box=self.cell,
)

rd0, gr0, _, _, _ = self.dd0(
Expand All @@ -362,6 +395,7 @@ def test_trans(self):

np.testing.assert_allclose(to_numpy_array(res[0]), to_numpy_array(res[1]))


class TestPropertyModel(unittest.TestCase):
def setUp(self):
self.natoms = 5
Expand Down

0 comments on commit c2b8ae1

Please sign in to comment.