From f7ce3172f32dd65eb257d303f7ced1d453c5d85d Mon Sep 17 00:00:00 2001 From: "Gao, Xiang" Date: Wed, 1 Apr 2020 02:46:03 -0700 Subject: [PATCH 1/2] Clean up unused args --- torchani/aev.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/torchani/aev.py b/torchani/aev.py index d55160223..c3505313d 100644 --- a/torchani/aev.py +++ b/torchani/aev.py @@ -170,8 +170,7 @@ def neighbor_pairs(padding_mask: Tensor, coordinates: Tensor, cell: Tensor, return molecule_index + atom_index1, molecule_index + atom_index2, shifts -def neighbor_pairs_nopbc(padding_mask: Tensor, coordinates: Tensor, cell: Tensor, - shifts: Tensor, cutoff: float) -> Tuple[Tensor, Tensor, Tensor]: +def neighbor_pairs_nopbc(padding_mask: Tensor, coordinates: Tensor, cutoff: float) -> Tuple[Tensor, Tensor, Tensor]: """Compute pairs of atoms that are neighbors (doesn't use PBC) This function bypasses the calculation of shifts and duplication @@ -199,7 +198,7 @@ def neighbor_pairs_nopbc(padding_mask: Tensor, coordinates: Tensor, cell: Tensor atom_index1 = p1_all[pair_index] + molecule_index atom_index2 = p2_all[pair_index] + molecule_index # shifts - shifts = shifts.new_zeros((p1_all.shape[0], 3)).index_select(0, pair_index) + shifts = p1_all.new_zeros((p1_all.shape[0], 3)).index_select(0, pair_index) return atom_index1, atom_index2, shifts From d5ea109e7ada4aedaaf0c760a69096fde5ef3c21 Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Wed, 1 Apr 2020 02:52:38 -0700 Subject: [PATCH 2/2] fix --- torchani/aev.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchani/aev.py b/torchani/aev.py index c3505313d..27c1391c4 100644 --- a/torchani/aev.py +++ b/torchani/aev.py @@ -274,7 +274,7 @@ def compute_aev(species: Tensor, coordinates: Tensor, cell: Tensor, num_species_pairs = angular_length // angular_sublength # PBC calculation is bypassed if there are no shifts if shifts.numel() == 0: - atom_index1, atom_index2, shifts = neighbor_pairs_nopbc(species == -1, coordinates, cell, shifts, Rcr) + atom_index1, atom_index2, shifts = neighbor_pairs_nopbc(species == -1, coordinates, Rcr) else: atom_index1, atom_index2, shifts = neighbor_pairs(species == -1, coordinates, cell, shifts, Rcr) coordinates = coordinates.flatten(0, 1)