Skip to content

Commit

Permalink
Apply suggestions from code review
Browse files Browse the repository at this point in the history
Co-authored-by: Tomás Capretto <[email protected]>
  • Loading branch information
julianlheureux and tomicapretto authored Aug 15, 2024
1 parent 2a73648 commit 2bff2a5
Showing 1 changed file with 6 additions and 3 deletions.
9 changes: 6 additions & 3 deletions bambi/priors/scaler.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def scale_response(self):
def scale_intercept(self, term):
if term.prior.name != "Normal":
return
# Special case for logit links with bernoulli family
# Special case for logit links with bernoulli or binomial family
if (
isinstance(self.model.family, (Bernoulli, Binomial))
and self.model.family.link["p"].name == "logit"
Expand All @@ -82,7 +82,7 @@ def scale_common(self, term):

if term.data.ndim == 1:
mu = 0
# Special case for logit links with bernoulli family
# Special case for logit links with bernoulli or binomial family
if (
isinstance(self.model.family, (Bernoulli, Binomial))
and self.model.family.link["p"].name == "logit"
Expand All @@ -102,12 +102,14 @@ def scale_common(self, term):
# Single numerical term
else:
sigma = 1 / np.std(term.data, axis=0)
# If not, fall back to the regular case
else:
sigma = self.get_slope_sigma(term.data)
# It's a term that spans multiple columns of the design matrix
else:
mu = np.zeros(term.data.shape[1])
sigma = np.zeros(term.data.shape[1])
# Special case for logit links with bernoulli family
# Special case for logit links with bernoulli or binomial family
if (
isinstance(self.model.family, (Bernoulli, Binomial))
and self.model.family.link["p"].name == "logit"
Expand All @@ -121,6 +123,7 @@ def scale_common(self, term):
)
if all_categoric:
sigma[i] = 1
# It's the standard deviation of the marginal numerical variable (_not_ by group)
else:
sigma[i] = 1 / np.std(np.sum(term.data, axis=1))
# Single categorical term
Expand Down

0 comments on commit 2bff2a5

Please sign in to comment.