Skip to content

Commit

Permalink
Merge pull request #451 from dynamicslab/axesarray-indexing
Browse files Browse the repository at this point in the history
Make AxesArray handle slicing correctly
  • Loading branch information
Jacob-Stevens-Haas authored Jan 31, 2024
2 parents e6752e2 + bc3fbd2 commit 9c73768
Show file tree
Hide file tree
Showing 15 changed files with 1,379 additions and 136 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ venv/
ENV/
env.bak/
venv.bak/
env8

# automatically generated by setuptools-scm
pysindy/version.py
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ docs = [
"ipython",
"pandoc",
"sphinx-rtd-theme",
"sphinx==5.3.0",
"sphinx==7.1.2",
"sphinxcontrib-apidoc",
"nbsphinx"
]
Expand Down
58 changes: 38 additions & 20 deletions pysindy/differentiation/finite_difference.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
from typing import List
from typing import Union

import numpy as np
from numpy.typing import NDArray
from scipy.special import factorial

from .base import BaseDifferentiation
from pysindy.utils.axes import AxesArray


class FiniteDifference(BaseDifferentiation):
Expand Down Expand Up @@ -55,7 +60,7 @@ class FiniteDifference(BaseDifferentiation):
def __init__(
self,
order=2,
d=1,
d: int = 1,
axis=0,
is_uniform=False,
drop_endpoints=False,
Expand Down Expand Up @@ -85,22 +90,27 @@ def __init__(

def _coefficients(self, t):
nt = len(t)
self.stencil_inds = np.array(
[np.arange(i, nt - self.n_stencil + i + 1) for i in range(self.n_stencil)]
self.stencil_inds = AxesArray(
np.array(
[
np.arange(i, nt - self.n_stencil + i + 1)
for i in range(self.n_stencil)
]
),
{"ax_offset": 0, "ax_ti": 1},
)
self.stencil = AxesArray(
np.transpose(t[self.stencil_inds]), {"ax_time": 0, "ax_offset": 1}
)
self.stencil = np.transpose(t[self.stencil_inds])

pows = np.arange(self.n_stencil)[np.newaxis, :, np.newaxis]
matrices = (
dt_endpoints = (
self.stencil
- t[
(self.n_stencil - 1) // 2 : -(self.n_stencil - 1) // 2,
np.newaxis,
]
)[:, np.newaxis, :] ** pows
b = np.zeros(self.n_stencil)
b[self.d] = factorial(self.d)
return np.linalg.solve(matrices, [b])
- t[(self.n_stencil - 1) // 2 : -(self.n_stencil - 1) // 2, "offset"]
)
matrices = dt_endpoints[:, "power", :] ** pows
b = AxesArray(np.zeros((1, self.n_stencil)), {"ax_time": 0, "ax_power": 1})
b[0, self.d] = factorial(self.d)
return np.linalg.solve(matrices, b)

def _coefficients_boundary_forward(self, t):
# use the same stencil for each boundary point,
Expand Down Expand Up @@ -202,23 +212,30 @@ def _constant_coefficients(self, dt):

def _accumulate(self, coeffs, x):
# slice to select the stencil indices
s = [slice(None)] * len(x.shape)
s = [slice(None)] * x.ndim
s[self.axis] = self.stencil_inds

# a new axis is introduced after self.axis for the stencil indices
# a new axis is introduced before self.axis for the stencil indices
# To contract with the coefficients, roll by -self.axis to put it first
# Then roll back by self.axis to return the order
trans = np.roll(np.arange(len(x.shape) + 1), -self.axis)
trans = np.roll(np.arange(x.ndim + 1), -self.axis)
# TODO: assign x's axes much earlier in the call stack
x = AxesArray(x, {"ax_unk": list(range(x.ndim))})
x_expanded = AxesArray(
np.transpose(x[tuple(s)], axes=trans), x.insert_axis(0, "ax_offset")
)
return np.transpose(
np.einsum(
"ij...,ij->j...",
np.transpose(x[tuple(s)], axes=trans),
x_expanded,
np.transpose(coeffs),
),
np.roll(np.arange(len(x.shape)), self.axis),
np.roll(np.arange(x.ndim), self.axis),
)

def _differentiate(self, x, t):
def _differentiate(
self, x: NDArray, t: Union[NDArray, float, List[float]]
) -> NDArray:
"""
Apply finite difference method.
"""
Expand Down Expand Up @@ -249,6 +266,7 @@ def _differentiate(self, x, t):
s[self.axis] = slice(start, stop)
interior = interior + x[tuple(s)] * coeffs[i]
else:
t = AxesArray(np.array(t), axes={"ax_time": 0})
coeffs = self._coefficients(t)
interior = self._accumulate(coeffs, x)
s[self.axis] = slice((self.n_stencil - 1) // 2, -(self.n_stencil - 1) // 2)
Expand Down
9 changes: 4 additions & 5 deletions pysindy/feature_library/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,9 @@ def correct_shape(self, x: AxesArray):
return x

def calc_trajectory(self, diff_method, x, t):
axes = x.__dict__
x_dot = diff_method(x, t=t)
x = AxesArray(diff_method.smoothed_x_, axes)
return x, AxesArray(x_dot, axes)
x = AxesArray(diff_method.smoothed_x_, x.axes)
return x, AxesArray(x_dot, x.axes)

def get_spatial_grid(self):
return None
Expand Down Expand Up @@ -337,7 +336,7 @@ def __init__(
self.libraries = libraries
self.inputs_per_library = inputs_per_library

def _combinations(self, lib_i, lib_j):
def _combinations(self, lib_i: AxesArray, lib_j: AxesArray) -> AxesArray:
"""
Compute combinations of the numerical libraries.
Expand All @@ -351,7 +350,7 @@ def _combinations(self, lib_i, lib_j):
lib_i.shape[lib_i.ax_coord] * lib_j.shape[lib_j.ax_coord]
)
lib_full = np.reshape(
lib_i[..., :, np.newaxis] * lib_j[..., np.newaxis, :],
lib_i[..., :, "coord"] * lib_j[..., "coord", :],
shape,
)

Expand Down
2 changes: 1 addition & 1 deletion pysindy/feature_library/generalized_library.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ def transform(self, x_full):
else:
xps.append(lib.transform([x])[0])

xp = AxesArray(np.concatenate(xps, axis=xps[0].ax_coord), xps[0].__dict__)
xp = AxesArray(np.concatenate(xps, axis=xps[0].ax_coord), xps[0].axes)
xp_full = xp_full + [xp]
return xp_full

Expand Down
14 changes: 4 additions & 10 deletions pysindy/feature_library/pde_library.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,13 +234,7 @@ def get_feature_names(self, input_features=None):
def derivative_string(multiindex):
ret = ""
for axis in range(self.ind_range):
if self.implicit_terms and (
axis
in [
self.spatiotemporal_grid.ax_time,
self.spatiotemporal_grid.ax_sample,
]
):
if self.implicit_terms and (axis == self.spatiotemporal_grid.ax_time,):
str_deriv = "t"
else:
str_deriv = str(axis + 1)
Expand Down Expand Up @@ -345,7 +339,7 @@ def transform(self, x_full):

# derivative terms
shape[-1] = n_features * self.num_derivatives
library_derivatives = np.empty(shape, dtype=x.dtype)
library_derivatives = AxesArray(np.empty(shape, dtype=x.dtype), x.axes)
library_idx = 0
for multiindex in self.multiindices:
derivs = x
Expand Down Expand Up @@ -395,8 +389,8 @@ def transform(self, x_full):
library_idx : library_idx
+ n_library_terms * self.num_derivatives * n_features,
] = np.reshape(
library_functions[..., np.newaxis, :]
* library_derivatives[..., :, np.newaxis],
library_functions[..., "coord", :]
* library_derivatives[..., :, "coord"],
shape,
)
library_idx += n_library_terms * self.num_derivatives * n_features
Expand Down
2 changes: 1 addition & 1 deletion pysindy/feature_library/polynomial_library.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ def transform(self, x_full):
dtype=x.dtype,
order=self.order,
),
x.__dict__,
x.axes,
)
for i, comb in enumerate(combinations):
xp[..., i] = x[..., comb].prod(-1)
Expand Down
2 changes: 1 addition & 1 deletion pysindy/feature_library/sindy_pi_library.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,5 +404,5 @@ def transform(self, x_full):
*[x[:, comb] for comb in f_combs]
) * f_dot(*[x_dot[:, comb] for comb in f_dot_combs])
library_idx += 1
xp_full = xp_full + [AxesArray(xp, x.__dict__)]
xp_full = xp_full + [AxesArray(xp, x.axes)]
return xp_full
41 changes: 26 additions & 15 deletions pysindy/feature_library/weak_pde_library.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from sklearn.utils.validation import check_is_fitted

from ..utils import AxesArray
from ..utils import comprehend_axes
from .base import BaseFeatureLibrary
from .base import x_sequence_or_item
from .polynomial_library import PolynomialLibrary
Expand Down Expand Up @@ -218,7 +219,10 @@ def __init__(

self.num_derivatives = num_derivatives
self.multiindices = multiindices
self.spatiotemporal_grid = spatiotemporal_grid

self.spatiotemporal_grid = AxesArray(
spatiotemporal_grid, axes=comprehend_axes(spatiotemporal_grid)
)

# Weak form checks and setup
self._weak_form_setup()
Expand All @@ -228,12 +232,14 @@ def _weak_form_setup(self):
L_xt = xt2 - xt1
if self.H_xt is not None:
if np.isscalar(self.H_xt):
self.H_xt = np.array(self.grid_ndim * [self.H_xt])
self.H_xt = AxesArray(
np.array(self.grid_ndim * [self.H_xt]), {"ax_coord": 0}
)
if self.grid_ndim != len(self.H_xt):
raise ValueError(
"The user-defined grid (spatiotemporal_grid) and "
"the user-defined sizes of the subdomains for the "
"weak form, do not have the same # of spatiotemporal "
"weak form do not have the same # of spatiotemporal "
"dimensions. For instance, if spatiotemporal_grid is 4D, "
"then H_xt should be a 4D list of the subdomain lengths."
)
Expand All @@ -258,8 +264,8 @@ def _weak_form_setup(self):
self._set_up_weights()

def _get_spatial_endpoints(self):
x1 = np.zeros(self.grid_ndim)
x2 = np.zeros(self.grid_ndim)
x1 = AxesArray(np.zeros(self.grid_ndim), {"ax_coord": 0})
x2 = AxesArray(np.zeros(self.grid_ndim), {"ax_coord": 0})
for i in range(self.grid_ndim):
inds = [slice(None)] * (self.grid_ndim + 1)
for j in range(self.grid_ndim):
Expand All @@ -279,7 +285,9 @@ def _set_up_weights(self):

# Sample the random domain centers
xt1, xt2 = self._get_spatial_endpoints()
domain_centers = np.zeros((self.K, self.grid_ndim))
domain_centers = AxesArray(
np.zeros((self.K, self.grid_ndim)), {"ax_sample": 0, "ax_coord": 1}
)
for i in range(self.grid_ndim):
domain_centers[:, i] = np.random.uniform(
xt1[i] + self.H_xt[i], xt2[i] - self.H_xt[i], size=self.K
Expand All @@ -294,15 +302,12 @@ def _set_up_weights(self):
s = [0] * (self.grid_ndim + 1)
s[i] = slice(None)
s[-1] = i
newinds = np.intersect1d(
np.where(
self.spatiotemporal_grid[tuple(s)]
>= domain_centers[k][i] - self.H_xt[i]
),
np.where(
self.spatiotemporal_grid[tuple(s)]
<= domain_centers[k][i] + self.H_xt[i]
),
ax_vals = self.spatiotemporal_grid[tuple(s)]
cell_left = domain_centers[k][i] - self.H_xt[i]
cell_right = domain_centers[k][i] + self.H_xt[i]
newinds = AxesArray(
((ax_vals > cell_left) & (ax_vals < cell_right)).nonzero()[0],
ax_vals.axes,
)
# If less than two indices along any axis, resample
if len(newinds) < 2:
Expand All @@ -319,6 +324,7 @@ def _set_up_weights(self):
self.inds_k = self.inds_k + [inds]
k = k + 1

# TODO: fix meaning of axes in XT_k
# Values of the spatiotemporal grid on the domain cells
XT_k = [
self.spatiotemporal_grid[np.ix_(*self.inds_k[k])] for k in range(self.K)
Expand Down Expand Up @@ -441,6 +447,11 @@ def _set_up_weights(self):
)
weights1 = weights1 + [weights2]

# TODO: get rest of code to work with AxesArray. Too unsure of
# which axis labels to use at this point to continue
tweights = [np.asarray(arr) for arr in tweights]
weights0 = [np.asarray(arr) for arr in weights0]
weights1 = [[np.asarray(arr) for arr in sublist] for sublist in weights1]
# Product weights over the axes for time derivatives, shaped as inds_k
self.fulltweights = []
deriv = np.zeros(self.grid_ndim)
Expand Down
3 changes: 2 additions & 1 deletion pysindy/optimizers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,8 @@ def fit(self, x_, y, sample_weight=None, **reduce_kws):
self : returns an instance of self
"""
x_ = AxesArray(np.asarray(x_), {"ax_sample": 0, "ax_coord": 1})
y = AxesArray(np.asarray(y), {"ax_sample": 0, "ax_coord": 1})
y_axes = {"ax_sample": 0} if y.ndim == 1 else {"ax_sample": 0, "ax_coord": 1}
y = AxesArray(np.asarray(y), y_axes)
x_, y = drop_nan_samples(x_, y)
x_, y = check_X_y(x_, y, accept_sparse=[], y_numeric=True, multi_output=True)

Expand Down
Loading

0 comments on commit 9c73768

Please sign in to comment.