Skip to content

Commit

Permalink
Remove torch.unique for finding the maximum three body index and litt…
Browse files Browse the repository at this point in the history
…le cleanup in united tests (#161)

* Optimize the Atoms2Graph and fixed the np.meshgrid

* put unittests

* improve the _three_body.py and test_M3GNetCalculator in test_ase.py

* add cpu() in ase.py and compute.py to enable the GPU usage for MatGL-LAMMPS interface

* included the unit-test for hessian test_ase.py to improve the coverage score

* remove reducdant torch.unique for finding the maximum three_body index and little cleanup in united tests

---------

Co-authored-by: Shyue Ping Ong <[email protected]>
  • Loading branch information
kenko911 and shyuep authored Sep 5, 2023
1 parent a74d22a commit 974d333
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 15 deletions.
2 changes: 1 addition & 1 deletion matgl/graph/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def compute_3body(g: dgl.DGLGraph):
src_id = torch.tensor(triple_bond_indices[:, 0], dtype=matgl.int_th)
dst_id = torch.tensor(triple_bond_indices[:, 1], dtype=matgl.int_th)
l_g = dgl.graph((src_id, dst_id))
three_body_id = torch.unique(torch.concatenate(l_g.edges()))
three_body_id = torch.concatenate(l_g.edges())
n_triple_ij = torch.tensor(n_triple_ij, dtype=matgl.int_th)
max_three_body_id = torch.max(three_body_id) + 1 if three_body_id.numel() > 0 else 0
l_g.ndata["bond_dist"] = g.edata["bond_dist"][:max_three_body_id]
Expand Down
2 changes: 1 addition & 1 deletion matgl/models/_m3gnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ def forward(
l_g.ndata["bond_dist"] = g.edata["bond_dist"][valid_three_body]
l_g.ndata["pbc_offset"] = g.edata["pbc_offset"][valid_three_body]
else:
three_body_id = torch.unique(torch.concatenate(l_g.edges()))
three_body_id = torch.concatenate(l_g.edges())
max_three_body_id = torch.max(three_body_id) + 1 if three_body_id.numel() > 0 else 0
l_g.ndata["bond_vec"] = g.edata["bond_vec"][:max_three_body_id]
l_g.ndata["bond_dist"] = g.edata["bond_dist"][:max_three_body_id]
Expand Down
14 changes: 1 addition & 13 deletions tests/layers/test_graph_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,7 @@
import torch
from torch import nn

from matgl.graph.compute import (
compute_theta_and_phi,
create_line_graph,
)
from matgl.layers import BondExpansion, EmbeddingBlock, SphericalBesselWithHarmonics
from matgl.layers import BondExpansion, EmbeddingBlock
from matgl.layers._graph_convolution import (
MLP,
M3GNetBlock,
Expand Down Expand Up @@ -88,10 +84,6 @@ def test_m3gnet_graph_conv(self, graph_MoS):
bond_expansion = BondExpansion(max_l=3, max_n=3, cutoff=5.0, rbf_type="SphericalBessel", smooth=False)
bond_basis = bond_expansion(bond_dist)
g1.edata["rbf"] = bond_basis
sb_and_sh = SphericalBesselWithHarmonics(max_n=3, max_l=3, cutoff=5.0, use_smooth=False, use_phi=False)
l_g1 = create_line_graph(g1, threebody_cutoff=4.0)
l_g1.apply_edges(compute_theta_and_phi)
sb_and_sh(l_g1)
max_n = 3
max_l = 3
num_node_feats = 16
Expand Down Expand Up @@ -137,10 +129,6 @@ def test_m3gnet_block(self, graph_MoS):
bond_expansion = BondExpansion(max_l=3, max_n=3, cutoff=5.0, rbf_type="SphericalBessel", smooth=False)
bond_basis = bond_expansion(g1.edata["bond_dist"])
g1.edata["rbf"] = bond_basis
sb_and_sh = SphericalBesselWithHarmonics(max_n=3, max_l=3, cutoff=5.0, use_smooth=False, use_phi=False)
l_g1 = create_line_graph(g1, threebody_cutoff=4.0)
l_g1.apply_edges(compute_theta_and_phi)
sb_and_sh(l_g1)
num_node_feats = 16
num_edge_feats = 32
num_state_feats = 64
Expand Down

0 comments on commit 974d333

Please sign in to comment.