Skip to content

Commit

Permalink
Improve the implementation of three-body interactions
Browse files Browse the repository at this point in the history
  • Loading branch information
kenko911 committed Jul 3, 2024
1 parent e144ae2 commit 9b17963
Showing 1 changed file with 41 additions and 22 deletions.
63 changes: 41 additions & 22 deletions src/matgl/layers/_three_body.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ class ThreeBodyInteractions(nn.Module):
"""Include 3D interactions to the bond update."""

def __init__(self, update_network_atom: nn.Module, update_network_bond: nn.Module, **kwargs):
"""Init ThreeBodyInteractions.
"""
Initialize ThreeBodyInteractions.
Args:
update_network_atom: MLP for node features in Eq.2
Expand All @@ -31,45 +32,63 @@ def __init__(self, update_network_atom: nn.Module, update_network_bond: nn.Modul
self.update_network_bond = update_network_bond

def forward(
self,
graph: dgl.DGLGraph,
line_graph: dgl.DGLGraph,
three_basis: torch.Tensor,
three_cutoff: float,
node_feat: torch.Tensor,
edge_feat: torch.Tensor,
self,
graph: dgl.DGLGraph,
line_graph: dgl.DGLGraph,
three_basis: torch.Tensor,
three_cutoff: torch.Tensor,
node_feat: torch.Tensor,
edge_feat: torch.Tensor,
):
"""
Forward function for ThreeBodyInteractions.
Args:
graph: dgl graph
line_graph: line graph.
line_graph: line graph
three_basis: three body basis expansion
three_cutoff: cutoff radius
node_feat: node features
edge_feat: edge features.
edge_feat: edge features
"""
end_atom_index = torch.gather(graph.edges()[1], 0, line_graph.edges()[1].to(torch.int64))
atoms = self.update_network_atom(node_feat)
end_atom_index = torch.unsqueeze(end_atom_index, 1)
atoms = torch.squeeze(atoms[end_atom_index])
basis = three_basis * atoms
three_cutoff = torch.unsqueeze(three_cutoff, dim=1) # type: ignore
weights = three_cutoff[torch.stack(list(line_graph.edges()), dim=1)].view(-1, 2) # type: ignore
weights = torch.prod(weights, dim=-1) # type: ignore
# Get the indices of the end atoms for each bond in the line graph
end_atom_indices = graph.edges()[1][line_graph.edges()[1]].to(matgl.int_th)

# Update node features using the atom update network
updated_atoms = self.update_network_atom(node_feat)

# Gather updated atom features for the end atoms
end_atom_features = updated_atoms[end_atom_indices]

# Compute the basis term
basis = three_basis * end_atom_features

# Reshape and compute weights based on the three-cutoff tensor
three_cutoff = three_cutoff.unsqueeze(1)
edge_indices = torch.stack(list(line_graph.edges()), dim=1)
weights = three_cutoff[edge_indices].view(-1, 2)
weights = weights.prod(dim=-1)

# Compute the weighted basis
basis = basis * weights[:, None]

# Aggregate the new bonds using scatter_sum
segment_ids = get_segment_indices_from_n(line_graph.ndata["n_triple_ij"])
new_bonds = scatter_sum(
basis.to(matgl.float_th),
segment_ids=get_segment_indices_from_n(line_graph.ndata["n_triple_ij"]),
segment_ids=segment_ids,
num_segments=graph.num_edges(),
dim=0,
)
if not new_bonds.data.shape[0]:

# If no new bonds are generated, return the original edge features
if new_bonds.shape[0] == 0:
return edge_feat
edge_feat_updated = edge_feat + self.update_network_bond(new_bonds)
return edge_feat_updated

# Update edge features using the bond update network
updated_edge_feat = edge_feat + self.update_network_bond(new_bonds)

return updated_edge_feat

def combine_sbf_shf(sbf, shf, max_n: int, max_l: int, use_phi: bool):
"""Combine the spherical Bessel function and the spherical Harmonics function.
Expand Down

0 comments on commit 9b17963

Please sign in to comment.