Skip to content

Commit

Permalink
Remove unused imports
Browse files Browse the repository at this point in the history
  • Loading branch information
gileshd committed Sep 12, 2024
1 parent b7927fb commit 72bf8f5
Show file tree
Hide file tree
Showing 18 changed files with 14 additions and 51 deletions.
1 change: 0 additions & 1 deletion dynamax/generalized_gaussian_ssm/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from jax import jacfwd, vmap, lax
import jax.numpy as jnp
from jax import lax
from tensorflow_probability.substrates.jax.distributions import MultivariateNormalFullCovariance as MVN
from jaxtyping import Array, Float
from typing import NamedTuple, Optional, Union, Callable

Expand Down
2 changes: 1 addition & 1 deletion dynamax/hidden_markov_model/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from typing import Callable, Optional, Tuple, Union, NamedTuple
from jaxtyping import Int, Float, Array

from dynamax.types import Scalar, PRNGKey
from dynamax.types import Scalar

_get_params = lambda x, dim, t: x[t] if x.ndim == dim + 1 else x

Expand Down
1 change: 0 additions & 1 deletion dynamax/hidden_markov_model/inference_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import pytest
import itertools as it
import jax.numpy as jnp
import jax.random as jr
Expand Down
2 changes: 1 addition & 1 deletion dynamax/hidden_markov_model/models/abstractions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from dynamax.types import Scalar
from dynamax.parameters import to_unconstrained, from_unconstrained
from dynamax.parameters import ParameterSet, PropertySet
from dynamax.hidden_markov_model.inference import HMMPosterior, HMMPosteriorFiltered
from dynamax.hidden_markov_model.inference import HMMPosterior
from dynamax.hidden_markov_model.inference import hmm_filter
from dynamax.hidden_markov_model.inference import hmm_posterior_mode
from dynamax.hidden_markov_model.inference import hmm_smoother
Expand Down
1 change: 0 additions & 1 deletion dynamax/hidden_markov_model/models/categorical_glm_hmm.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import jax.numpy as jnp
import jax.random as jr
import tensorflow_probability.substrates.jax.distributions as tfd
from jaxtyping import Float, Array
Expand Down
3 changes: 1 addition & 2 deletions dynamax/hidden_markov_model/parallel_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@
import jax.random as jr
from jax import lax, vmap, value_and_grad
from jaxtyping import Array, Float, Int
from typing import NamedTuple, Union
from functools import partial
from typing import NamedTuple

from dynamax.hidden_markov_model.inference import HMMPosterior, HMMPosteriorFiltered

Expand Down
15 changes: 1 addition & 14 deletions dynamax/linear_gaussian_ssm/demos/kf_linreg_jax_vs_pt.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,16 @@
import numpy as np
from jaxtyping import Float, Array
from typing import Callable, NamedTuple, Union, Tuple, Any
from functools import partial
import chex
import optax
import jax
import jax.numpy as jnp
import jax.random as jr
from jax import lax, jacfwd, vmap, grad, jit
from jax.tree_util import tree_map, tree_reduce
from jax.flatten_util import ravel_pytree
from jax import lax
import jax.numpy as jnp
import jax.random as jr
from jax import lax
import time
import platform

import matplotlib.pyplot as plt
import matplotlib.cm as cm

from dataclasses import dataclass
from itertools import cycle

import tensorflow as tf
import tensorflow_probability as tfp
from tensorflow_probability.substrates.jax.distributions import MultivariateNormalFullCovariance as MVN

import torch
Expand Down
4 changes: 2 additions & 2 deletions dynamax/linear_gaussian_ssm/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import jax.numpy as jnp
import jax.random as jr
from jax.tree_util import tree_map
from jaxtyping import Array, Float, PyTree
from jaxtyping import Array, Float
import tensorflow_probability.substrates.jax.distributions as tfd
from tensorflow_probability.substrates.jax.distributions import MultivariateNormalFullCovariance as MVN
from typing import Any, Optional, Tuple, Union
Expand All @@ -14,7 +14,7 @@
from dynamax.linear_gaussian_ssm.inference import lgssm_joint_sample, lgssm_filter, lgssm_smoother, lgssm_posterior_sample
from dynamax.linear_gaussian_ssm.inference import ParamsLGSSM, ParamsLGSSMInitial, ParamsLGSSMDynamics, ParamsLGSSMEmissions
from dynamax.linear_gaussian_ssm.inference import PosteriorGSSMFiltered, PosteriorGSSMSmoothed
from dynamax.parameters import ParameterProperties, ParameterSet
from dynamax.parameters import ParameterProperties
from dynamax.types import PRNGKey, Scalar
from dynamax.utils.bijectors import RealToPSDBijector
from dynamax.utils.distributions import MatrixNormalInverseWishart as MNIW
Expand Down
1 change: 0 additions & 1 deletion dynamax/linear_gaussian_ssm/models_test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import pytest
from datetime import datetime
import jax.random as jr
from dynamax.linear_gaussian_ssm import LinearGaussianSSM
from dynamax.linear_gaussian_ssm import LinearGaussianConjugateSSM
Expand Down
1 change: 0 additions & 1 deletion dynamax/nonlinear_gaussian_ssm/inference_ukf_test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import jax.numpy as jnp
import jax.random as jr

from dynamax.nonlinear_gaussian_ssm.inference_ukf import unscented_kalman_smoother, UKFHyperParams
from dynamax.nonlinear_gaussian_ssm.sarkka_lib import ukf, uks
Expand Down
1 change: 0 additions & 1 deletion dynamax/nonlinear_gaussian_ssm/sarkka_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
"""

import jax.numpy as jnp
import jax.random as jr
from jax import vmap
from jax import lax
from jax import jacfwd
Expand Down
5 changes: 2 additions & 3 deletions dynamax/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,10 @@
from jax import lax
from jax.tree_util import tree_reduce, tree_map, register_pytree_node_class
import tensorflow_probability.substrates.jax.bijectors as tfb
from typing import Optional, Union
from typing import Optional
from typing_extensions import Protocol
from jaxtyping import Array, Float

from dynamax.types import PRNGKey, Scalar
from dynamax.types import Scalar

class ParameterSet(Protocol):
"""A :class:`NamedTuple` with parameters stored as :class:`jax.DeviceArray` in the leaf nodes.
Expand Down
1 change: 0 additions & 1 deletion dynamax/slds/inference_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import jax.random as jr
from dynamax.slds import SLDS, DiscreteParamsSLDS, LGParamsSLDS, ParamsSLDS, rbpfilter, rbpfilter_optimal
from functools import partial
import matplotlib.pyplot as plt
import dynamax.slds.mixture_kalman_filter_demo as kflib
from functools import partial
from jax.scipy.special import logit
Expand Down
20 changes: 4 additions & 16 deletions dynamax/slds/models.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,15 @@
from fastprogress.fastprogress import progress_bar
from functools import partial
from jax import jit, lax
from jax import lax
import jax.numpy as jnp
import jax.random as jr
from jax.tree_util import tree_map
from jaxtyping import Array, Float, PyTree
from jaxtyping import Array, Float
import tensorflow_probability.substrates.jax.distributions as tfd
from tensorflow_probability.substrates.jax.distributions import MultivariateNormalFullCovariance as MVN
from typing import Any, Optional, Tuple, Union
from typing_extensions import Protocol
from typing import Optional, Tuple

from dynamax.ssm import SSM
from dynamax.linear_gaussian_ssm.models import LinearGaussianSSM
from dynamax.linear_gaussian_ssm.inference import lgssm_filter, lgssm_smoother, lgssm_posterior_sample
from dynamax.slds.inference import ParamsSLDS
from dynamax.linear_gaussian_ssm.inference import PosteriorGSSMFiltered, PosteriorGSSMSmoothed
from dynamax.parameters import ParameterProperties, ParameterSet
from dynamax.types import PRNGKey, Scalar
from dynamax.utils.bijectors import RealToPSDBijector
from dynamax.utils.distributions import MatrixNormalInverseWishart as MNIW
from dynamax.utils.distributions import NormalInverseWishart as NIW
from dynamax.utils.distributions import mniw_posterior_update, niw_posterior_update
from dynamax.utils.utils import pytree_stack, psd_solve
from dynamax.types import PRNGKey

class SLDS(SSM):

Expand Down
2 changes: 1 addition & 1 deletion dynamax/ssm.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import jax.random as jr
from jax import jit, lax, vmap
from jax.tree_util import tree_map
from jaxtyping import Float, Array, PyTree
from jaxtyping import Float, Array
import optax
from tensorflow_probability.substrates.jax import distributions as tfd
from typing import Optional, Union, Tuple, Any
Expand Down
3 changes: 1 addition & 2 deletions dynamax/types.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from typing import Optional, Union
from typing_extensions import Protocol
from typing import Union
from jaxtyping import Array, Float
import jax._src.random as prng

Expand Down
1 change: 0 additions & 1 deletion dynamax/utils/distributions_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import pytest
import jax.numpy as jnp
import jax.random as jr
from jax.tree_util import tree_map
Expand Down
1 change: 0 additions & 1 deletion dynamax/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import jaxlib
from jaxtyping import Array, Int
from scipy.optimize import linear_sum_assignment
from typing import Optional
from jax.scipy.linalg import cho_factor, cho_solve

def has_tpu():
Expand Down

0 comments on commit 72bf8f5

Please sign in to comment.