Skip to content

Commit

Permalink
Reimplement several RandomVariables as SymbolicRandomVariables
Browse files Browse the repository at this point in the history
This allows sampling from multiple backends without having to dispatch for each one
  • Loading branch information
ricardoV94 committed Apr 10, 2024
1 parent e0a82bc commit 2e6c2c4
Show file tree
Hide file tree
Showing 3 changed files with 130 additions and 119 deletions.
158 changes: 90 additions & 68 deletions pymc/distributions/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,12 @@
vonmises,
)
from pytensor.tensor.random.op import RandomVariable
from pytensor.tensor.random.utils import normalize_size_param
from pytensor.tensor.variable import TensorConstant

from pymc.logprob.abstract import _logprob_helper
from pymc.logprob.basic import icdf
from pymc.pytensorf import normalize_rng_param

try:
from polyagamma import polyagamma_cdf, polyagamma_pdf, random_polyagamma
Expand All @@ -73,7 +75,6 @@ def polyagamma_cdf(*args, **kwargs):

from scipy import stats
from scipy.interpolate import InterpolatedUnivariateSpline
from scipy.special import expit

from pymc.distributions import transforms
from pymc.distributions.dist_math import (
Expand All @@ -90,8 +91,8 @@ def polyagamma_cdf(*args, **kwargs):
normal_lcdf,
zvalue,
)
from pymc.distributions.distribution import DIST_PARAMETER_TYPES, Continuous
from pymc.distributions.shape_utils import rv_size_is_none
from pymc.distributions.distribution import DIST_PARAMETER_TYPES, Continuous, SymbolicRandomVariable
from pymc.distributions.shape_utils import implicit_size_from_params, rv_size_is_none
from pymc.distributions.transforms import _default_transform
from pymc.math import invlogit, logdiffexp, logit

Expand Down Expand Up @@ -1236,20 +1237,28 @@ def icdf(value, alpha, beta):
)


class KumaraswamyRV(RandomVariable):
class KumaraswamyRV(SymbolicRandomVariable):
name = "kumaraswamy"
ndim_supp = 0
ndims_params = [0, 0]
dtype = "floatX"
signature = "[rng],[size],(),()->[rng],()"
_print_name = ("Kumaraswamy", "\\operatorname{Kumaraswamy}")

@classmethod
def rng_fn(cls, rng, a, b, size) -> np.ndarray:
u = rng.uniform(size=size)
return np.asarray((1 - (1 - u) ** (1 / b)) ** (1 / a))
def rv_op(cls, a, b, *, size=None, rng=None):
a = pt.as_tensor(a)
b = pt.as_tensor(b)
rng = normalize_rng_param(rng)
size = normalize_size_param(size)

if rv_size_is_none(size):
size = implicit_size_from_params(a, b, ndims_params=cls.ndims_params)

kumaraswamy = KumaraswamyRV()
next_rng, u = uniform(size=size, rng=rng).owner.outputs
draws = (1 - (1 - u) ** (1 / b)) ** (1 / a)

return cls(
inputs=[rng, size, a, b],
outputs=[next_rng, draws],
)(rng, size, a, b)


class Kumaraswamy(UnitContinuous):
Expand Down Expand Up @@ -1296,13 +1305,11 @@ class Kumaraswamy(UnitContinuous):
b > 0.
"""

rv_op = kumaraswamy
rv_type = KumaraswamyRV
rv_op = KumaraswamyRV.rv_op

@classmethod
def dist(cls, a: DIST_PARAMETER_TYPES, b: DIST_PARAMETER_TYPES, *args, **kwargs):
a = pt.as_tensor_variable(a)
b = pt.as_tensor_variable(b)

return super().dist([a, b], *args, **kwargs)

def support_point(rv, size, a, b):
Expand Down Expand Up @@ -1533,24 +1540,32 @@ def icdf(value, mu, b):
return check_icdf_parameters(res, b > 0, msg="b > 0")


class AsymmetricLaplaceRV(RandomVariable):
class AsymmetricLaplaceRV(SymbolicRandomVariable):
name = "asymmetriclaplace"
ndim_supp = 0
ndims_params = [0, 0, 0]
dtype = "floatX"
signature = "[rng],[size],(),(),()->[rng],()"
_print_name = ("AsymmetricLaplace", "\\operatorname{AsymmetricLaplace}")

@classmethod
def rng_fn(cls, rng, b, kappa, mu, size=None) -> np.ndarray:
u = rng.uniform(size=size)
def rv_op(cls, b, kappa, mu, *, size=None, rng=None):
b = pt.as_tensor(b)
kappa = pt.as_tensor(kappa)
mu = pt.as_tensor(mu)
rng = normalize_rng_param(rng)
size = normalize_size_param(size)

if rv_size_is_none(size):
size = implicit_size_from_params(b, kappa, mu, ndims_params=cls.ndims_params)

next_rng, u = uniform(size=size, rng=rng).owner.outputs
switch = kappa**2 / (1 + kappa**2)
non_positive_x = mu + kappa * np.log(u * (1 / switch)) / b
positive_x = mu - np.log((1 - u) * (1 + kappa**2)) / (kappa * b)
non_positive_x = mu + kappa * pt.log(u * (1 / switch)) / b
positive_x = mu - pt.log((1 - u) * (1 + kappa**2)) / (kappa * b)
draws = non_positive_x * (u <= switch) + positive_x * (u > switch)
return np.asarray(draws)


asymmetriclaplace = AsymmetricLaplaceRV()
return cls(
inputs=[rng, size, b, kappa, mu],
outputs=[next_rng, draws],
)(rng, size, b, kappa, mu)


class AsymmetricLaplace(Continuous):
Expand Down Expand Up @@ -1599,15 +1614,12 @@ class AsymmetricLaplace(Continuous):
of interest.
"""

rv_op = asymmetriclaplace
rv_type = AsymmetricLaplaceRV
rv_op = AsymmetricLaplaceRV.rv_op

@classmethod
def dist(cls, kappa=None, mu=None, b=None, q=None, *args, **kwargs):
kappa = cls.get_kappa(kappa, q)
b = pt.as_tensor_variable(b)
kappa = pt.as_tensor_variable(kappa)
mu = pt.as_tensor_variable(mu)

return super().dist([b, kappa, mu], *args, **kwargs)

@classmethod
Expand Down Expand Up @@ -2475,7 +2487,6 @@ def dist(cls, nu, **kwargs):
return Gamma.dist(alpha=nu / 2, beta=1 / 2, **kwargs)


# TODO: Remove this once logp for multiplication is working!
class WeibullBetaRV(RandomVariable):
name = "weibull"
ndim_supp = 0
Expand Down Expand Up @@ -2597,19 +2608,22 @@ def icdf(value, alpha, beta):
)


class HalfStudentTRV(RandomVariable):
class HalfStudentTRV(SymbolicRandomVariable):
name = "halfstudentt"
ndim_supp = 0
ndims_params = [0, 0]
dtype = "floatX"
signature = "[rng],[size],(),()->[rng],()"
_print_name = ("HalfStudentT", "\\operatorname{HalfStudentT}")

@classmethod
def rng_fn(cls, rng, nu, sigma, size=None) -> np.ndarray:
return np.asarray(np.abs(stats.t.rvs(nu, scale=sigma, size=size, random_state=rng)))
def rv_op(cls, nu, sigma, *, size=None, rng=None) -> np.ndarray:
nu = pt.as_tensor(nu)
sigma = pt.as_tensor(sigma)
rng = normalize_rng_param(rng)
size = normalize_size_param(size)

next_rng, t_draws = t(df=nu, scale=sigma, size=size, rng=rng).owner.outputs
draws = pt.abs(t_draws)

halfstudentt = HalfStudentTRV()
return cls(inputs=[rng, size, nu, sigma], outputs=[next_rng, draws])(rng, size, nu, sigma)


class HalfStudentT(PositiveContinuous):
Expand Down Expand Up @@ -2671,14 +2685,12 @@ class HalfStudentT(PositiveContinuous):
x = pm.HalfStudentT('x', lam=4, nu=10)
"""

rv_op = halfstudentt
rv_type = HalfStudentTRV
rv_op = HalfStudentTRV.rv_op

@classmethod
def dist(cls, nu, sigma=None, lam=None, *args, **kwargs):
nu = pt.as_tensor_variable(nu)
lam, sigma = get_tau_sigma(lam, sigma)
sigma = pt.as_tensor_variable(sigma)

return super().dist([nu, sigma], *args, **kwargs)

def support_point(rv, size, nu, sigma):
Expand Down Expand Up @@ -2710,19 +2722,29 @@ def logp(value, nu, sigma):
)


class ExGaussianRV(RandomVariable):
class ExGaussianRV(SymbolicRandomVariable):
name = "exgaussian"
ndim_supp = 0
ndims_params = [0, 0, 0]
dtype = "floatX"
signature = "[rng],[size],(),(),()->[rng],()"
_print_name = ("ExGaussian", "\\operatorname{ExGaussian}")

@classmethod
def rng_fn(cls, rng, mu, sigma, nu, size=None) -> np.ndarray:
return np.asarray(rng.normal(mu, sigma, size=size) + rng.exponential(scale=nu, size=size))
def rv_op(cls, mu, sigma, nu, *, size=None, rng=None):
mu = pt.as_tensor(mu)
sigma = pt.as_tensor(sigma)
nu = pt.as_tensor(nu)
rng = normalize_rng_param(rng)
size = normalize_size_param(size)

if rv_size_is_none(size):
size = implicit_size_from_params(mu, sigma, nu, ndims_params=cls.ndims_params)

exgaussian = ExGaussianRV()
next_rng, normal_draws = normal(loc=mu, scale=sigma, size=size, rng=rng)
final_rng, exponential_draws = exponential(scale=nu, size=size, rng=next_rng)
draws = normal_draws + exponential_draws

return cls(inputs=[rng, size, mu, sigma, nu], outputs=[final_rng, draws])(
rng, size, mu, sigma, nu
)


class ExGaussian(Continuous):
Expand Down Expand Up @@ -2792,14 +2814,11 @@ class ExGaussian(Continuous):
Vol. 4, No. 1, pp 35-45.
"""

rv_op = exgaussian
rv_type = ExGaussianRV
rv_op = ExGaussianRV.rv_op

@classmethod
def dist(cls, mu=0.0, sigma=None, nu=None, *args, **kwargs):
mu = pt.as_tensor_variable(mu)
sigma = pt.as_tensor_variable(sigma)
nu = pt.as_tensor_variable(nu)

return super().dist([mu, sigma, nu], *args, **kwargs)

def support_point(rv, size, mu, sigma, nu):
Expand Down Expand Up @@ -3477,19 +3496,25 @@ def icdf(value, mu, s):
)


class LogitNormalRV(RandomVariable):
class LogitNormalRV(SymbolicRandomVariable):
name = "logit_normal"
ndim_supp = 0
ndims_params = [0, 0]
dtype = "floatX"
signature = "[rng],[size],(),()->[rng],()"
_print_name = ("logitNormal", "\\operatorname{logitNormal}")

@classmethod
def rng_fn(cls, rng, mu, sigma, size=None) -> np.ndarray:
return np.asarray(expit(stats.norm.rvs(loc=mu, scale=sigma, size=size, random_state=rng)))
def rv_op(cls, mu, sigma, *, size=None, rng=None):
mu = pt.as_tensor(mu)
sigma = pt.as_tensor(sigma)
rng = normalize_rng_param(rng)
size = normalize_size_param(size)

next_rng, normal_draws = normal(loc=mu, scale=sigma, size=size, rng=rng).owner.outputs
draws = pt.expit(normal_draws)

logit_normal = LogitNormalRV()
return cls(
inputs=[rng, size, mu, sigma],
outputs=[next_rng, draws],
)(rng, size, mu, sigma)


class LogitNormal(UnitContinuous):
Expand Down Expand Up @@ -3540,15 +3565,12 @@ class LogitNormal(UnitContinuous):
Defaults to 1.
"""

rv_op = logit_normal
rv_type = LogitNormalRV
rv_op = LogitNormalRV.rv_op

@classmethod
def dist(cls, mu=0, sigma=None, tau=None, **kwargs):
mu = pt.as_tensor_variable(mu)
tau, sigma = get_tau_sigma(tau=tau, sigma=sigma)
sigma = pt.as_tensor_variable(sigma)
tau = pt.as_tensor_variable(tau)

_, sigma = get_tau_sigma(tau=tau, sigma=sigma)
return super().dist([mu, sigma], **kwargs)

def support_point(rv, size, mu, sigma):
Expand Down
35 changes: 21 additions & 14 deletions pymc/distributions/discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@

from pytensor.tensor import TensorConstant
from pytensor.tensor.random.basic import (
RandomVariable,
ScipyRandomVariable,
bernoulli,
betabinom,
Expand All @@ -28,7 +27,9 @@
hypergeometric,
nbinom,
poisson,
uniform,
)
from pytensor.tensor.random.utils import normalize_size_param
from scipy import stats

import pymc as pm
Expand All @@ -45,8 +46,8 @@
normal_lccdf,
normal_lcdf,
)
from pymc.distributions.distribution import Discrete
from pymc.distributions.shape_utils import rv_size_is_none
from pymc.distributions.distribution import Discrete, SymbolicRandomVariable
from pymc.distributions.shape_utils import implicit_size_from_params, rv_size_is_none
from pymc.logprob.basic import logcdf, logp
from pymc.math import sigmoid

Expand All @@ -65,6 +66,8 @@
"OrderedProbit",
]

from pymc.pytensorf import normalize_rng_param


class Binomial(Discrete):
R"""
Expand Down Expand Up @@ -387,20 +390,25 @@ def logcdf(value, p):
)


class DiscreteWeibullRV(RandomVariable):
class DiscreteWeibullRV(SymbolicRandomVariable):
name = "discrete_weibull"
ndim_supp = 0
ndims_params = [0, 0]
dtype = "int64"
signature = "[rng],[size],(),()->[rng],()"
_print_name = ("dWeibull", "\\operatorname{dWeibull}")

@classmethod
def rng_fn(cls, rng, q, beta, size):
p = rng.uniform(size=size)
return np.ceil(np.power(np.log(1 - p) / np.log(q), 1.0 / beta)) - 1
def rv_op(cls, q, beta, *, size=None, rng=None):
q = pt.as_tensor(q)
beta = pt.as_tensor(beta)
rng = normalize_rng_param(rng)
size = normalize_size_param(size)

if rv_size_is_none(size):
size = implicit_size_from_params(q, beta, ndims_params=cls.ndims_params)

next_rng, p = uniform(size=size, rng=rng).owner.outputs
draws = pt.ceil(pt.power(pt.log(1 - p) / pt.log(q), 1.0 / beta)) - 1

discrete_weibull = DiscreteWeibullRV()
return cls(inputs=[rng, size, q, beta], outputs=[next_rng, draws])(rng, size, q, beta)


class DiscreteWeibull(Discrete):
Expand Down Expand Up @@ -452,12 +460,11 @@ def DiscreteWeibull(q, b, x):
"""

rv_op = discrete_weibull
rv_type = DiscreteWeibullRV
rv_op = DiscreteWeibullRV.rv_op

@classmethod
def dist(cls, q, beta, *args, **kwargs):
q = pt.as_tensor_variable(q)
beta = pt.as_tensor_variable(beta)
return super().dist([q, beta], **kwargs)

def support_point(rv, size, q, beta):
Expand Down
Loading

0 comments on commit 2e6c2c4

Please sign in to comment.