Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Scan logprob fails when unvalued stochastic outputs are returned #6909

Open
ricardoV94 opened this issue Sep 13, 2023 · 0 comments
Open

Scan logprob fails when unvalued stochastic outputs are returned #6909

ricardoV94 opened this issue Sep 13, 2023 · 0 comments

Comments

@ricardoV94
Copy link
Member

Description

Found by @lucianopaz

import numpy as np
import pymc as pm
import pytensor
import pytensor.tensor as pt

from pymc.pytensorf import collect_default_updates

steps = 4

def ar_dist1(rho, sigma, size):
    def ar_step(x_tm1, rho, sigma):
        eps_t = pm.Normal.dist(sigma=sigma)
        mu = x_tm1 * rho
        x = mu + eps_t
        return x, collect_default_updates([x])

    ar_innov, _ = pytensor.scan(
        fn=ar_step,
        outputs_info=[{"initial": pt.zeros(()), "taps": [-1]}],
        non_sequences=[rho, sigma],
        n_steps=steps,
        strict=True,
    )

    return ar_innov


def ar_dist2(rho, sigma, size):
    def ar_step(x_tm1, rho, sigma):
        eps_t = pm.Normal.dist(sigma=sigma)
        mu = x_tm1 * rho
        x = mu + eps_t
        return [x, eps_t], collect_default_updates([x])

    [ar_innov, _], _ = pytensor.scan(
        fn=ar_step,
        outputs_info=[{"initial": pt.zeros(()), "taps": [-1]}, None],
        non_sequences=[rho, sigma],
        n_steps=steps,
        strict=True,
    )

    return ar_innov


with pm.Model() as m:
    rho = 0.1
    sigma = 0.1
    observed = np.arange(steps)

    pm.CustomDist(
        "ar_dist1",
        rho,
        sigma,
        dist=ar_dist1,
        observed=observed,
    )

    pm.CustomDist(
        "ar_dist2",
        rho,
        sigma,
        dist=ar_dist2,
        observed=observed,
    )

logp1, logp2 = m.compile_logp(sum=False)({})
np.testing.assert_allclose(logp1, logp2)
"""
AssertionError: 
Not equal to tolerance rtol=1e-07, atol=0
Mismatched elements: 2 / 4 (50%)
Max absolute difference: 58.
Max relative difference: 0.12928641
 x: array([   1.383647,  -48.616353, -179.116353, -390.616353])
 y: array([   1.383647,  -48.616353, -198.616353, -448.616353])
"""
@ricardoV94 ricardoV94 changed the title Wrong derived logprob for scan when unused deterministic of rv is returned Wrong scan logprob when unused deterministic of rv is returned Sep 13, 2023
@ricardoV94 ricardoV94 changed the title Wrong scan logprob when unused deterministic of rv is returned Wrong scan logprob when unused base rv is returned Sep 6, 2024
@ricardoV94 ricardoV94 changed the title Wrong scan logprob when unused base rv is returned Scan logprob fails when unused base rv is returned Sep 6, 2024
@ricardoV94 ricardoV94 changed the title Scan logprob fails when unused base rv is returned Scan logprob fails when unvalued stochastic outputs are returned Sep 7, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

1 participant