Skip to content

Commit

Permalink
chore: moments and entropy were extra and removed
Browse files Browse the repository at this point in the history
  • Loading branch information
Qazalbash committed Jul 12, 2024
1 parent 1192ca7 commit 0e4d35f
Showing 1 changed file with 0 additions and 58 deletions.
58 changes: 0 additions & 58 deletions numpyro/distributions/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -2986,31 +2986,6 @@ def sample(self, key, sample_shape=()):
samples = self.icdf(u)
return samples

def _kth_moment(self, k):
Z = jnp.exp(self._logZ)

def index_eq_neg1():
return (jnp.log(self.high) - jnp.log(self.low)) / Z

def index_neq_neg1():
power_index = k + self.alpha + 1.0
return (
jnp.power(self.high, power_index) - jnp.power(self.low, power_index)
) / ((power_index) * Z)

return jnp.where(
jnp.equal(self.alpha + k, -1.0), index_eq_neg1(), index_neq_neg1()
)

@lazy_property
def mean(self):
return self._kth_moment(1)

@lazy_property
def variance(self):
moment2 = self._kth_moment(2)
return moment2 - jnp.square(self.mean)


class LowerTruncatedPowerLaw(Distribution):
r"""Lower truncated power law distribution with :math:`\alpha` index.
Expand Down Expand Up @@ -3086,36 +3061,3 @@ def sample(self, key, sample_shape=()):
u = random.uniform(key, sample_shape + self.batch_shape)
samples = self.icdf(u)
return samples

def _kth_moment(self, k):
return jnp.where(
jnp.less(k, self.alpha - 1),
(self.alpha - 1) / (self.alpha - 1 - k) * jnp.power(self.low, k),
jnp.inf,
)

@lazy_property
def mean(self):
return self._kth_moment(1)

@lazy_property
def variance(self):
return self._kth_moment(2) - jnp.square(self.mean)

def entropy(self):
# The simplified expression for the entorpy is,
# H(x) = (alpha-1)log(a) - log(a-1)
# + alpha(alpha-1)a^(alpha-1)\int_{a}^{\infty}x^{-alpha}log(x)dx
# The integral term can be reshaped into a lower incomplete gamma function.
# After simplification, we get the following expression.
# H(x) = (alpha-1)log(a) - log(a-1)
# + (alpha/(alpha-1))a^(alpha-1)(1-gamma(2, (alpha-1)log(a)))
# I followed the definition of lower incomplete gamma function from wikipedia.
# https://en.wikipedia.org/wiki/Incomplete_gamma_function#Definition
# In the very same wikipedia article there was a recursive formula,
# https://en.wikipedia.org/wiki/Incomplete_gamma_function#Properties
# which I used to simplify the expression. It gave me the following expression.
# H(x) = log(a) - log(alpha-1) + (alpha/(alpha-1))(a^(alpha-1) + 1)
Hx = jnp.log(self.low) - jnp.log(self.alpha - 1)
Hx += self.alpha / (self.alpha - 1) * (jnp.power(self.low, self.alpha - 1) + 1)
return Hx

0 comments on commit 0e4d35f

Please sign in to comment.