From 0e4d35f8a0c3a27b51be80d88606d1a45cbdbedc Mon Sep 17 00:00:00 2001 From: Qazalbash Date: Fri, 12 Jul 2024 23:02:42 +0500 Subject: [PATCH] chore: moments and entropy were extra and removed --- numpyro/distributions/continuous.py | 58 ----------------------------- 1 file changed, 58 deletions(-) diff --git a/numpyro/distributions/continuous.py b/numpyro/distributions/continuous.py index b2139d70e..04d79a466 100644 --- a/numpyro/distributions/continuous.py +++ b/numpyro/distributions/continuous.py @@ -2986,31 +2986,6 @@ def sample(self, key, sample_shape=()): samples = self.icdf(u) return samples - def _kth_moment(self, k): - Z = jnp.exp(self._logZ) - - def index_eq_neg1(): - return (jnp.log(self.high) - jnp.log(self.low)) / Z - - def index_neq_neg1(): - power_index = k + self.alpha + 1.0 - return ( - jnp.power(self.high, power_index) - jnp.power(self.low, power_index) - ) / ((power_index) * Z) - - return jnp.where( - jnp.equal(self.alpha + k, -1.0), index_eq_neg1(), index_neq_neg1() - ) - - @lazy_property - def mean(self): - return self._kth_moment(1) - - @lazy_property - def variance(self): - moment2 = self._kth_moment(2) - return moment2 - jnp.square(self.mean) - class LowerTruncatedPowerLaw(Distribution): r"""Lower truncated power law distribution with :math:`\alpha` index. @@ -3086,36 +3061,3 @@ def sample(self, key, sample_shape=()): u = random.uniform(key, sample_shape + self.batch_shape) samples = self.icdf(u) return samples - - def _kth_moment(self, k): - return jnp.where( - jnp.less(k, self.alpha - 1), - (self.alpha - 1) / (self.alpha - 1 - k) * jnp.power(self.low, k), - jnp.inf, - ) - - @lazy_property - def mean(self): - return self._kth_moment(1) - - @lazy_property - def variance(self): - return self._kth_moment(2) - jnp.square(self.mean) - - def entropy(self): - # The simplified expression for the entorpy is, - # H(x) = (alpha-1)log(a) - log(a-1) - # + alpha(alpha-1)a^(alpha-1)\int_{a}^{\infty}x^{-alpha}log(x)dx - # The integral term can be reshaped into a lower incomplete gamma function. - # After simplification, we get the following expression. - # H(x) = (alpha-1)log(a) - log(a-1) - # + (alpha/(alpha-1))a^(alpha-1)(1-gamma(2, (alpha-1)log(a))) - # I followed the definition of lower incomplete gamma function from wikipedia. - # https://en.wikipedia.org/wiki/Incomplete_gamma_function#Definition - # In the very same wikipedia article there was a recursive formula, - # https://en.wikipedia.org/wiki/Incomplete_gamma_function#Properties - # which I used to simplify the expression. It gave me the following expression. - # H(x) = log(a) - log(alpha-1) + (alpha/(alpha-1))(a^(alpha-1) + 1) - Hx = jnp.log(self.low) - jnp.log(self.alpha - 1) - Hx += self.alpha / (self.alpha - 1) * (jnp.power(self.low, self.alpha - 1) + 1) - return Hx