Skip to content

Commit

Permalink
restructure project directory
Browse files Browse the repository at this point in the history
  • Loading branch information
GStechschulte committed Nov 3, 2023
1 parent 050543d commit 52a82f6
Show file tree
Hide file tree
Showing 13 changed files with 232 additions and 210 deletions.
2 changes: 1 addition & 1 deletion docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ The tutorials below will introduce you to Kalman Filters in filterjax.

.. toctree::
:maxdepth: 1
:caption: Kalman Filter:
:caption: Kalman Filter (Linear Dynamics):

notebooks/tracking_objects.ipynb
.. examples
Expand Down
70 changes: 26 additions & 44 deletions docs/notebooks/tracking_objects.ipynb

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions filterjax/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .filters import GHFilter, KalmanFilter, KalmanParams
from .models import KalmanFilter
from .plotting import plot_posterior_covariance

__all__ = ["GHFilter", "KalmanFilter", "KalmanParams", "plot_posterior_covariance"]
__all__ = ["KalmanFilter", "plot_posterior_covariance"]
4 changes: 0 additions & 4 deletions filterjax/filters/__init__.py

This file was deleted.

35 changes: 0 additions & 35 deletions filterjax/filters/gh_filter.py

This file was deleted.

118 changes: 0 additions & 118 deletions filterjax/filters/kalman_filter.py

This file was deleted.

3 changes: 3 additions & 0 deletions filterjax/inference/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from filterjax.inference.filters import batch_filter

__all__ = ["batch_filter"]
70 changes: 70 additions & 0 deletions filterjax/inference/filters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
from typing import NamedTuple

import jax
import jax.numpy as jnp

from filterjax.params import KalmanParams


class PosteriorFilter(NamedTuple):
"""Posterior state estimate and covariance matrix.
Parameters
----------
m : jnp.ndarray
Mean of the posterior state estimate.
P : jnp.ndarray
Covariance of the posterior state estimate.
"""

mean: jnp.ndarray
covariance: jnp.ndarray


def update(y, x, R, H, P):
"""
Observe new measurement (emission) 'y' and update state.
"""
error = y - (H @ x)
# TODO: Use Cholesky decomposition for covariance
S = H @ P @ H.T + R
K = P @ H.T @ jnp.linalg.inv(S)
x = x + K @ error
P = P - K @ S @ K.T

return x, P


def predict(F, Q, x, P, B=None, u=None):
"""
Predict next state using the Kalman filter state propagation equations.
"""
if B is not None and u is not None:
x = F @ x + B @ u
else:
x = F @ x

P = F @ P @ F.T + Q

return x, P


def batch_filter(params: KalmanParams, emissions: jnp.ndarray):

num_timesteps = len(emissions)

def step(carry, t):
m, P = carry

# TODO: add and update log-likelihood

m, P = update(emissions[t], m, params.R, params.H, P)
m, P = predict(params.F, params.Q, m, P)

return (m, P), (m, P)

log_likelihood = 0.0
carry = (params.m, params.P)
_, (ms, Ps) = jax.lax.scan(step, carry, jnp.arange(num_timesteps))

return PosteriorFilter(ms, Ps)
3 changes: 3 additions & 0 deletions filterjax/models/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from filterjax.models.kalman_filter import KalmanFilter

__all__ = ["KalmanFilter"]
69 changes: 69 additions & 0 deletions filterjax/models/kalman_filter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
from typing import NamedTuple, Union

import jax
from jax import numpy as jnp

from filterjax.params import KalmanParams
from filterjax.inference.filters import batch_filter


class PosteriorFilter(NamedTuple):
"""Posterior state estimate and covariance matrix.
Parameters
----------
m : jnp.ndarray
Mean of the posterior state estimate.
P : jnp.ndarray
Covariance of the posterior state estimate.
"""

mean: jnp.ndarray
covariance: jnp.ndarray


class KalmanFilter:
def __init__(self, state_dim, emission_dim):
self.state_dim = state_dim
self.emission_dim = emission_dim

def initialize(
self,
init_state: jnp.ndarray,
init_transition: jnp.ndarray,
init_emission: jnp.ndarray,
init_emission_covariance: jnp.ndarray,
init_process_covariance: jnp.ndarray,
init_state_covariance: jnp.ndarray,
init_control: Union[jnp.ndarray, None] = None,
) -> KalmanParams:

# TODO: add a method that checks dims of params before initializing
params = KalmanParams(
m=init_state,
F=init_transition,
H=init_emission,
R=init_emission_covariance,
Q=init_process_covariance,
P=init_state_covariance,
B=init_control,
)

self.check_dims(params)

return params

def check_dims(self, params: KalmanParams):
assert params.m.shape == (self.state_dim,)
assert params.F.shape == (self.state_dim, self.state_dim)
assert params.H.shape == (self.emission_dim, self.state_dim)
assert params.R.shape == (self.emission_dim, self.emission_dim)
assert params.Q.shape == (self.state_dim, self.state_dim)
assert params.P.shape == (self.state_dim, self.state_dim)

if params.B is not None:
assert params.B.shape == (self.state_dim, self.state_dim)

def filter(self, params: KalmanParams, emissions: jnp.ndarray):
return batch_filter(params, emissions)

33 changes: 33 additions & 0 deletions filterjax/params.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from typing import NamedTuple

import jax.numpy as jnp


class KalmanParams(NamedTuple):
"""Parameters for the Kalman filter.
Parameters
----------
m : jnp.ndarray
Mean of the prior state estimate.
F : jnp.ndarray
State transition matrix.
H : jnp.ndarray
Observation (emission) matrix.
R : jnp.ndarray
Covariance matrix of the observation (emission) noise.
Q : jnp.ndarray
Covariance matrix of the process model (dynamics) noise.
P : jnp.ndarray
Covariance matrix of the prior state estimate.
B : jnp.ndarray, optional
Control transition matrix, by default None
"""

m: jnp.ndarray
F: jnp.ndarray
H: jnp.ndarray
R: jnp.ndarray
Q: jnp.ndarray
P: jnp.ndarray
B: jnp.ndarray = None
2 changes: 1 addition & 1 deletion filterjax/plotting/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from filterjax.plotting.plot import plot_posterior_covariance

__all__ = ["plot_posterior_covariance"]
__all__ = ["plot_posterior_covariance"]
Loading

0 comments on commit 52a82f6

Please sign in to comment.