Skip to content

Commit

Permalink
Small modifications in code and comments to clarify (#11) (#539)
Browse files Browse the repository at this point in the history
Co-authored-by: Ignacio Pickering <[email protected]>
  • Loading branch information
zasdfgbnm and IgnacioJPickering committed Nov 13, 2020
1 parent 25bd59f commit 2e5032a
Showing 1 changed file with 12 additions and 12 deletions.
24 changes: 12 additions & 12 deletions torchani/aev.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,17 +36,17 @@ def radial_terms(Rcr: float, EtaR: Tensor, ShfR: Tensor, distances: Tensor) -> T
.. _ANI paper:
http://pubs.rsc.org/en/Content/ArticleLanding/2017/SC/C6SC05720A#!divAbstract
"""
distances = distances.unsqueeze(-1).unsqueeze(-1)
distances = distances.view(-1, 1, 1)
fc = cutoff_cosine(distances, Rcr)
# Note that in the equation in the paper there is no 0.25
# coefficient, but in NeuroChem there is such a coefficient.
# We choose to be consistent with NeuroChem instead of the paper here.
ret = 0.25 * torch.exp(-EtaR * (distances - ShfR)**2) * fc
# At this point, ret now have shape
# (conformations, atoms, N, ?, ?) where ? depend on constants.
# We then should flat the last 2 dimensions to view the subAEV as one
# dimension vector
return ret.flatten(start_dim=-2)
# At this point, ret now has shape
# (conformations x atoms, ?, ?) where ? depend on constants.
# We then should flat the last 2 dimensions to view the subAEV as a two
# dimensional tensor (onnx doesn't support negative indices in flatten)
return ret.flatten(start_dim=1)


def angular_terms(Rca: float, ShfZ: Tensor, EtaA: Tensor, Zeta: Tensor,
Expand All @@ -63,7 +63,7 @@ def angular_terms(Rca: float, ShfZ: Tensor, EtaA: Tensor, Zeta: Tensor,
.. _ANI paper:
http://pubs.rsc.org/en/Content/ArticleLanding/2017/SC/C6SC05720A#!divAbstract
"""
vectors12 = vectors12.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
vectors12 = vectors12.view(2, -1, 3, 1, 1, 1, 1)
distances12 = vectors12.norm(2, dim=-5)

cos_angles = vectors12.prod(0).sum(1) / distances12.prod(0)
Expand All @@ -74,11 +74,11 @@ def angular_terms(Rca: float, ShfZ: Tensor, EtaA: Tensor, Zeta: Tensor,
factor1 = ((1 + torch.cos(angles - ShfZ)) / 2) ** Zeta
factor2 = torch.exp(-EtaA * (distances12.sum(0) / 2 - ShfA) ** 2)
ret = 2 * factor1 * factor2 * fcj12.prod(0)
# At this point, ret now have shape
# (conformations, atoms, N, ?, ?, ?, ?) where ? depend on constants.
# We then should flat the last 4 dimensions to view the subAEV as one
# dimension vector
return ret.flatten(start_dim=-4)
# At this point, ret now has shape
# (conformations x atoms, ?, ?, ?, ?) where ? depend on constants.
# We then should flat the last 4 dimensions to view the subAEV as a two
# dimensional tensor (onnx doesn't support negative indices in flatten)
return ret.flatten(start_dim=1)


def compute_shifts(cell: Tensor, pbc: Tensor, cutoff: float) -> Tensor:
Expand Down

0 comments on commit 2e5032a

Please sign in to comment.