Skip to content

Commit

Permalink
improved OOP with dataclasses, error handling, and added unit-level c…
Browse files Browse the repository at this point in the history
…ontrasts 'average_by=True'
  • Loading branch information
GStechschulte committed Jun 28, 2023
1 parent 3ccb833 commit 40ecc50
Show file tree
Hide file tree
Showing 4 changed files with 118 additions and 86 deletions.
64 changes: 34 additions & 30 deletions bambi/plots/create_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,13 @@
from bambi.models import Model
from bambi.utils import clean_formula_lhs
from bambi.plots.utils import (
ComparisonInfo,
ConditionalInfo,
ContrastInfo,
enforce_dtypes,
get_covariates,
get_model_covariates,
make_group_panel_values,
make_main_values,
set_default_contrast_values,
set_default_values,
)

Expand Down Expand Up @@ -51,17 +50,21 @@ def create_cap_data(model: Model, covariates: dict) -> pd.DataFrame:
return enforce_dtypes(data, pd.DataFrame(data_dict))


def create_comparisons_data(comparisons: ComparisonInfo, user_passed: bool = False):
def create_comparisons_data(
condition: ConditionalInfo,
contrast: ContrastInfo,
user_passed: bool = False
) -> pd.DataFrame:
"""Create data for a Conditional Adjusted Comparisons
Parameters
----------
model : bambi.Model
An instance of a Bambi model
comparisons : ComparisonInfo
The name of the predictor to be used in the comparisons.
An dataclass instance containing the model, contrast, and conditional
covariates to be used in the comparisons.
user_passed : bool, optional
Whether the user passed data to the model. Defaults to False.
Whether the user passed their own 'conditional' data to determine the
conditional data. Defaults to False.
Returns
-------
Expand All @@ -70,24 +73,22 @@ def create_comparisons_data(comparisons: ComparisonInfo, user_passed: bool = Fal
plotting.
"""

def grid_level(comparisons: ComparisonInfo, contrast: ContrastInfo):
def _grid_level(condition: ConditionalInfo, contrast: ContrastInfo):
"""
Creates the data for grid-level contrasts by using the covariates passed
into the `conditional` arg. Values for the grid are either: (1) computed
using a equally spaced grid, mean, and or mode (depending on the covariate
dtype), and (2) a user specified value or range of values.
"""
covariates = get_covariates(comparisons.conditional)
model_covariates = clean_formula_lhs(str(comparisons.model.formula.main)).strip()
model_covariates = model_covariates.split(" ")
covariates = get_covariates(condition.conditional)

# if user passed data, then only need to compute default values for
# unspecified covariates in the model
if user_passed:
data_dict = {**comparisons.conditional}
data_dict = {**condition.conditional}
else:
# if user did not pass data, then compute default values for the
# covariates specified in the `conditional` arg.
main_values = make_main_values(comparisons.model.data[covariates.main])
main_values = make_main_values(condition.model.data[covariates.main])
data_dict = {covariates.main: main_values}
data_dict = make_group_panel_values(
comparisons.model.data,
condition.model.data,
data_dict,
covariates.main,
covariates.group,
Expand All @@ -96,36 +97,39 @@ def grid_level(comparisons: ComparisonInfo, contrast: ContrastInfo):
)

data_dict[contrast.name] = contrast.values
comparison_data = set_default_values(comparisons.model, data_dict, kind="comparison")
comparison_data = set_default_values(condition.model, data_dict, kind="comparison")
# use cartesian product (cross join) to create contrasts
keys, values = zip(*comparison_data.items())
contrast_dict = [dict(zip(keys, v)) for v in itertools.product(*values)]

return enforce_dtypes(comparisons.model.data, pd.DataFrame(contrast_dict))
return enforce_dtypes(condition.model.data, pd.DataFrame(contrast_dict))


def unit_level(comparisons: ComparisonInfo, contrast: ContrastInfo):
def _unit_level(comparisons: ConditionalInfo, contrast: ContrastInfo):
"""
Creates the data for unit-level contrasts by using the observed (empirical)
data. All covariates in the model are included in the data, except for the
contrast predictor. The contrast predictor is replaced with either: (1) the
default contrast value, or (2) the user specified contrast value.
"""
covariates = get_model_covariates(comparisons.model)
df = comparisons.model.data[covariates].drop(labels=contrast.name, axis=1)
covariates = get_model_covariates(contrast.model)
df = contrast.model.data[covariates].drop(labels=contrast.name, axis=1)

contrast_vals = np.array(contrast.values)[..., None]
contrast_vals = np.repeat(contrast_vals, comparisons.model.data.shape[0], axis=1)
contrast_vals = np.repeat(contrast_vals, contrast.model.data.shape[0], axis=1)

contrast_df_dict = {}
for idx, value in enumerate(contrast_vals):
contrast_df_dict[f"contrast_{idx}"] = df.copy()
contrast_df_dict[f"contrast_{idx}"][contrast.name] = value

return pd.concat(contrast_df_dict.values())

return pd.concat(contrast_df_dict.values())

contrast = ContrastInfo(comparisons.contrast_predictor, comparisons.model)

if not comparisons.conditional:
df = unit_level(comparisons, contrast)
if not condition.conditional:
df = _unit_level(condition, contrast)
else:
df = grid_level(comparisons, contrast)
df = _grid_level(condition, contrast)

return df

89 changes: 41 additions & 48 deletions bambi/plots/effects.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from bambi.models import Model
from bambi.plots.create_data import create_cap_data, create_comparisons_data
from bambi.plots.utils import average_over, ComparisonInfo, identity
from bambi.plots.utils import average_over, ConditionalInfo, ContrastInfo, identity
from bambi.utils import get_aliased_name, listify


Expand Down Expand Up @@ -121,12 +121,6 @@ def predictions(
return cap_data


@dataclass
class Contrast:
term: Union[str, list]
value: Union[list, np.ndarray, None]


@dataclass
class Response:
name: str
Expand Down Expand Up @@ -190,47 +184,45 @@ def comparisons(
When ``prob`` is not > 0 and < 1.
"""

if isinstance(contrast, (dict, list)):
if not isinstance(contrast, (dict, list, str)):
raise ValueError("'contrast' must be a string, dictionary, or list.")
elif isinstance(contrast, (dict, list)):
if len(contrast) > 1:
raise ValueError(
f"Only one contrast predictor can be passed. {len(contrast)} were passed."
)
if isinstance(contrast, dict):
contrast_name = next(iter(contrast.keys()))
else:
contrast_name = contrast[0]
elif isinstance(contrast, str):
contrast_name = contrast
else:
raise ValueError("'contrast' must be a string, dictionary, or list.")

if comparison_type not in ("diff", "ratio"):
raise ValueError("'comparison_type' must be 'diff' or 'ratio'")

comparison_functions = {"diff": lambda x, y: x - y, "ratio": lambda x, y: x / y}

if prob is None:
prob = az.rcParams["stats.hdi_prob"]
if not 0 < prob < 1:
raise ValueError(f"'prob' must be greater than 0 and smaller than 1. It is {prob}.")

comparison_functions = {"diff": lambda x, y: x - y, "ratio": lambda x, y: x / y}
lower_bound = round((1 - prob) / 2, 4)
upper_bound = 1 - lower_bound

covariate_kinds = ("main", "group", "panel")
contrast_info = ContrastInfo(model, contrast)
covariate_kinds = ("main", "group", "panel") #todo: 'comparisons' should not be restricted to these...
# if not dict, then user:
# 1.) passed a covariate, but not values to condtion on, or
# 2.) passed None
if not isinstance(conditional, dict):
conditional = listify(conditional)
conditional = dict(zip(covariate_kinds, conditional))
comparisons_df = create_comparisons_data(
Comparison(model, contrast, conditional), user_passed=False
ConditionalInfo(model, conditional),
contrast_info,
user_passed=False
)
# if dict, user passed values to condition on
elif isinstance(conditional, dict):
comparisons_df = create_comparisons_data(
Comparison(model, contrast, conditional), user_passed=True
ConditionalInfo(model, conditional),
contrast_info,
user_passed=True
)
conditional = {k: listify(v) for k, v in conditional.items()}
conditional = dict(zip(covariate_kinds, conditional))
Expand All @@ -239,12 +231,13 @@ def comparisons(
transforms = {}

response_name = get_aliased_name(model.response_component.response_term)
response = Response(response_name)

# perform predictions on new data
idata = model.predict(idata, data=comparisons_df, inplace=False)

def _compute_contrast_estimate(
contrast: Contrast,
contrast: ContrastInfo,
response: Response,
comparisons_df: pd.DataFrame,
idata: az.InferenceData,
Expand All @@ -257,8 +250,8 @@ def _compute_contrast_estimate(

# subset draw by observation using contrast mask
draws = {}
for idx, val in enumerate(contrast.value):
mask = np.array(comparisons_df[contrast.term] == contrast.value[idx])
for idx, val in enumerate(contrast.values):
mask = np.array(comparisons_df[contrast.name] == contrast.values[idx])
select_draw = idata.posterior[f"{response.name}_{response.target}"].sel(
{f"{response.name}_obs": mask}
)
Expand All @@ -268,7 +261,7 @@ def _compute_contrast_estimate(
draws[val] = select_draw

# iterate over pairwise combinations of contrast values
pairwise_contrasts = list(itertools.combinations(contrast.value, 2))
pairwise_contrasts = list(itertools.combinations(contrast.values, 2))

# compute mean comparison and HDI for each pairwise comparison
comparison_mean = {}
Expand All @@ -287,7 +280,7 @@ def _compute_contrast_estimate(
return ContrastEstimate(comparison_mean, comparison_bounds)

def _build_contrasts_df(
contrast: Contrast,
contrast: ContrastInfo,
response: Response,
comparisons_df: pd.DataFrame,
idata: az.InferenceData,
Expand All @@ -301,57 +294,57 @@ def _build_contrasts_df(
contrast_estimate = _compute_contrast_estimate(contrast, response, comparisons_df, idata)

# if two contrast values, then can drop duplicates to build contrast_df
if len(contrast.value) < 3:
if len(contrast.values) < 3:
if not any(conditional.values()):
contrast_df = model.data[comparisons_df.columns].drop(columns=contrast.term)
contrast_df = model.data[comparisons_df.columns].drop(columns=contrast.name)
num_rows = contrast_df.shape[0]
contrast_df.insert(0, "term", contrast.term)
contrast_df.insert(0, "term", contrast.name)
contrast_df.insert(
1, "contrast", list(np.tile(contrast.value, num_rows).reshape(num_rows, 2))
1, "contrast", list(np.tile(contrast.values, num_rows).reshape(num_rows, 2))
)
contrast_df["estimate"] = contrast_estimate.comparison[tuple(contrast.value)].to_numpy()
contrast_df["estimate"] = contrast_estimate.comparison[tuple(contrast.values)].to_numpy()
else:
contrast_df = comparisons_df.drop_duplicates(list(conditional.values())).reset_index(
drop=True
)
contrast_df = contrast_df.drop(columns=contrast.term)
contrast_df = contrast_df.drop(columns=contrast.name)
num_rows = contrast_df.shape[0]
contrast_df.insert(0, "term", contrast.term)
contrast_df.insert(0, "term", contrast.name)
contrast_df.insert(
1, "contrast", list(np.tile(contrast.value, num_rows).reshape(num_rows, 2))
1, "contrast", list(np.tile(contrast.values, num_rows).reshape(num_rows, 2))
)
contrast_df["estimate"] = contrast_estimate.comparison[tuple(contrast.value)].to_numpy()
contrast_df["estimate"] = contrast_estimate.comparison[tuple(contrast.values)].to_numpy()

if use_hdi:
contrast_df[f"hdi_{lower_bound}%"] = (
contrast_estimate.hdi[tuple(contrast.value)][
contrast_estimate.hdi[tuple(contrast.values)][
f"{response.name}_{response.target}"
]
.sel(hdi="lower")
.values
)
contrast_df[f"hdi_{upper_bound}%"] = (
contrast_estimate.hdi[tuple(contrast.value)][
contrast_estimate.hdi[tuple(contrast.values)][
f"{response.name}_{response.target}"
]
.sel(hdi="higher")
.values
)
else:
contrast_df[f"lower_{lower_bound}%"] = contrast_estimate.hdi[
tuple(contrast.value)
tuple(contrast.values)
].sel(quantile=lower_bound)
contrast_df[f"upper_{upper_bound}%"] = contrast_estimate.hdi[
tuple(contrast.value)
tuple(contrast.values)
].sel(quantile=upper_bound)

# if > 2 contrast values, then need the full dataframe to build contrast_df
elif len(contrast.value) >= 3:
elif len(contrast.values) >= 3:
num_rows = comparisons_df.shape[0]
contrast_df = comparisons_df.drop(columns=contrast.term)
contrast_df.insert(0, "term", contrast.term)
contrast_df = comparisons_df.drop(columns=contrast.name)
contrast_df.insert(0, "term", contrast.name)
contrast_keys = [list(elem) for elem in list(contrast_estimate.comparison.keys())]
contrast_df.insert(1, "contrast", contrast_keys * (num_rows // len(contrast.value)))
contrast_df.insert(1, "contrast", contrast_keys * (num_rows // len(contrast.values)))

estimates = []
for val in contrast_estimate.comparison.values():
Expand Down Expand Up @@ -388,15 +381,15 @@ def _build_contrasts_df(
if average_by:
if average_by is True:
contrast_df_avg = average_over(contrast_df, None)
contrast_df_avg.insert(0, "term", contrast.term)
contrast_df_avg.insert(0, "term", contrast.name)
contrast_df_avg.insert(
1,
"contrast",
np.tile(contrast_df["contrast"].drop_duplicates(), len(contrast_df_avg))
)
else:
contrast_df_avg = average_over(contrast_df, average_by)
contrast_df_avg.insert(0, "term", contrast.term)
contrast_df_avg.insert(0, "term", contrast.name)
contrast_df_avg.insert(
1,
"contrast",
Expand All @@ -406,10 +399,10 @@ def _build_contrasts_df(
else:
return contrast_df

contrast_vals = np.sort(np.unique(comparisons_df[contrast_name]))

contrast_df = _build_contrasts_df(
ContrastInfo(contrast_name, contrast_vals),
Response(response_name),
contrast_info,
response,
comparisons_df,
idata,
average_by,
Expand Down
Loading

0 comments on commit 40ecc50

Please sign in to comment.