Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tackle Typing and Linting Errors #379

Draft
wants to merge 32 commits into
base: main
Choose a base branch
from
Draft

Tackle Typing and Linting Errors #379

wants to merge 32 commits into from

Conversation

gileshd
Copy link
Collaborator

@gileshd gileshd commented Sep 12, 2024

Summary

This PR is a Work In Progress.

Fix the current errors with type hints and some common linting errors.

Details

Commits

A brief categorisation of the commits:

Address review comment "Incorrect typing"

  • Add ruff ignore F722 - 7c252c2
  • Prepend space in uni-dim jaxtyping hints - b7927fb

Linting Errors:

Fix Type Errors:

  • PRNGKeys:
    • Fix jr.PRNGKey type hints - 1ded8e9
      • There are lots of places where a PRNGKey is annotated with jr.PRNGKey which, as a function, isn't suitable for type hinting here.
      • Replace these instances with jaxtyping.Array as per suggestion here
    • Rename and change PRNGKey Type - c9109e8
      • We have an internal dynamax.types module that defines a type for a key - I update this definition as per suggestion linked above.
    • The current state is a bit inconsistent
      • Sometimes keys are typed as dynamax.types.PRNGKeyT (which is an alias for jaxtyping.Array) and sometimes directly as jaxtyping.Array.
        • Should we try to be consistent? (If so what should we use?)
  • Other:
    • Fix LinearGaussianSSM.sample type hint - a890d62
    • Fix type annotations in hmm parallel inference - a8e08d4

Enable runtime type checking:

  • Add runtime_checkable decorator to Protocols - c4efc96

Unrelated Changes:

Changes:
- Replace `jnp.ones` input with `jr.normal`.
- Reduce size of hidden state to 3.
- Remove unused `datetime` import and commented lines.

Fundamentally the problem is that the solve step can be unstable. This
does not resolve that but instead chooses a set up which is less
vulnerable to the instability.
- Rename to PRNGKeyT to differentiate from `jax.random` function
- Change type to Array
  - see https://jax.readthedocs.io/en/latest/jep/9263-typed-keys.html#type-annotations-for-prng-keys
@gileshd gileshd changed the title Tackle Typing Tackle Typing and Linting Errors Sep 12, 2024
@gileshd gileshd force-pushed the ghd/typing branch 3 times, most recently from c11f50f to 685028a Compare September 23, 2024 12:03
@gileshd gileshd force-pushed the ghd/typing branch 3 times, most recently from 5b4f00a to 55b17f9 Compare September 24, 2024 13:50
Major changes:
- Fix an error in `_compute_all_transition_probs` which caused the
  return array to be too short.
  - The `filtered_probs` were being truncated twice instead of once.
- Replace `jaxtyping.Int` with `dynamax.typing.IntScalar` or `int`
  - this reflects when integer scalar arrays are accepted
  - `jaxtyping.[Dtype]` cannot be used directly for type checking
    instead they must be used as part of an array.
- Fix the shape of `transition_matrix`:
  - if transition_matrix has a leading timestep axis it should be of
    length T-1 not of length T.
- Add annotation indicating that `transition_matrix` is an optional argument
- Raise ValueError when neither `transition_matrix` or `transition_fn`
  provided.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant