Skip to content

Commit

Permalink
Update docstring and args (#823)
Browse files Browse the repository at this point in the history
  • Loading branch information
GStechschulte authored Jul 8, 2024
1 parent 6d19a33 commit 3e80fb0
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 10 deletions.
6 changes: 5 additions & 1 deletion bambi/interpret/effects.py
Original file line number Diff line number Diff line change
Expand Up @@ -531,7 +531,11 @@ def predictions(

if pps:
idata = model.predict(
idata, data=cap_data, sample_new_groups=sample_new_groups, inplace=False, kind="pps"
idata,
data=cap_data,
sample_new_groups=sample_new_groups,
inplace=False,
kind="response",
)
y_hat = response_transform(idata["posterior_predictive"][response.name])
y_hat_mean = y_hat.mean(("chain", "draw"))
Expand Down
20 changes: 11 additions & 9 deletions bambi/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -812,20 +812,22 @@ def predict(
idata : InferenceData
The ``InferenceData`` instance returned by ``.fit()``.
kind : str
Indicates the type of prediction required. Can be ``"mean"`` or ``"pps"``. The
first returns draws from the posterior distribution of the mean, while the latter
returns the draws from the posterior predictive distribution (i.e. the posterior
probability distribution for a new observation) in addition to the mean posterior
distribution. Defaults to ``"mean"``.
Indicates the type of prediction required. Can be ``"response_params"`` or
``"response"``. The first returns draws from the posterior distribution of the
likelihood parameters, while the latter returns the draws from the posterior
predictive distribution (i.e. the posterior probability distribution for a new
observation) in addition to the posterior distribution. Defaults to
``"response_params"``.
data : pandas.DataFrame or None
An optional data frame with values for the predictors that are used to obtain
out-of-sample predictions. If omitted, the original dataset is used.
inplace : bool
If ``True`` it will modify ``idata`` in-place. Otherwise, it will return a copy of
``idata`` with the predictions added. If ``kind="mean"``, a new variable ending in
``"_mean"`` is added to the ``posterior`` group. If ``kind="pps"``, it appends a
``posterior_predictive`` group to ``idata``. If any of these already exist, it will be
overwritten.
``idata`` with the predictions added. If ``kind="response_params"``, a new variable
with the name of the parent parameter, e.g. ``"mu"`` and ``"sigma" for a Gaussian
likelihood, or ``"p"`` for a Bernoulli likelihood, is added to the ``posterior`` group.
If ``kind="response"``, it appends a ``posterior_predictive`` group to ``idata``. If
any of these already exist, it will be overwritten.
include_group_specific : bool
Determines if predictions incorporate group-specific effects. If ``False``, predictions
are made with common effects only (i.e. group specific are set to zero). Defaults to
Expand Down

0 comments on commit 3e80fb0

Please sign in to comment.