Skip to content

Commit

Permalink
add weight_on option to SO3
Browse files Browse the repository at this point in the history
  • Loading branch information
qzhu2017 committed Jun 5, 2023
1 parent b9820b2 commit afaf729
Showing 1 changed file with 9 additions and 3 deletions.
12 changes: 9 additions & 3 deletions pyxtal_ff/descriptors/SO3.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,10 @@ class SO3:
rcut: float, cutoff radius for neighbor calculation
alpha: float, gaussian width parameter
derivative: bool, whether to calculate the gradient of not
weight_on: bool, if True, the neighbors with different type will be counted as negative
'''

def __init__(self, nmax=3, lmax=3, rcut=3.5, alpha=2.0, derivative=True, stress=False, cutoff_function='cosine'):
def __init__(self, nmax=3, lmax=3, rcut=3.5, alpha=2.0, derivative=True, stress=False, cutoff_function='cosine', weight_on=False):
# populate attributes
self.nmax = nmax
self.lmax = lmax
Expand All @@ -29,6 +30,7 @@ def __init__(self, nmax=3, lmax=3, rcut=3.5, alpha=2.0, derivative=True, stress=
self.stress = stress
self._type = "SO3"
self.cutoff_function = cutoff_function
self.weight_on = weight_on
return

def __str__(self):
Expand Down Expand Up @@ -176,7 +178,7 @@ def clear_memory(self):
'''
attrs = list(vars(self).keys())
for attr in attrs:
if attr not in {'_nmax', '_lmax', '_rcut', '_alpha', '_derivative', '_stress', '_cutoff_function'}:
if attr not in {'_nmax', '_lmax', '_rcut', '_alpha', '_derivative', '_stress', '_cutoff_function', 'weight_on'}:
delattr(self, attr)
return

Expand Down Expand Up @@ -339,7 +341,11 @@ def build_neighbor_list(self, atom_ids=None):
pos = atoms.positions[j] + np.dot(offset,atoms.get_cell()) - center_atom
center_atoms.append(center_atom)
neighbors.append(pos)
atomic_weights.append(atoms[j].number)
if self.weight_on and atoms[j].number != atoms[i].number:
factor = -1
else:
factor = 1
atomic_weights.append(factor*atoms[j].number)
neighbor_indices.append([i,j])

neighbor_indices = np.array(neighbor_indices, dtype=np.int64)
Expand Down

0 comments on commit afaf729

Please sign in to comment.