From 72bf8f5af9610d1aa54d3378e6799ac7f2ad9e17 Mon Sep 17 00:00:00 2001 From: gileshd Date: Thu, 12 Sep 2024 15:02:02 +0100 Subject: [PATCH] Remove unused imports --- dynamax/generalized_gaussian_ssm/inference.py | 1 - dynamax/hidden_markov_model/inference.py | 2 +- dynamax/hidden_markov_model/inference_test.py | 1 - .../models/abstractions.py | 2 +- .../models/categorical_glm_hmm.py | 1 - .../hidden_markov_model/parallel_inference.py | 3 +-- .../demos/kf_linreg_jax_vs_pt.py | 15 +------------- dynamax/linear_gaussian_ssm/models.py | 4 ++-- dynamax/linear_gaussian_ssm/models_test.py | 1 - .../inference_ukf_test.py | 1 - dynamax/nonlinear_gaussian_ssm/sarkka_lib.py | 1 - dynamax/parameters.py | 5 ++--- dynamax/slds/inference_test.py | 1 - dynamax/slds/models.py | 20 ++++--------------- dynamax/ssm.py | 2 +- dynamax/types.py | 3 +-- dynamax/utils/distributions_test.py | 1 - dynamax/utils/utils.py | 1 - 18 files changed, 14 insertions(+), 51 deletions(-) diff --git a/dynamax/generalized_gaussian_ssm/inference.py b/dynamax/generalized_gaussian_ssm/inference.py index fa9af8fa..d86d54ec 100644 --- a/dynamax/generalized_gaussian_ssm/inference.py +++ b/dynamax/generalized_gaussian_ssm/inference.py @@ -3,7 +3,6 @@ from jax import jacfwd, vmap, lax import jax.numpy as jnp from jax import lax -from tensorflow_probability.substrates.jax.distributions import MultivariateNormalFullCovariance as MVN from jaxtyping import Array, Float from typing import NamedTuple, Optional, Union, Callable diff --git a/dynamax/hidden_markov_model/inference.py b/dynamax/hidden_markov_model/inference.py index aa4657d6..9b6f945f 100644 --- a/dynamax/hidden_markov_model/inference.py +++ b/dynamax/hidden_markov_model/inference.py @@ -8,7 +8,7 @@ from typing import Callable, Optional, Tuple, Union, NamedTuple from jaxtyping import Int, Float, Array -from dynamax.types import Scalar, PRNGKey +from dynamax.types import Scalar _get_params = lambda x, dim, t: x[t] if x.ndim == dim + 1 else x diff --git a/dynamax/hidden_markov_model/inference_test.py b/dynamax/hidden_markov_model/inference_test.py index 9a4babfe..6f814279 100644 --- a/dynamax/hidden_markov_model/inference_test.py +++ b/dynamax/hidden_markov_model/inference_test.py @@ -1,4 +1,3 @@ -import pytest import itertools as it import jax.numpy as jnp import jax.random as jr diff --git a/dynamax/hidden_markov_model/models/abstractions.py b/dynamax/hidden_markov_model/models/abstractions.py index 958cfe25..f36cbe7e 100644 --- a/dynamax/hidden_markov_model/models/abstractions.py +++ b/dynamax/hidden_markov_model/models/abstractions.py @@ -3,7 +3,7 @@ from dynamax.types import Scalar from dynamax.parameters import to_unconstrained, from_unconstrained from dynamax.parameters import ParameterSet, PropertySet -from dynamax.hidden_markov_model.inference import HMMPosterior, HMMPosteriorFiltered +from dynamax.hidden_markov_model.inference import HMMPosterior from dynamax.hidden_markov_model.inference import hmm_filter from dynamax.hidden_markov_model.inference import hmm_posterior_mode from dynamax.hidden_markov_model.inference import hmm_smoother diff --git a/dynamax/hidden_markov_model/models/categorical_glm_hmm.py b/dynamax/hidden_markov_model/models/categorical_glm_hmm.py index 56692641..af49cd74 100644 --- a/dynamax/hidden_markov_model/models/categorical_glm_hmm.py +++ b/dynamax/hidden_markov_model/models/categorical_glm_hmm.py @@ -1,4 +1,3 @@ -import jax.numpy as jnp import jax.random as jr import tensorflow_probability.substrates.jax.distributions as tfd from jaxtyping import Float, Array diff --git a/dynamax/hidden_markov_model/parallel_inference.py b/dynamax/hidden_markov_model/parallel_inference.py index 36461299..5db1ab70 100644 --- a/dynamax/hidden_markov_model/parallel_inference.py +++ b/dynamax/hidden_markov_model/parallel_inference.py @@ -2,8 +2,7 @@ import jax.random as jr from jax import lax, vmap, value_and_grad from jaxtyping import Array, Float, Int -from typing import NamedTuple, Union -from functools import partial +from typing import NamedTuple from dynamax.hidden_markov_model.inference import HMMPosterior, HMMPosteriorFiltered diff --git a/dynamax/linear_gaussian_ssm/demos/kf_linreg_jax_vs_pt.py b/dynamax/linear_gaussian_ssm/demos/kf_linreg_jax_vs_pt.py index 85565928..88dfcc11 100644 --- a/dynamax/linear_gaussian_ssm/demos/kf_linreg_jax_vs_pt.py +++ b/dynamax/linear_gaussian_ssm/demos/kf_linreg_jax_vs_pt.py @@ -1,29 +1,16 @@ import numpy as np -from jaxtyping import Float, Array -from typing import Callable, NamedTuple, Union, Tuple, Any -from functools import partial -import chex -import optax import jax import jax.numpy as jnp import jax.random as jr -from jax import lax, jacfwd, vmap, grad, jit -from jax.tree_util import tree_map, tree_reduce -from jax.flatten_util import ravel_pytree +from jax import lax import jax.numpy as jnp import jax.random as jr from jax import lax import time import platform -import matplotlib.pyplot as plt -import matplotlib.cm as cm -from dataclasses import dataclass -from itertools import cycle -import tensorflow as tf -import tensorflow_probability as tfp from tensorflow_probability.substrates.jax.distributions import MultivariateNormalFullCovariance as MVN import torch diff --git a/dynamax/linear_gaussian_ssm/models.py b/dynamax/linear_gaussian_ssm/models.py index f7c61330..7b05c788 100644 --- a/dynamax/linear_gaussian_ssm/models.py +++ b/dynamax/linear_gaussian_ssm/models.py @@ -4,7 +4,7 @@ import jax.numpy as jnp import jax.random as jr from jax.tree_util import tree_map -from jaxtyping import Array, Float, PyTree +from jaxtyping import Array, Float import tensorflow_probability.substrates.jax.distributions as tfd from tensorflow_probability.substrates.jax.distributions import MultivariateNormalFullCovariance as MVN from typing import Any, Optional, Tuple, Union @@ -14,7 +14,7 @@ from dynamax.linear_gaussian_ssm.inference import lgssm_joint_sample, lgssm_filter, lgssm_smoother, lgssm_posterior_sample from dynamax.linear_gaussian_ssm.inference import ParamsLGSSM, ParamsLGSSMInitial, ParamsLGSSMDynamics, ParamsLGSSMEmissions from dynamax.linear_gaussian_ssm.inference import PosteriorGSSMFiltered, PosteriorGSSMSmoothed -from dynamax.parameters import ParameterProperties, ParameterSet +from dynamax.parameters import ParameterProperties from dynamax.types import PRNGKey, Scalar from dynamax.utils.bijectors import RealToPSDBijector from dynamax.utils.distributions import MatrixNormalInverseWishart as MNIW diff --git a/dynamax/linear_gaussian_ssm/models_test.py b/dynamax/linear_gaussian_ssm/models_test.py index c4394858..fefe8b9e 100644 --- a/dynamax/linear_gaussian_ssm/models_test.py +++ b/dynamax/linear_gaussian_ssm/models_test.py @@ -1,5 +1,4 @@ import pytest -from datetime import datetime import jax.random as jr from dynamax.linear_gaussian_ssm import LinearGaussianSSM from dynamax.linear_gaussian_ssm import LinearGaussianConjugateSSM diff --git a/dynamax/nonlinear_gaussian_ssm/inference_ukf_test.py b/dynamax/nonlinear_gaussian_ssm/inference_ukf_test.py index 6900a08d..bb8f9a11 100644 --- a/dynamax/nonlinear_gaussian_ssm/inference_ukf_test.py +++ b/dynamax/nonlinear_gaussian_ssm/inference_ukf_test.py @@ -1,5 +1,4 @@ import jax.numpy as jnp -import jax.random as jr from dynamax.nonlinear_gaussian_ssm.inference_ukf import unscented_kalman_smoother, UKFHyperParams from dynamax.nonlinear_gaussian_ssm.sarkka_lib import ukf, uks diff --git a/dynamax/nonlinear_gaussian_ssm/sarkka_lib.py b/dynamax/nonlinear_gaussian_ssm/sarkka_lib.py index ff65405d..875bf5f9 100644 --- a/dynamax/nonlinear_gaussian_ssm/sarkka_lib.py +++ b/dynamax/nonlinear_gaussian_ssm/sarkka_lib.py @@ -6,7 +6,6 @@ """ import jax.numpy as jnp -import jax.random as jr from jax import vmap from jax import lax from jax import jacfwd diff --git a/dynamax/parameters.py b/dynamax/parameters.py index c829c586..b0dfde9e 100644 --- a/dynamax/parameters.py +++ b/dynamax/parameters.py @@ -2,11 +2,10 @@ from jax import lax from jax.tree_util import tree_reduce, tree_map, register_pytree_node_class import tensorflow_probability.substrates.jax.bijectors as tfb -from typing import Optional, Union +from typing import Optional from typing_extensions import Protocol -from jaxtyping import Array, Float -from dynamax.types import PRNGKey, Scalar +from dynamax.types import Scalar class ParameterSet(Protocol): """A :class:`NamedTuple` with parameters stored as :class:`jax.DeviceArray` in the leaf nodes. diff --git a/dynamax/slds/inference_test.py b/dynamax/slds/inference_test.py index 6986492e..5e592fe2 100644 --- a/dynamax/slds/inference_test.py +++ b/dynamax/slds/inference_test.py @@ -2,7 +2,6 @@ import jax.random as jr from dynamax.slds import SLDS, DiscreteParamsSLDS, LGParamsSLDS, ParamsSLDS, rbpfilter, rbpfilter_optimal from functools import partial -import matplotlib.pyplot as plt import dynamax.slds.mixture_kalman_filter_demo as kflib from functools import partial from jax.scipy.special import logit diff --git a/dynamax/slds/models.py b/dynamax/slds/models.py index cbbd7170..67858fa5 100644 --- a/dynamax/slds/models.py +++ b/dynamax/slds/models.py @@ -1,27 +1,15 @@ -from fastprogress.fastprogress import progress_bar -from functools import partial -from jax import jit, lax +from jax import lax import jax.numpy as jnp import jax.random as jr from jax.tree_util import tree_map -from jaxtyping import Array, Float, PyTree +from jaxtyping import Array, Float import tensorflow_probability.substrates.jax.distributions as tfd from tensorflow_probability.substrates.jax.distributions import MultivariateNormalFullCovariance as MVN -from typing import Any, Optional, Tuple, Union -from typing_extensions import Protocol +from typing import Optional, Tuple from dynamax.ssm import SSM -from dynamax.linear_gaussian_ssm.models import LinearGaussianSSM -from dynamax.linear_gaussian_ssm.inference import lgssm_filter, lgssm_smoother, lgssm_posterior_sample from dynamax.slds.inference import ParamsSLDS -from dynamax.linear_gaussian_ssm.inference import PosteriorGSSMFiltered, PosteriorGSSMSmoothed -from dynamax.parameters import ParameterProperties, ParameterSet -from dynamax.types import PRNGKey, Scalar -from dynamax.utils.bijectors import RealToPSDBijector -from dynamax.utils.distributions import MatrixNormalInverseWishart as MNIW -from dynamax.utils.distributions import NormalInverseWishart as NIW -from dynamax.utils.distributions import mniw_posterior_update, niw_posterior_update -from dynamax.utils.utils import pytree_stack, psd_solve +from dynamax.types import PRNGKey class SLDS(SSM): diff --git a/dynamax/ssm.py b/dynamax/ssm.py index 72107a6d..c5a41a6a 100644 --- a/dynamax/ssm.py +++ b/dynamax/ssm.py @@ -6,7 +6,7 @@ import jax.random as jr from jax import jit, lax, vmap from jax.tree_util import tree_map -from jaxtyping import Float, Array, PyTree +from jaxtyping import Float, Array import optax from tensorflow_probability.substrates.jax import distributions as tfd from typing import Optional, Union, Tuple, Any diff --git a/dynamax/types.py b/dynamax/types.py index 3fff53f3..4847471f 100644 --- a/dynamax/types.py +++ b/dynamax/types.py @@ -1,5 +1,4 @@ -from typing import Optional, Union -from typing_extensions import Protocol +from typing import Union from jaxtyping import Array, Float import jax._src.random as prng diff --git a/dynamax/utils/distributions_test.py b/dynamax/utils/distributions_test.py index 88355e19..fc2f5d37 100644 --- a/dynamax/utils/distributions_test.py +++ b/dynamax/utils/distributions_test.py @@ -1,4 +1,3 @@ -import pytest import jax.numpy as jnp import jax.random as jr from jax.tree_util import tree_map diff --git a/dynamax/utils/utils.py b/dynamax/utils/utils.py index d694365e..52135fda 100644 --- a/dynamax/utils/utils.py +++ b/dynamax/utils/utils.py @@ -8,7 +8,6 @@ import jaxlib from jaxtyping import Array, Int from scipy.optimize import linear_sum_assignment -from typing import Optional from jax.scipy.linalg import cho_factor, cho_solve def has_tpu():