Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement caching for Clebsch-Gordan matrix construction. #37

Merged
merged 6 commits into from
Jun 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 93 additions & 0 deletions anisoap/utils/cyclic_list.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
from typing import List


class CGRCacheList:
"""
This is a simple class that only exists to be used as a "private" cache
for computed ClebschGordanReal matrices (self._cg)

It only stores last n CG matrices computed for distinct l_max values. Since
it is specialized to store (l_max, cg matrix) pair, it is NOT supposed to be
used to cache any other things. While type of "value" (in `(key, value)` pair)
does not matter as much, the type of `key` MUST BE A 32-BIT INTEGER with
most significant bit unused, as it performs bitwise operations on the most
significant bit to store and check the replacement flags.

It will replace the entries using something called "clock algorithm" for page
replacement, which mimics the "least recently used" algorithm but with less
computation requirement.

[Link to clock algorithm](https://en.wikipedia.org/wiki/Page_replacement_algorithm#Clock)
"""

# Perform bitwise AND with this number to remove the replacement flag (set to 0).
_REMOVE_RFLAG = 0x7FFFFFFF
# Perform bitwise AND with this number to get replacement flag (0: replace, not 0: do not replace)
# Or perform bitwise OR then assign the value to set the flag to 1.
_GET_RFLAG = 0x80000000

def __init__(self, size: int):
"""
A constructor that makes an empty cyclic list.
"""
self._size = size
self.clear_cache()

def keys(self) -> List[int]:
"""
Return list of keys, with keys
"""
return [
entry[0] & CGRCacheList._REMOVE_RFLAG
for entry in self._cyclic_list
if entry is not None
]

def insert(self, key: int, value: dict) -> None:
"""
This insert algorithm mimics the "clock algorithm" for page replacement
technique. The algorithm works like this:
1. Get the entry that the "clock hand" (self._ins_index) points to.
2. Keep advancing the "clock hand" until key with replacement flag = 0
is found. Replacement flag, in this case, is the most significant bit
3. Insert the new [key, value] pair into the list, with replacement flag = 1.
Advance "clock hand" by one position.
Note that clock hand wraps around the list, should it reach the end of the list.
"""
if key not in self.keys():
curr_entry = self._cyclic_list[self._ins_index]

# Advance _ins_index until replacement flag is 0.
while (
curr_entry is not None and curr_entry[0] & CGRCacheList._GET_RFLAG != 0
):
curr_entry[0] = curr_entry[0] & CGRCacheList._REMOVE_RFLAG
self._ins_index = (self._ins_index + 1) % self._size
curr_entry = self._cyclic_list[self._ins_index]

# Store [key with replace_flag = 0, value] pair in cyclic list
self._cyclic_list[self._ins_index] = [key, value]

# Advance the "clock hand" once more after insertion.
self._ins_index = (self._ins_index + 1) % self._size

def get_val(self, key: int) -> dict:
"""
Obtains the value for given key. Raises IndexError if the given key
is not found in the list.
"""
for element in self._cyclic_list:
if element is not None and key == element[0] & CGRCacheList._REMOVE_RFLAG:
element[0] = element[0] | CGRCacheList._GET_RFLAG
return element[1]
raise IndexError(
f"The specified key {key} is not in the list. Current keys in the list are: {self.keys()}"
)

def clear_cache(self) -> None:
"""
Resets the cache (makes list empty)
"""
self._ins_index = 0
# will be a list of list [key | replace_flag, value].
self._cyclic_list: list = [None] * self._size
27 changes: 25 additions & 2 deletions anisoap/utils/metatensor_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,33 @@
TensorMap,
)

from .cyclic_list import CGRCacheList


class ClebschGordanReal:
# Size of 5 is an arbitrary choice -- it can be changed to any number.
# Set to None to disable caching.
cache_list = CGRCacheList(5)

def __init__(self, l_max):
self._l_max = l_max
self._cg = {}
self._cg = dict()

# Check if the caching feature is activated.
if ClebschGordanReal.cache_list is not None:
# Check if the given l_max is already in the cache.
if self._l_max in ClebschGordanReal.cache_list.keys():
# If so, load from the cache.
self._cg = ClebschGordanReal.cache_list.get_val(self._l_max)
else:
# Otherwise, compute the matrices and store it to the cache.
self._init_cg()
ClebschGordanReal.cache_list.insert(self._l_max, self._cg)
else:
# If caching is deactivated, then just compute the matrices normally.
self._init_cg()

def _init_cg(self):
# real-to-complex and complex-to-real transformations as matrices
r2c = {}
c2r = {}
Expand Down Expand Up @@ -86,7 +107,9 @@ def combine_einsum(self, rho1, rho2, L, combination_string):
for M in range(2 * L + 1):
for m1, m2, cg in self._cg[(l1, l2, L)][M]:
rho[:, M, ...] += np.einsum(
combination_string, rho1[:, m1, ...], rho2[:, m2, ...] * cg
combination_string,
rho1[:, m1, ...],
rho2[:, m2, ...] * cg,
)

return rho
Expand Down
25 changes: 25 additions & 0 deletions tests/test_cgr_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import numpy as np
from anisoap.utils import ClebschGordanReal
from time import perf_counter


class TestCGRCache:
def test_cgr(self):
"""
Tests that the CGR cache is correct, and that using the CGR cache results
in a performance boost.
"""
lmax = 5
start = perf_counter()
mycg1 = ClebschGordanReal(lmax)
end = perf_counter()
time_nocache = end - start
start = perf_counter()
mycg2 = ClebschGordanReal(lmax) # we should expect these to be cached!
end = perf_counter()
time_cache = end - start
assert time_cache < time_nocache
assert mycg1.get_cg().keys() == mycg2.get_cg().keys()
for key1, key2 in zip(mycg1.get_cg(), mycg2.get_cg()):
for entry1, entry2 in zip(mycg1.get_cg()[key1], mycg2.get_cg()[key2]):
assert np.all(entry1 == entry2)
Loading