Skip to content

Commit

Permalink
Allow conditional numba installtion.
Browse files Browse the repository at this point in the history
  • Loading branch information
mgt16-LANL committed Aug 8, 2023
1 parent a93d9b2 commit 0a3d811
Showing 1 changed file with 38 additions and 7 deletions.
45 changes: 38 additions & 7 deletions architector/io_lig.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,43 @@
from ase.optimize.bfgslinesearch import BFGSLineSearch
import ase.constraints as ase_con
from xtb.ase.calculator import XTB

from numba import jit
import warnings

warnings.filterwarnings('ignore') # Supress numpy warnings.

# Conditional Numba import
has_numba = True
try:
from numba import jit
except:
has_numba = False
def jit(nopython=True):
return nopython


def conditional_decorator(dec, condition):
"""conditional decorator
only apply a decorator if the condition is met
Parmeters
---------
dec : python decorator function
python decorator function
condition : bool
if a condition is met.
Returns
-------
decorator : python function
Either with or without a decorator
"""
def decorator(func):
if not condition:
# Return the function unchanged, not decorated.
return func
return dec(func)
return decorator

def get_oxo_refdict():
"""get_oxo_refdict
Pull the metal-oxo distance reference dictionary.
Expand Down Expand Up @@ -317,7 +348,7 @@ def get_bounds_matrix(allcoords, molgraph, natoms, catoms, shape, ml_dists, vdwr

return LB, UB

@jit(nopython=True)
@conditional_decorator(jit(nopython=True),has_numba)
def triangle(LB, UB, natoms):
"""triangle
Triangle inequality bounds smoothing. From ref [2], pp. 252-253.
Expand Down Expand Up @@ -401,7 +432,7 @@ def metrize(LB, UB, natoms, non_triangle=False, debug=False):
D[j][natoms-1] = D[natoms-1][j]
return D

@jit(nopython=True)
@conditional_decorator(jit(nopython=True),has_numba)
def get_cm_dists(D, natoms):
"""get_cm_dists
Get distances of each atom to center of mass given the distance matrix.
Expand Down Expand Up @@ -431,7 +462,7 @@ def get_cm_dists(D, natoms):
D0[i] = np.sqrt(D0[i])
return D0

@jit(nopython=True)
@conditional_decorator(jit(nopython=True),has_numba)
def get_metric_matrix(D, D0, natoms):
"""get_metric_matrix
Get metric matrix from distance matrix and cm distances
Expand Down Expand Up @@ -483,7 +514,7 @@ def get_3_eigs(G, natoms):
V[:, i] = v[:, natoms-1-i]
return L, V

@jit(nopython=True)
@conditional_decorator(jit(nopython=True),has_numba)
def distance_error(x, *args):
"""distance_error
Computes distance error function for scipy optimization.
Expand Down Expand Up @@ -514,7 +545,7 @@ def distance_error(x, *args):
E += (2*lij**2/(lij**2 + dij**2) - 1)**2
return E

@jit(nopython=True)
@conditional_decorator(jit(nopython=True),has_numba)
def dist_error_gradient(x, *args):
"""dist_error_gradient
Computes gradient of distance error function for scipy optimization.
Expand Down

0 comments on commit 0a3d811

Please sign in to comment.