Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add fourier #49

Open
wants to merge 15 commits into
base: main
Choose a base branch
from
4 changes: 3 additions & 1 deletion docs/developers_notes/01-basis_module.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@ Abstract Class Basis
│ │
│ └─ Concrete Subclass RaisedCosineBasisLog
└─ Concrete Subclass OrthExponentialBasis
├─ Concrete Subclass OrthExponentialBasis
└─ Concrete Subclass FourierBasis
```

The super-class `Basis` provides two public methods, [`evaluate`](#the-public-method-evaluate) and [`evaluate_on_grid`](#the-public-method-evaluate_on_grid). These methods perform checks on both the input provided by the user and the output of the evaluation to ensure correctness, and are thus considered "safe". They both make use of the private abstract method `_evaluate` that is specific for each concrete class. See below for more details.
Expand Down
92 changes: 90 additions & 2 deletions docs/examples/plot_1D_basis_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,10 @@
# -----------------
# Each basis type may necessitate specific hyperparameters for instantiation. For a comprehensive description,
# please refer to the [Code References](../../../reference/nemos/basis). After instantiation, all classes
# share the same syntax for basis evaluation. The following is an example of how to instantiate and
# evaluate a log-spaced cosine raised function basis.
# share the same syntax for basis evaluation.
#
# ### The Log-Spaced Raised Cosine Basis
# The following is an example of how to instantiate and evaluate a log-spaced cosine raised function basis.

# Instantiate the basis noting that the `RaisedCosineBasisLog` does not require an `order` parameter
raised_cosine_log = nmo.basis.RaisedCosineBasisLog(n_basis_funcs=10, width=1.5, time_scaling=50)
Expand All @@ -81,3 +83,89 @@
plt.plot(samples, eval_basis)
plt.show()

# %%
# ### The Fourier Basis
# Another type of basis available is the Fourier Basis. Fourier basis are ideal to capture periodic and
# quasi-periodic patterns. Such oscillatory, rhythmic behavior is a common signature of many neural signals.
# Additionally, the Fourier basis has the advantage of being orthogonal, which simplifies the estimation and
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should probably explain orthogonal here.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

still think this. at least a foot note or link

# interpretation of the model parameters, each of which will represent the relative contribution of a specific
# oscillation frequency to the overall signal.
#
# A Fourier basis can be instantiated with the following syntax:
# the user can provide the maximum frequency of the cosine and negative
# sine pairs by setting the `max_freq` parameter.
# The sinusoidal basis elements will have frequencies from 0 to `max_freq`.


fourier_basis = nmo.basis.FourierBasis(max_freq=3)

# evaluate on equi-spaced samples
samples, eval_basis = fourier_basis.evaluate_on_grid(1000)

# plot the `sin` and `cos` separately
plt.figure(figsize=(6, 3))
plt.subplot(121)
plt.title("Cos")
plt.plot(samples, eval_basis[:, :4])
plt.subplot(122)
plt.title("Sin")
plt.plot(samples, eval_basis[:, 4:])
plt.tight_layout()

# %%
# ## Fourier Basis Convolution and Fourier Transform
BalzaniEdoardo marked this conversation as resolved.
Show resolved Hide resolved
# The Fourier transform of a signal $ s(t) $ restricted to a temporal window $ [t_0,\;t_1] $ is
# $$ \\hat{x}(\\omega) = \\int_{t_0}^{t_1} s(\\tau) e^{-j\\omega \\tau} d\\tau. $$
# where $ e^{-j\\omega \\tau} = \\cos(\\omega \\tau) - j \\sin (\\omega \\tau) $.
#
# When computing the cross-correlation of a signal with the Fourier basis functions,
# we essentially measure how well the signal correlates with sinusoids of different frequencies,
# within a specified temporal window. This process mirrors the operation performed by the Fourier transform.
# Therefore, it becomes clear that computing the cross-correlation of a signal with the Fourier basis defined here
# is equivalent to computing the discrete Fourier transform on a sliding window of the same size
# as that of the basis.
BalzaniEdoardo marked this conversation as resolved.
Show resolved Hide resolved


n_samples = 1000
max_freq = 20

# define a signal
signal = np.random.normal(size=n_samples)

# evaluate the basis
_, eval_basis = nmo.basis.FourierBasis(max_freq=max_freq).evaluate_on_grid(n_samples)

# compute the cross-corr with the signal and the basis
# Note that we are inverting the time axis of the basis because we are aiming
# for a cross-correlation, while np.convolve compute a convolution which would flip the time axis.
Comment on lines +139 to +140
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can't we just compute the correlation directly to avoid this confusion? It's true, but provides an extra hurdle for folks (and then we could call out this equivalency in an admonition)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

still think this

xcorr = np.array(
[
np.convolve(eval_basis[::-1, k], signal, mode="valid")[0]
for k in range(2 * max_freq + 1)
]
)

# compute the power (add back sin(0 * t) = 0)
fft_complex = np.fft.fft(signal)
fft_amplitude = np.abs(fft_complex[:max_freq + 1])
fft_phase = np.angle(fft_complex[:max_freq + 1])
# compute the phase and amplitude from the convolution
xcorr_phase = np.arctan2(np.hstack([[0], xcorr[max_freq+1:]]), xcorr[:max_freq+1])
xcorr_aplitude = np.sqrt(xcorr[:max_freq+1] ** 2 + np.hstack([[0], xcorr[max_freq+1:]]) ** 2)

fig, ax = plt.subplots(1, 2)
ax[0].set_aspect("equal")
ax[0].set_title("Signal amplitude")
ax[0].scatter(fft_amplitude, xcorr_aplitude)
ax[0].set_xlabel("FFT")
ax[0].set_ylabel("cross-correlation")

ax[1].set_aspect("equal")
ax[1].set_title("Signal phase")
ax[1].scatter(fft_phase, xcorr_phase)
ax[1].set_xlabel("FFT")
ax[1].set_ylabel("cross-correlation")
plt.tight_layout()

print(f"Max Error Amplitude: {np.abs(fft_amplitude - xcorr_aplitude).max()}")
print(f"Max Error Phase: {np.abs(fft_phase - xcorr_phase).max()}")
96 changes: 94 additions & 2 deletions src/nemos/basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
"OrthExponentialBasis",
"AdditiveBasis",
"MultiplicativeBasis",
"FourierBasis",
]


Expand Down Expand Up @@ -103,7 +104,7 @@ def _check_evaluate_input(self, *xi: ArrayLike) -> Tuple[NDArray]:
# make sure array is at least 1d (so that we succeed when only
# passed a scalar)
xi = tuple(np.atleast_1d(np.asarray(x, dtype=float)) for x in xi)
except TypeError:
except (TypeError, ValueError):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what additionally is being caught here?

raise TypeError("Input samples must be array-like of floats!")

# check for non-empty samples
Expand Down Expand Up @@ -1086,7 +1087,8 @@ def _check_rates(self):
"linearly dependent set of function for the basis."
)

def _check_sample_range(self, sample_pts: NDArray):
@staticmethod
def _check_sample_range(sample_pts: NDArray):
"""
Check if the sample points are all positive.

Expand Down Expand Up @@ -1177,6 +1179,96 @@ def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]:
return super().evaluate_on_grid(n_samples)


class FourierBasis(Basis):
"""Set of 1D Fourier basis.

This class defines a cosine and negative sine basis (quadrature pair)
with frequencies ranging 0 to max_freq.

Parameters
----------
max_freq
Highest frequency of the cosine, negative sine pairs.
The number of basis function will be 2*max_freq + 1.
"""

def __init__(self, max_freq: int):
super().__init__(n_basis_funcs=2 * max_freq + 1)

self._frequencies = np.arange(max_freq + 1, dtype=float)
self._n_input_dimensionality = 1

def _check_n_basis_min(self) -> None:
"""Check that the user required enough basis elements.

Checks that the number of basis is at least 1.

Raises
------
ValueError
If an insufficient number of basis element is requested for the basis type.
"""
if self.n_basis_funcs < 0:
raise ValueError(
f"Object class {self.__class__.__name__} requires >= 1 basis elements. "
f"{self.n_basis_funcs} basis elements specified instead"
)

def evaluate(self, sample_pts: ArrayLike) -> NDArray:
"""Generate basis functions with given spacing.

Parameters
----------
sample_pts
Spacing for basis functions.

Returns
-------
basis_funcs
Evaluated Fourier basis, shape (n_samples, n_basis_funcs).

Notes
-----
The frequencies are set to np.arange(max_freq+1), convolving a signal
of length n_samples with this basis is equivalent, but slower,
then computing the FFT truncated to the first max_freq components.

Therefore, convolving a signal with this basis is equivalent
to compute the FFT over a sliding window.

Examples
--------
>>> import nemos as nmo
>>> import numpy as np
>>> n_samples, max_freq = 1000, 10
>>> basis = nmo.basis.FourierBasis(max_freq)
>>> eval_basis = basis.evaluate(np.linspace(0, 1, n_samples))
>>> sinusoid = np.cos(3 * np.arange(0, 1000) * np.pi * 2 / 1000.)
>>> conv = [np.convolve(eval_basis[::-1, k], sinusoid, mode='valid')[0] for k in range(2*max_freq+1)]
>>> fft = np.fft.fft(sinusoid)
>>> print('FFT power: ', np.round(np.real(fft[:max_freq]), 4))
>>> print('Convolution: ', np.round(conv[:max_freq], 4))
Comment on lines +1241 to +1250

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a way in mkdocs to set this as a python codeblock? Can you remove the >>>? Because right now in the docs the >>> shows up and when you copy it copies with the >>> so a user can't just copy and paste and run from the docs 😄

for example at the bottom here the arrows render and get copied: https://nemos.readthedocs.io/en/latest/reference/nemos/utils/#nemos.utils.pytree_map_and_reduce

"""
(sample_pts,) = self._check_evaluate_input(sample_pts)
# assumes equi-spaced samples.
if sample_pts.shape[0] / np.max(self._frequencies) < 2:
raise ValueError("Not enough samples, aliasing likely to occur!")

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe report sample_pts.shape[0] and max(self._frequencies) to the user?


# rescale to [0, 2pi)
mn, mx = np.nanmin(sample_pts), np.nanmax(sample_pts)
# first sample in 0, last sample in 2 pi - 2 pi / n_samples.
sample_pts = (
2
* np.pi
* (sample_pts - mn)
/ (mx - mn)
* (1.0 - 1.0 / sample_pts.shape[0])
)
# create the basis
angles = np.einsum("i,j->ij", sample_pts, self._frequencies)
return np.concatenate([np.cos(angles), -np.sin(angles[:, 1:])], axis=1)


def mspline(x: NDArray, k: int, i: int, T: NDArray):
"""Compute M-spline basis function.

Expand Down
Loading