Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update ARProcess and InfectionInitializationProcess to handle batched input #423

Closed
sbidari opened this issue Aug 30, 2024 · 20 comments
Closed
Assignees
Labels
bug Something isn't working clean up Good code that could be better

Comments

@sbidari
Copy link
Collaborator

sbidari commented Aug 30, 2024

Currently, ARProcess and InfectionInitializationProcess (and possibly others) require scalar inputs which prevents these from being used with numpyro.plate.

This emerged while trying to use numpyro.plate to calculate site-level dynamics for each site in CDCgov/pyrenew-hew#7

I am also open to other possible ways of handling this and maybe worth having a discussion about the best way to do it. For reference, I tried using a for loop but that required modifying the name arguments of RandomVariables. @dylanhmorris mentioned potentially using jax.lax.scan

@dylanhmorris
Copy link
Collaborator

dylanhmorris commented Aug 30, 2024

Minimal reproducible example? Which inputs are causing plate to object?

Also, what happens if you instantiate the process RV outside the plate but call its sample() method within the plate?

@sbidari
Copy link
Collaborator Author

sbidari commented Aug 30, 2024

numpyro.plate vectorizes noise_sd for the "plate size" (not sure if this is the correct term) but ARProcess expects noise_sd input to be a scalar

@sbidari
Copy link
Collaborator Author

sbidari commented Aug 30, 2024

This reproduces the issue with using ARProcess

import numpyro
from pyrenew.process import ARProcess
from pyrenew.randomvariable import DistributionalVariable
from pyrenew.metaclass import Model
import numpyro.distributions as dist

class my_model(Model):
    def __init__(
            self, 
            autoreg_rt_site_rv,
            sigma_rt_rv,
            rtu_site_ar_proc,
            rtu_site_ar_init_rv,
            ):
        
        self.autoreg_rt_site_rv = autoreg_rt_site_rv
        self.sigma_rt_rv = sigma_rt_rv
        self.rtu_site_ar_proc = rtu_site_ar_proc
        self.rtu_site_ar_init_rv = rtu_site_ar_init_rv

    def validate(self):
        pass

    def sample(self, n_weeks):
        with numpyro.plate("subpop", 5):
            autoreg_rt_site =self.autoreg_rt_site_rv()[0].value
            sigma_rt = self.sigma_rt_rv()[0].value
            rtu_site_ar_init = rtu_site_ar_init_rv()[0].value

            rtu_site_ar_weekly = self.rtu_site_ar_proc(
                n=n_weeks,
                init_vals=rtu_site_ar_init,
                autoreg=autoreg_rt_site,
                noise_sd=sigma_rt,
            )
            return rtu_site_ar_weekly


autoreg_rt_site_rv = DistributionalVariable("autoreg_rt", dist.Beta(1,1))
sigma_rt_rv = DistributionalVariable("sigma_rt", dist.TruncatedNormal(1,1,low=0))
rtu_site_ar_init_rv = DistributionalVariable("rtu_site_ar_init",dist.Normal(0,1))
rtu_site_ar_proc = ARProcess(noise_rv_name="rtu_ar_proc")

model1 = my_model(autoreg_rt_site_rv,sigma_rt_rv,rtu_site_ar_proc,rtu_site_ar_init_rv)

with numpyro.handlers.seed(rng_seed=5):
    model1.sample(n_weeks=10)

one line error summary

## ValueError: noise_sd must be a scalar. Got [0.40856445 0.3173288  1.530092   0.78180265 1.627873  ]

@damonbayer
Copy link
Collaborator

This is blocking CDCgov/pyrenew-hew#7

@dylanhmorris
Copy link
Collaborator

dylanhmorris commented Sep 3, 2024

Suggest splitting this into sub-issues, one for each of the two RV classes.

@damonbayer
Copy link
Collaborator

damonbayer commented Sep 3, 2024

Rewriting to have the initial exponential growth work in terms of matrix multiplication seem to do what we want. However, it still feels a bit automatic. For instance, how does it know how to batch y into the correct sub-population?

I get

UserWarning: Missing a plate statement for batch dimension -2 at site 'obs'. You can use `numpyro.util.format_shapes` utility to check shapes at all sites of your model.
  mcmc.run(rng_key, y=y_data)

but I'm not sure where to put this -2.

import jax.numpy as jnp
import numpyro
import numpyro.distributions as dist
from jax import random
from numpyro.infer import MCMC, NUTS

n_subpops = 3
rate = 1 + jnp.pow(10.0, -(jnp.arange(n_subpops) + 1))
n_timepoints = 10
i0 = jnp.arange(n_subpops) + 1

y_data = i0 * jnp.exp(rate * jnp.expand_dims(jnp.arange(n_timepoints), 1))


def my_model(y):
    with numpyro.plate("subpop", n_subpops):
        rate = numpyro.sample("rate", dist.HalfNormal())
        i0 = numpyro.sample("i0", dist.HalfNormal())

        mean_infec = i0 * jnp.exp(
            rate * jnp.expand_dims(jnp.arange(n_timepoints), 1)
        )
        numpyro.sample("obs", dist.Poisson(mean_infec), obs=y)


# Posterior Sampling
nuts_kernel = NUTS(my_model)
mcmc = MCMC(nuts_kernel, num_warmup=500, num_samples=1000)
rng_key = random.PRNGKey(0)
mcmc.run(rng_key, y=y_data)

# Check results
mcmc.print_summary()

@AFg6K7h4fhy2
Copy link
Collaborator

AFg6K7h4fhy2 commented Sep 3, 2024

Maybe a second numpyro.plate is needed here for timepoint? (for the warning; also, DB remarked in DMs that inference is being done properly).

@AFg6K7h4fhy2
Copy link
Collaborator

For the warning, wo/ with numpyro.plate("timepoint", n_timepoints): above

mean_infec = i0 * jnp.exp(
    rate * jnp.expand_dims(jnp.arange(n_timepoints), 1)
)
numpyro.sample("obs", dist.Poisson(mean_infec), obs=y)

out:

db_plate_eg.py:30: UserWarning: Missing a plate statement for batch dimension -2 at site 'obs'. You can use `numpyro.util.format_shapes` utility to check shapes at all sites of your model.
  mcmc.run(rng_key, y=y_data)
sample: 100%|████████████████████████████████████| 1500/1500 [00:01<00:00, 1042.99it/s, 63 steps of size 6.52e-02. acc. prob=0.89]

                mean       std    median      5.0%     95.0%     n_eff     r_hat
     i0[0]      1.00      0.06      1.00      0.91      1.09    472.95      1.00
     i0[1]      1.99      0.11      1.99      1.80      2.16    394.08      1.00
     i0[2]      2.96      0.13      2.96      2.75      3.17    524.65      1.00
   rate[0]      1.10      0.01      1.10      1.09      1.11    466.99      1.00
   rate[1]      1.01      0.01      1.01      1.00      1.02    401.20      1.00
   rate[2]      1.00      0.01      1.00      0.99      1.01    528.82      1.00

Number of divergences: 0

and w/ the with numpyro.plate("timepoint", n_timepoints):

sample: 100%|████████████████████████████████████| 1500/1500 [00:01<00:00, 1127.89it/s, 63 steps of size 6.52e-02. acc. prob=0.89]

                mean       std    median      5.0%     95.0%     n_eff     r_hat
     i0[0]      1.00      0.06      1.00      0.91      1.09    472.95      1.00
     i0[1]      1.99      0.11      1.99      1.80      2.16    394.08      1.00
     i0[2]      2.96      0.13      2.96      2.75      3.17    524.65      1.00
   rate[0]      1.10      0.01      1.10      1.09      1.11    466.99      1.00
   rate[1]      1.01      0.01      1.01      1.00      1.02    401.20      1.00
   rate[2]      1.00      0.01      1.00      0.99      1.01    528.82      1.00

Number of divergences: 0

@AFg6K7h4fhy2
Copy link
Collaborator

Modified, for no warning:

def my_model(y):
    with numpyro.plate("subpop", n_subpops):
        rate = numpyro.sample("rate", dist.HalfNormal())
        i0 = numpyro.sample("i0", dist.HalfNormal())
        with numpyro.plate("timepoint", n_timepoints):
            mean_infec = i0 * jnp.exp(
                rate * jnp.expand_dims(jnp.arange(n_timepoints), 1)
            )
            numpyro.sample("obs", dist.Poisson(mean_infec), obs=y)

@damonbayer
Copy link
Collaborator

Thanks @AFg6K7h4fhy2. I'm glad we have something working without a warning, but this all feels a bit funny to me. In particular, it seems like we should be able to come up with a solution that doesn't require the time series to be of equal length.

@AFg6K7h4fhy2
Copy link
Collaborator

Thanks @AFg6K7h4fhy2. I'm glad we have something working without a warning, but this all feels a bit funny to me. In particular, it seems like we should be able to come up with a solution that doesn't require the time series to be of equal length.

Agree. I'm not going to investigate the dim= option right now, but hopefully there is something there that supports our ends better.

@damonbayer
Copy link
Collaborator

damonbayer commented Sep 3, 2024

This works with arbitrary time series sizes. Doesn't feel very modular, though.

import jax.numpy as jnp
import numpyro
import numpyro.distributions as dist
from jax import random
from numpyro.infer import MCMC, NUTS
from jax import vmap


n_subpops = 4
rates = 1 + jnp.pow(10.0, -(jnp.arange(n_subpops) + 1))
n_timepoints = jnp.arange(n_subpops) + 10
i0s = jnp.arange(n_subpops) + 1

y_data = jnp.concatenate(
    [
        i0 * jnp.exp(rate * jnp.arange(n_timepoints))
        for rate, i0, n_timepoints in zip(rates, i0s, n_timepoints)
    ]
)
y_ind = jnp.repeat(jnp.arange(n_subpops), n_timepoints)
y_time = jnp.concatenate(
    [jnp.arange(n_timepoint) for n_timepoint in n_timepoints]
)


def my_model(y_data, y_ind, y_time):
    with numpyro.plate("subpop", n_subpops):
        rate = numpyro.sample("rate", dist.HalfNormal())
        i0 = numpyro.sample("i0", dist.HalfNormal())
    mean_infec = i0[y_ind] * jnp.exp(rate[y_ind] * y_time)
    numpyro.sample("obs", dist.Poisson(mean_infec), obs=y_data)


# Posterior Sampling
nuts_kernel = NUTS(my_model)
mcmc = MCMC(nuts_kernel, num_warmup=500, num_samples=1000)
rng_key = random.PRNGKey(0)
mcmc.run(rng_key, y_data=y_data, y_ind=y_ind, y_time=y_time)

# Check results
mcmc.print_summary()

@damonbayer
Copy link
Collaborator

damonbayer commented Sep 4, 2024

After discussing with @dylanhmorris, we have agreed on

  1. Models should accept tidy data. (Ex: time, location, count columns)
  2. Internal replication of model components should be achieved with numpyro.plates (Ex with numpyro.plate("location", n_locations).
  3. Latent computations should be done internal to plates and/or using broadcasting when possible. This may result in cases where we compute "unused" variables (Ex. location 1 has observations at 20 timepoints and location 2 has observations at 10 timepoints. We will compute the full latent time series for both locations).
  4. Observation samples should occur outside of plates using indexing based on the tidy data.

I am working on a revised version of the above model to demonstrate these recommendations.

@damonbayer
Copy link
Collaborator

damonbayer commented Sep 4, 2024

Updated example based on above recommendations is below. A bit unclear about what this implies for the rest of PyRenew. I think we may need to stick in some atleast_1d calls throughout.

import jax.numpy as jnp
import numpy as np
import numpyro
import numpyro.distributions as dist
from jax import random
from numpyro.infer import MCMC, NUTS
import polars as pl
import string

n_groups = 4
rates = 1 + jnp.pow(10.0, -(jnp.arange(n_groups) + 1))
n_timepoints = jnp.arange(n_groups) + 10
i0s = jnp.arange(n_groups) + 1


input_data = pl.DataFrame(
    {
        "group": pl.Series(np.array(list(string.ascii_lowercase))[
            np.repeat(np.arange(n_groups), n_timepoints)
        ], dtype = pl.Categorical),
        "time": np.concatenate(
            [np.arange(n_timepoint) for n_timepoint in n_timepoints]
        ),
        "obs": np.concatenate(
            [
                i0 * np.exp(rate * np.arange(n_timepoints))
                for rate, i0, n_timepoints in zip(rates, i0s, n_timepoints)
            ]
        ),
    }
).filter(~((pl.col("group") == "a") & (pl.col("time") == 4)))
# some implicitly missing data

y_group = input_data["group"].to_numpy()
y_time = input_data["time"].to_numpy()
y_obs = input_data["obs"].to_numpy()


# This would be done at modle instantiation:
y_group_ind = input_data["group"].to_physical().to_numpy()
y_time_max = input_data["time"].max()

def my_model(y_group, y_time, y_obs):
    with numpyro.plate("group", n_groups):
        rate = numpyro.sample("rate", dist.HalfNormal())
        i0 = numpyro.sample("i0", dist.HalfNormal())
        mean_infec = i0 * jnp.exp(rate * jnp.arange(y_time_max+1)[:, jnp.newaxis])
        
    numpyro.sample("obs", dist.Poisson(mean_infec[y_time, y_group]), obs=y_obs)

# Posterior Sampling
nuts_kernel = NUTS(my_model, find_heuristic_step_size = True)
mcmc = MCMC(nuts_kernel, num_warmup=500, num_samples=1000)
rng_key = random.PRNGKey(0)
mcmc.run(rng_key, y_group = y_group_ind, y_time = y_time, y_obs = y_obs)

# Check results
mcmc.print_summary()

@damonbayer damonbayer added bug Something isn't working clean up Good code that could be better labels Sep 6, 2024
@dylanhmorris
Copy link
Collaborator

dylanhmorris commented Sep 6, 2024

Was the Infection initialization process part of this issue handled by #432, @damonbayer? Or do other initialization schemes such as InitializeInfectionsZeroPad also need reworking?

@damonbayer
Copy link
Collaborator

Was the Infection initialization process part of this issue handled by #432, @damonbayer? Or do other initialization schemes such as InitializeInfectionsZeroPad also need reworking?

It's a non-trivial fix. I would like to first consider removing the InfectionInitializationMethod class and the InitializeInfectionsFromVec and InitializeInfectionsZeroPad functions.

@dylanhmorris
Copy link
Collaborator

I think that's reasonable.InitializeInfectionsFromVec is the only functionality I can see us using, but it's not clear that it needs a helper class.

@damonbayer
Copy link
Collaborator

damonbayer commented Sep 6, 2024

I think the initialize functions should just be RandomVariables. Their shared interface was not very elegant and somewhat confusing. InitializeInfectionsFromVec can just be a DeterministicVariable.

@damonbayer
Copy link
Collaborator

Partially closed by #432

@damonbayer damonbayer modified the milestones: 🦖 Rajasaurus, S Sprint Sep 16, 2024
@damonbayer
Copy link
Collaborator

Fully closed by #439

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working clean up Good code that could be better
Projects
None yet
Development

No branches or pull requests

4 participants