Skip to content

Commit

Permalink
Rust moments (#46)
Browse files Browse the repository at this point in the history
* added some test rust-python scripts

* changed the name of pymodule

* unified names

* namechange

* namechange to have module be called anisoap_rust.fibbers

* init commit for rust_moments

* added a timing test

* linter

* Update tests.yml

removed -e.

* linter

* update pyproject.toml

* clean up pyproject.toml

* clean up pyproject.toml again

* pyproject.toml

* deleted rust folder

* made rtol more permissive

* linter

* renamed src folder to rust, and removed old functions in lib.rs

* fixed Cargo.toml

* fixed lib.rs

* made rust_moments a kwarg in edp.transform() so it's actually togglable, and added docstrings

* linter

* time per-call averages rather than totals

* rewrote test_compute_moments to include realistic example
  • Loading branch information
arthur-lin1027 authored Sep 6, 2024
1 parent e00904d commit 3e6080f
Show file tree
Hide file tree
Showing 12 changed files with 335 additions and 64 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ jobs:
run: |
python -m pip install --upgrade pip
pip install -r tests/requirements.txt
pip install -e .
pip install .
- name: Run tests
run: |
pytest -v tests/.
Expand Down
19 changes: 19 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# Cargo.toml
[package]
name = "anisoap_rust"
version = "0.0.0"
edition = "2021"

[dependencies]
pyo3 = "0.21.0"
numpy = "0.21.0"

[lib]
name = "anisoap_rust_lib" # private module to be nested into Python package,
# needs to match the name of the function with the `[#pymodule]` attribute

path = "rust/lib.rs"
crate-type = ["cdylib"] # required for shared library for Python to import from.

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
# See also PyO3 docs on writing Cargo.toml files at https://pyo3.rs
3 changes: 3 additions & 0 deletions MANIFEST.in
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# MANIFEST.in
include Cargo.toml
recursive-include rust *.rs
26 changes: 19 additions & 7 deletions anisoap/representations/ellipsoidal_density_projection.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from itertools import product

import numpy as np
from anisoap_rust_lib import compute_moments
from metatensor import (
Labels,
TensorBlock,
Expand Down Expand Up @@ -30,6 +31,7 @@ def pairwise_ellip_expansion(
sph_to_cart,
radial_basis,
show_progress=False,
rust_moments=True,
):
r"""Computes pairwise expansion
Expand Down Expand Up @@ -61,6 +63,10 @@ def pairwise_ellip_expansion(
appropriately with the cutoff radius, radial basis type.
show_progress : bool
Show progress bar for frame analysis and feature generation
rust_moments : bool
Use the ported rust code, which should result in increased speed. Default = True.
In the future, once we ensure integrity checks with the original python code,
this kwarg will be deprecated, and the rust version will always be used.
Returns
-------
Expand Down Expand Up @@ -136,13 +142,14 @@ def pairwise_ellip_expansion(
constant,
) = radial_basis.compute_gaussian_parameters(r_ij, lengths, rot)

moments = (
np.exp(-0.5 * constant)
* length_norm
* compute_moments_inefficient_implementation(
if rust_moments:
moments = compute_moments(precision, center, maxdeg)
else:
moments = compute_moments_inefficient_implementation(
precision, center, maxdeg=maxdeg
)
)
moments *= np.exp(-0.5 * constant) * length_norm

for l in range(lmax + 1):
deg = l + 2 * (num_ns[l] - 1)
moments_l = moments[: deg + 1, : deg + 1, : deg + 1]
Expand Down Expand Up @@ -545,7 +552,7 @@ def __init__(

self.rotation_key = rotation_key

def transform(self, frames, show_progress=False, normalize=True):
def transform(self, frames, show_progress=False, normalize=True, rust_moments=True):
"""Computes features and gradients for frames
Computes the features and (if compute_gradients == True) gradients
Expand All @@ -558,8 +565,12 @@ def transform(self, frames, show_progress=False, normalize=True):
List containing all ase.Atoms types
show_progress : bool
Show progress bar for frame analysis and feature generation
normalize: bool
normalize : bool
Whether to perform Lowdin Symmetric Orthonormalization or not.
rust_moments : bool
Use the ported rust code, which should result in increased speed. Default = True.
In the future, once we ensure integrity checks with the original python code,
this kwarg will be deprecated, and the rust version will always be used.
Returns
-------
Expand Down Expand Up @@ -638,6 +649,7 @@ def transform(self, frames, show_progress=False, normalize=True):
self.sph_to_cart,
self.radial_basis,
show_progress,
rust_moments=rust_moments,
)

features = contract_pairwise_feat(pairwise_ellip_feat, types, show_progress)
Expand Down
57 changes: 39 additions & 18 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,23 +1,44 @@
[tool.tox]
legacy_tox_ini = """
[tox]
[project]
name = "anisoap"
version = "0.0.0"
requires-python = ">=3.8"
authors = [
{name = "Arthur Lin", email = "[email protected]"},
{name = "Kevin Kazuki Huguenin-Dumittan"},
{name = "Jigyasa Nigam"},
{name = "Yong-Cheol Cho"},
{name = "Lucas Ortengren"},
{name = "Seonwoo Hwang"},
{name = "Rose K. Cersonsky"}
]
description = "A package for computing anisotropic extensions to the SOAP formalism"
readme = "README.md"
license = {file = "LICENSE"}
classifiers = [
"Development Status :: 3 - Alpha",
"Intended Audience :: Science/Research",
"Topic :: Scientific/Engineering",
"License :: OSI Approved :: Apache License 2.0",
"Natural Language :: English",

[testenv:tests]
changedir = tests
deps = -rtests/requirements.txt
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
]

commands =
coverage run -m unittest discover -p "*.py"
coverage xml
[build-system]
requires = ["setuptools", "setuptools-rust"]
build-backend = "setuptools.build_meta"

"""
[tool.setuptools.packages]
# Pure Python packages/modules
find = { where = ["."] }

[tool.coverage.run]
branch = true
data_file = "tests/.coverage"
[[tool.setuptools-rust.ext-modules]]
# Private Rust extension module to be nested into the Python package
target = "anisoap_rust_lib" # The last part of the name (e.g. "_lib") has to match lib.name in Cargo.toml,
# but you can add a prefix to nest it inside of a Python package.
path = "Cargo.toml" # Default value, can be omitted
binding = "PyO3" # Default value, can be omitted

[tool.coverage.report]
include = ["anisoap/*"]

[tool.coverage.xml]
output = "tests/coverage.xml"
175 changes: 175 additions & 0 deletions rust/ellip_expansion/compute_moments.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
use numpy::ndarray::{Array3, ArrayView1, ArrayView2};
use pyo3::exceptions::PyAssertionError;
use pyo3::prelude::*;

/// Compute all moments <x^n0 y^n1 z^n2> for a general dilation matrix.
/// Since this computes moments for all n0, n1, and n2, and stores 0 for some
/// impossible configurations, it may not be memory-efficient.
/// However, this implementation allows simple access to all moments with
/// [n0, n1, n2] indexing like normal arrays.
///
/// # Arguments
/// * `dil_mat` - A symmetric, 3x3 matrix, given by np.ndarray from python side.
/// This function will return Err (exception on Python side) if
/// the matrix is not of size 3x3, not symmetric, or not invertible.
/// * `gau_cen` - A 3-dimensional vector for center of tri-variate Gaussian.
/// * `max_deg` - An integer that represents the maximum degree for which moments
/// must be computed. The given number must be positive; otherwise,
/// it will return Err (exception on Python side).
pub fn compute_moments_rust(
dil_mat: ArrayView2<'_, f64>,
gau_cen: ArrayView1<'_, f64>,
max_deg: i32,
) -> PyResult<Array3<f64>> {
// Check if the dilation matrix is a 3x3 matrix.
if dil_mat.shape() != &[3, 3] {
return Err(PyErr::new::<PyAssertionError, _>(
"Dilation matrix needs to be 3x3",
));
}

// Check if the dilation matrix is symmetric
for i in 0..3 {
for j in 0..3 {
if (dil_mat[[i, j]] - dil_mat[[j, i]]).powi(2) >= 1e-14 {
return Err(PyErr::new::<PyAssertionError, _>(
"Dilation matrix needs to be symmetric",
));
}
}
}

if gau_cen.shape() != &[3] {
return Err(PyErr::new::<PyAssertionError, _>(
"Center of Gaussian has to be given by a 3-dim. vector.",
));
}

if max_deg <= 0 {
return Err(PyErr::new::<PyAssertionError, _>(
"The maximum degree needs to be at least 1.",
));
}

// Unpack three values of Gaussian centers, as they will be frequently
// accessed while calculating moments.
let (a0, a1, a2) = (gau_cen[0], gau_cen[1], gau_cen[2]);

// [a, b, c] <- This is how general symmetric 3x3 matrix look like
// [b, d, e] and we only need 6 out of 9 values to compute entire
// [c, e, f] determinant and inverse.
// These values are cached on stack to remove frequent address
// lookups required for indexing
let (a, b, c, d, e, f) = (
dil_mat[[0, 0]],
dil_mat[[0, 1]],
dil_mat[[0, 2]],
dil_mat[[1, 1]],
dil_mat[[1, 2]],
dil_mat[[2, 2]],
);

// cofNM is determinant of resulting matrix after removing N-th row and
// M-th column, with appropriate sign of (-1)^(row + col)
// (i.e. (N, M) co-factor matrix)
let (cof00, cof01, cof02) = (d * f - e * e, c * e - b * f, b * e - c * d);

// Determinant of entire dilation matrix
let det = a * cof00 + b * cof01 + c * cof02;
if det.abs() < 1e-14 {
return Err(PyErr::new::<PyAssertionError, _>(
"The given dilation matrix is singular.",
));
}

// Compute inverse; but since each we use coefficients a lot for moments
// calculation, each elements will be stored as individual variables.
let (cov00, cov01, cov02, cov11, cov12, cov22) = (
cof00 / det, // Use pre-computed co-factors
cof01 / det,
cof02 / det,
(a * f - c * c) / det, // Computed with co-factors
(b * c - a * e) / det,
(a * d - b * b) / det,
);

// Compute global_factor, a number that must be multiplied by before returning.
// global_factor = (2 PI)^1.5 / SQRT(det|dil_mat|)
// = SQRT(8 PI^3 / det|dil_mat|)
let global_factor = (8.0 * (std::f64::consts::PI).powi(3) / det).sqrt();

// Prepare an empty array to store answers
let max_deg = max_deg as usize;
let mut moments = Array3::<f64>::zeros((max_deg + 1, max_deg + 1, max_deg + 1));

// Initialize degree-1 elements
moments[[0, 0, 0]] = 1.0;
moments[[1, 0, 0]] = a0;
moments[[0, 1, 0]] = a1;
moments[[0, 0, 1]] = a2;

if max_deg > 1 {
// Initialize degree-2 elements
moments[[2, 0, 0]] = cov00 + a0 * a0;
moments[[0, 2, 0]] = cov11 + a1 * a1;
moments[[0, 0, 2]] = cov22 + a2 * a2;
moments[[1, 1, 0]] = cov01 + a0 * a1;
moments[[0, 1, 1]] = cov12 + a1 * a2;
moments[[1, 0, 1]] = cov02 + a0 * a2;
}

if max_deg > 2 {
for deg in 2..max_deg {
for n0 in 0..=deg {
for n1 in 0..=(deg - n0) {
let n2 = deg - n0 - n1; // Forces n0 + n1 + n2 = deg
let (n0_pos, n1_pos, n2_pos) = (n0 > 0, n1 > 0, n2 > 0);
let x_iter_add =
0.0 + if n0_pos {
cov00 * n0 as f64 * moments[[n0 - 1, n1, n2]]
} else {
0.0
} + if n1_pos {
cov01 * n1 as f64 * moments[[n0, n1 - 1, n2]]
} else {
0.0
} + if n2_pos {
cov02 * n2 as f64 * moments[[n0, n1, n2 - 1]]
} else {
0.0
};

// Run the x-iteration
moments[[n0 + 1, n1, n2]] = a0 * moments[[n0, n1, n2]] + x_iter_add;

// Run y-iteration if n0 is 0.
if !n0_pos {
let y_iter_add =
0.0 + if n1_pos {
cov11 * n1 as f64 * moments[[n0, n1 - 1, n2]]
} else {
0.0
} + if n2_pos {
cov12 * n2 as f64 * moments[[n0, n1, n2 - 1]]
} else {
0.0
};
moments[[n0, n1 + 1, n2]] = a1 * moments[[n0, n1, n2]] + y_iter_add;

// Run z-iteration if both n0 and n1 are 0.
if !n1_pos {
moments[[n0, n1, n2 + 1]] = a2 * moments[[n0, n1, n2]]
+ if n2_pos {
cov22 * n2 as f64 * moments[[n0, n1, n2 - 1]]
} else {
0.0
}
}
}
}
}
}
}

Ok(moments * global_factor)
}
1 change: 1 addition & 0 deletions rust/ellip_expansion/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
pub mod compute_moments;
24 changes: 24 additions & 0 deletions rust/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
mod ellip_expansion;

use ellip_expansion::compute_moments::compute_moments_rust;
use numpy::ndarray::{Array2, ArrayView2};
use numpy::{IntoPyArray, PyArray2, PyArray3, PyReadonlyArray1, PyReadonlyArray2, PyUntypedArrayMethods};
use pyo3::exceptions::PyTypeError;
use pyo3::prelude::*;
use pyo3::wrap_pyfunction;


#[pymodule]
fn anisoap_rust_lib(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
#[pyfn(m)]
fn compute_moments<'py>(
py: Python<'py>,
mat: PyReadonlyArray2<'_, f64>,
g_vec: PyReadonlyArray1<'_, f64>,
max_deg: i32,
) -> PyResult<&'py PyArray3<f64>> {
Ok(compute_moments_rust(mat.as_array(), g_vec.as_array(), max_deg)?.into_pyarray(py))
}

Ok(())
}
Loading

0 comments on commit 3e6080f

Please sign in to comment.