Skip to content

Commit

Permalink
Fix GPU UTs (#3203)
Browse files Browse the repository at this point in the history
This PR fixes GPU UTs;
Delete the PREPROCESS_DEVICE in torch data preprocess and use training
DEVICE instead, which will be removed after the dataset is refomated.

---------

Signed-off-by: Jinzhe Zeng <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Jinzhe Zeng <[email protected]>
Co-authored-by: Han Wang <[email protected]>
Co-authored-by: Han Wang <[email protected]>
  • Loading branch information
5 people authored Jan 31, 2024
1 parent b800043 commit 7f069cc
Show file tree
Hide file tree
Showing 15 changed files with 131 additions and 153 deletions.
1 change: 1 addition & 0 deletions .github/workflows/test_cuda.yml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ jobs:
DP_BUILD_TESTING: 1
DP_VARIANT: cuda
CUDA_PATH: /usr/local/cuda-12.2
NUM_WORKERS: 0
- run: dp --version
- run: python -m pytest -s --cov=deepmd source/tests --durations=0
- run: source/install/test_cc_local.sh
Expand Down
2 changes: 0 additions & 2 deletions deepmd/pt/utils/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,13 +276,11 @@ def collate_batch(batch):
result[key] = torch.zeros(
(n_frames, natoms_extended, 3),
dtype=env.GLOBAL_PT_FLOAT_PRECISION,
device=env.PREPROCESS_DEVICE,
)
else:
result[key] = torch.zeros(
(n_frames, natoms_extended),
dtype=torch.long,
device=env.PREPROCESS_DEVICE,
)
for i in range(len(batch)):
natoms_tmp = list[i].shape[0]
Expand Down
57 changes: 14 additions & 43 deletions deepmd/pt/utils/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,21 +477,15 @@ def preprocess(self, batch):
if "find_" in kk:
pass
else:
batch[kk] = torch.tensor(
batch[kk],
dtype=env.GLOBAL_PT_FLOAT_PRECISION,
device=env.PREPROCESS_DEVICE,
)
batch[kk] = torch.tensor(batch[kk], dtype=env.GLOBAL_PT_FLOAT_PRECISION)
if self._data_dict[kk]["atomic"]:
batch[kk] = batch[kk].view(
n_frames, -1, self._data_dict[kk]["ndof"]
)

for kk in ["type", "real_natoms_vec"]:
if kk in batch.keys():
batch[kk] = torch.tensor(
batch[kk], dtype=torch.long, device=env.PREPROCESS_DEVICE
)
batch[kk] = torch.tensor(batch[kk], dtype=torch.long)
batch["atype"] = batch.pop("type")

keys = ["nlist", "nlist_loc", "nlist_type", "shift", "mapping"]
Expand Down Expand Up @@ -524,13 +518,9 @@ def preprocess(self, batch):
batch["nlist_type"] = nlist_type
natoms_extended = max([item.shape[0] for item in shift])
batch["shift"] = torch.zeros(
(n_frames, natoms_extended, 3),
dtype=env.GLOBAL_PT_FLOAT_PRECISION,
device=env.PREPROCESS_DEVICE,
)
batch["mapping"] = torch.zeros(
(n_frames, natoms_extended), dtype=torch.long, device=env.PREPROCESS_DEVICE
(n_frames, natoms_extended, 3), dtype=env.GLOBAL_PT_FLOAT_PRECISION
)
batch["mapping"] = torch.zeros((n_frames, natoms_extended), dtype=torch.long)
for i in range(len(shift)):
natoms_tmp = shift[i].shape[0]
batch["shift"][i, :natoms_tmp] = shift[i]
Expand Down Expand Up @@ -566,17 +556,13 @@ def single_preprocess(self, batch, sid):
pass
else:
batch[kk] = torch.tensor(
batch[kk][sid],
dtype=env.GLOBAL_PT_FLOAT_PRECISION,
device=env.PREPROCESS_DEVICE,
batch[kk][sid], dtype=env.GLOBAL_PT_FLOAT_PRECISION
)
if self._data_dict[kk]["atomic"]:
batch[kk] = batch[kk].view(-1, self._data_dict[kk]["ndof"])
for kk in ["type", "real_natoms_vec"]:
if kk in batch.keys():
batch[kk] = torch.tensor(
batch[kk][sid], dtype=torch.long, device=env.PREPROCESS_DEVICE
)
batch[kk] = torch.tensor(batch[kk][sid], dtype=torch.long)
clean_coord = batch.pop("coord")
clean_type = batch.pop("type")
nloc = clean_type.shape[0]
Expand Down Expand Up @@ -670,30 +656,22 @@ def single_preprocess(self, batch, sid):
NotImplementedError(f"Unknown noise type {self.noise_type}!")
noised_coord = _clean_coord.clone().detach()
noised_coord[coord_mask] += noise_on_coord
batch["coord_mask"] = torch.tensor(
coord_mask, dtype=torch.bool, device=env.PREPROCESS_DEVICE
)
batch["coord_mask"] = torch.tensor(coord_mask, dtype=torch.bool)
else:
noised_coord = _clean_coord
batch["coord_mask"] = torch.tensor(
np.zeros_like(coord_mask, dtype=bool),
dtype=torch.bool,
device=env.PREPROCESS_DEVICE,
np.zeros_like(coord_mask, dtype=bool), dtype=torch.bool
)

# add mask for type
if self.mask_type:
masked_type = clean_type.clone().detach()
masked_type[type_mask] = self.mask_type_idx
batch["type_mask"] = torch.tensor(
type_mask, dtype=torch.bool, device=env.PREPROCESS_DEVICE
)
batch["type_mask"] = torch.tensor(type_mask, dtype=torch.bool)
else:
masked_type = clean_type
batch["type_mask"] = torch.tensor(
np.zeros_like(type_mask, dtype=bool),
dtype=torch.bool,
device=env.PREPROCESS_DEVICE,
np.zeros_like(type_mask, dtype=bool), dtype=torch.bool
)
if self.pbc:
_coord = normalize_coord(noised_coord, region, nloc)
Expand Down Expand Up @@ -803,7 +781,7 @@ def __len__(self):
def __getitem__(self, index):
"""Get a frame from the selected system."""
b_data = self._data_system._get_item(index)
b_data["natoms"] = torch.tensor(self._natoms_vec, device=env.PREPROCESS_DEVICE)
b_data["natoms"] = torch.tensor(self._natoms_vec)
return b_data


Expand Down Expand Up @@ -878,9 +856,7 @@ def __getitem__(self, index=None):
if index is None:
index = dp_random.choice(np.arange(self.nsystems), p=self.probs)
b_data = self._data_systems[index].get_batch(self._batch_size)
b_data["natoms"] = torch.tensor(
self._natoms_vec[index], device=env.PREPROCESS_DEVICE
)
b_data["natoms"] = torch.tensor(self._natoms_vec[index])
batch_size = b_data["coord"].shape[0]
b_data["natoms"] = b_data["natoms"].unsqueeze(0).expand(batch_size, -1)
return b_data
Expand All @@ -891,9 +867,7 @@ def get_training_batch(self, index=None):
if index is None:
index = dp_random.choice(np.arange(self.nsystems), p=self.probs)
b_data = self._data_systems[index].get_batch_for_train(self._batch_size)
b_data["natoms"] = torch.tensor(
self._natoms_vec[index], device=env.PREPROCESS_DEVICE
)
b_data["natoms"] = torch.tensor(self._natoms_vec[index])
batch_size = b_data["coord"].shape[0]
b_data["natoms"] = b_data["natoms"].unsqueeze(0).expand(batch_size, -1)
return b_data
Expand All @@ -902,10 +876,7 @@ def get_batch(self, sys_idx=None):
"""TF-compatible batch for testing."""
pt_batch = self[sys_idx]
np_batch = {}
for key in ["coord", "box", "force", "energy", "virial"]:
if key in pt_batch.keys():
np_batch[key] = pt_batch[key].cpu().numpy()
for key in ["atype", "natoms"]:
for key in ["coord", "box", "force", "energy", "virial", "atype", "natoms"]:
if key in pt_batch.keys():
np_batch[key] = pt_batch[key].cpu().numpy()
batch_size = pt_batch["coord"].shape[0]
Expand Down
5 changes: 0 additions & 5 deletions deepmd/pt/utils/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,6 @@
else:
DEVICE = torch.device(f"cuda:{LOCAL_RANK}")

if os.environ.get("PREPROCESS_DEVICE") == "gpu":
PREPROCESS_DEVICE = torch.device(f"cuda:{LOCAL_RANK}")
else:
PREPROCESS_DEVICE = torch.device("cpu")

JIT = False
CACHE_PER_SYS = 5 # keep at most so many sets per sys in memory
ENERGY_BIAS_TRAINABLE = True
Expand Down
47 changes: 15 additions & 32 deletions deepmd/pt/utils/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def build_inside_clist(coord, region: Region3D, ncell):
cell_offset[cell_offset < 0] = 0
delta = cell_offset - ncell
a2c = compute_serial_cid(cell_offset, ncell) # cell id of atoms
arange = torch.arange(0, loc_ncell, 1, device=env.PREPROCESS_DEVICE)
arange = torch.arange(0, loc_ncell, 1)
cellid = a2c == arange.unsqueeze(-1) # one hot cellid
c2a = cellid.nonzero()
lst = []
Expand Down Expand Up @@ -131,18 +131,12 @@ def append_neighbors(coord, region: Region3D, atype, rcut: float):

# add ghost atoms
a2c, c2a = build_inside_clist(coord, region, ncell)
xi = torch.arange(-ngcell[0], ncell[0] + ngcell[0], 1, device=env.PREPROCESS_DEVICE)
yi = torch.arange(-ngcell[1], ncell[1] + ngcell[1], 1, device=env.PREPROCESS_DEVICE)
zi = torch.arange(-ngcell[2], ncell[2] + ngcell[2], 1, device=env.PREPROCESS_DEVICE)
xyz = xi.view(-1, 1, 1, 1) * torch.tensor(
[1, 0, 0], dtype=torch.long, device=env.PREPROCESS_DEVICE
)
xyz = xyz + yi.view(1, -1, 1, 1) * torch.tensor(
[0, 1, 0], dtype=torch.long, device=env.PREPROCESS_DEVICE
)
xyz = xyz + zi.view(1, 1, -1, 1) * torch.tensor(
[0, 0, 1], dtype=torch.long, device=env.PREPROCESS_DEVICE
)
xi = torch.arange(-ngcell[0], ncell[0] + ngcell[0], 1)
yi = torch.arange(-ngcell[1], ncell[1] + ngcell[1], 1)
zi = torch.arange(-ngcell[2], ncell[2] + ngcell[2], 1)
xyz = xi.view(-1, 1, 1, 1) * torch.tensor([1, 0, 0], dtype=torch.long)
xyz = xyz + yi.view(1, -1, 1, 1) * torch.tensor([0, 1, 0], dtype=torch.long)
xyz = xyz + zi.view(1, 1, -1, 1) * torch.tensor([0, 0, 1], dtype=torch.long)
xyz = xyz.view(-1, 3)
mask_a = (xyz >= 0).all(dim=-1)
mask_b = (xyz < ncell).all(dim=-1)
Expand All @@ -165,9 +159,7 @@ def append_neighbors(coord, region: Region3D, atype, rcut: float):
merged_coord = torch.cat([coord, tmp_coord])
merged_coord_shift = torch.cat([torch.zeros_like(coord), coord_shift[tmp]])
merged_atype = torch.cat([atype, tmp_atype])
merged_mapping = torch.cat(
[torch.arange(atype.numel(), device=env.PREPROCESS_DEVICE), aid]
)
merged_mapping = torch.cat([torch.arange(atype.numel()), aid])
return merged_coord_shift, merged_atype, merged_mapping


Expand All @@ -188,22 +180,16 @@ def build_neighbor_list(
distance = coord_l - coord_r
distance = torch.linalg.norm(distance, dim=-1)
DISTANCE_INF = distance.max().detach() + rcut
distance[:nloc, :nloc] += (
torch.eye(nloc, dtype=torch.bool, device=env.PREPROCESS_DEVICE) * DISTANCE_INF
)
distance[:nloc, :nloc] += torch.eye(nloc, dtype=torch.bool) * DISTANCE_INF
if min_check:
if distance.min().abs() < 1e-6:
RuntimeError("Atom dist too close!")
if not type_split:
sec = sec[-1:]
lst = []
nlist = torch.zeros((nloc, sec[-1].item()), device=env.PREPROCESS_DEVICE).long() - 1
nlist_loc = (
torch.zeros((nloc, sec[-1].item()), device=env.PREPROCESS_DEVICE).long() - 1
)
nlist_type = (
torch.zeros((nloc, sec[-1].item()), device=env.PREPROCESS_DEVICE).long() - 1
)
nlist = torch.zeros((nloc, sec[-1].item())).long() - 1
nlist_loc = torch.zeros((nloc, sec[-1].item())).long() - 1
nlist_type = torch.zeros((nloc, sec[-1].item())).long() - 1
for i, nnei in enumerate(sec):
if i > 0:
nnei = nnei - sec[i - 1]
Expand All @@ -216,11 +202,8 @@ def build_neighbor_list(
_sorted, indices = torch.topk(tmp, nnei, dim=1, largest=False)
else:
# when nnei > nall
indices = torch.zeros((nloc, nnei), device=env.PREPROCESS_DEVICE).long() - 1
_sorted = (
torch.ones((nloc, nnei), device=env.PREPROCESS_DEVICE).long()
* DISTANCE_INF
)
indices = torch.zeros((nloc, nnei)).long() - 1
_sorted = torch.ones((nloc, nnei)).long() * DISTANCE_INF
_sorted_nnei, indices_nnei = torch.topk(
tmp, tmp.shape[1], dim=1, largest=False
)
Expand Down Expand Up @@ -284,7 +267,7 @@ def make_env_mat(
else:
merged_coord_shift = torch.zeros_like(coord)
merged_atype = atype.clone()
merged_mapping = torch.arange(atype.numel(), device=env.PREPROCESS_DEVICE)
merged_mapping = torch.arange(atype.numel())
merged_coord = coord.clone()

# build nlist
Expand Down
7 changes: 1 addition & 6 deletions deepmd/pt/utils/stat.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,14 +62,9 @@ def make_stat_input(datasets, dataloaders, nbatches):
shape = torch.zeros(
(n_frames, extend, 3),
dtype=env.GLOBAL_PT_FLOAT_PRECISION,
device=env.PREPROCESS_DEVICE,
)
else:
shape = torch.zeros(
(n_frames, extend),
dtype=torch.long,
device=env.PREPROCESS_DEVICE,
)
shape = torch.zeros((n_frames, extend), dtype=torch.long)
for i in range(len(item)):
natoms_tmp = l[i].shape[0]
shape[i, :natoms_tmp] = l[i]
Expand Down
23 changes: 14 additions & 9 deletions source/tests/pt/test_descriptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
)
from deepmd.pt.utils import (
dp_random,
env,
)
from deepmd.pt.utils.dataset import (
DeepmdDataSet,
Expand Down Expand Up @@ -112,29 +113,33 @@ def setUp(self):

def test_consistency(self):
avg_zero = torch.zeros(
[self.ntypes, self.nnei * 4], dtype=GLOBAL_PT_FLOAT_PRECISION
[self.ntypes, self.nnei * 4],
dtype=GLOBAL_PT_FLOAT_PRECISION,
device=env.DEVICE,
)
std_ones = torch.ones(
[self.ntypes, self.nnei * 4], dtype=GLOBAL_PT_FLOAT_PRECISION
[self.ntypes, self.nnei * 4],
dtype=GLOBAL_PT_FLOAT_PRECISION,
device=env.DEVICE,
)
base_d, base_force, nlist = base_se_a(
rcut=self.rcut,
rcut_smth=self.rcut_smth,
sel=self.sel,
batch=self.np_batch,
mean=avg_zero,
stddev=std_ones,
mean=avg_zero.detach().cpu(),
stddev=std_ones.detach().cpu(),
)

pt_coord = self.pt_batch["coord"]
pt_coord = self.pt_batch["coord"].to(env.DEVICE)
pt_coord.requires_grad_(True)
index = self.pt_batch["mapping"].unsqueeze(-1).expand(-1, -1, 3)
index = self.pt_batch["mapping"].unsqueeze(-1).expand(-1, -1, 3).to(env.DEVICE)
extended_coord = torch.gather(pt_coord, dim=1, index=index)
extended_coord = extended_coord - self.pt_batch["shift"]
extended_coord = extended_coord - self.pt_batch["shift"].to(env.DEVICE)
my_d, _, _ = prod_env_mat_se_a(
extended_coord.to(DEVICE),
self.pt_batch["nlist"],
self.pt_batch["atype"],
self.pt_batch["nlist"].to(env.DEVICE),
self.pt_batch["atype"].to(env.DEVICE),
avg_zero.reshape([-1, self.nnei, 4]).to(DEVICE),
std_ones.reshape([-1, self.nnei, 4]).to(DEVICE),
self.rcut,
Expand Down
8 changes: 4 additions & 4 deletions source/tests/pt/test_descriptor_dpa1.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ def test_descriptor_block(self):
dparams["ntypes"] = ntypes
des = DescrptBlockSeAtten(
**dparams,
)
).to(env.DEVICE)
des.load_state_dict(torch.load(self.file_model_param))
rcut = dparams["rcut"]
nsel = dparams["sel"]
Expand All @@ -260,7 +260,7 @@ def test_descriptor_block(self):
extended_coord, extended_atype, nloc, rcut, nsel, distinguish_types=False
)
# handel type_embedding
type_embedding = TypeEmbedNet(ntypes, 8)
type_embedding = TypeEmbedNet(ntypes, 8).to(env.DEVICE)
type_embedding.load_state_dict(torch.load(self.file_type_embed))

## to save model parameters
Expand Down Expand Up @@ -293,7 +293,7 @@ def test_descriptor(self):
dparams["concat_output_tebd"] = False
des = DescrptDPA1(
**dparams,
)
).to(env.DEVICE)
target_dict = des.state_dict()
source_dict = torch.load(self.file_model_param)
type_embd_dict = torch.load(self.file_type_embed)
Expand Down Expand Up @@ -337,7 +337,7 @@ def test_descriptor(self):
dparams["concat_output_tebd"] = True
des = DescrptDPA1(
**dparams,
)
).to(env.DEVICE)
descriptor, env_mat, diff, rot_mat, sw = des(
extended_coord,
extended_atype,
Expand Down
Loading

0 comments on commit 7f069cc

Please sign in to comment.