Skip to content

Commit

Permalink
Add warnings section, fix bug with constrained_layout
Browse files Browse the repository at this point in the history
  • Loading branch information
d.a.bunin committed Mar 9, 2022
1 parent 6861b7d commit 0e48462
Showing 1 changed file with 12 additions and 2 deletions.
14 changes: 12 additions & 2 deletions etna/analysis/plotters.py
Original file line number Diff line number Diff line change
Expand Up @@ -849,13 +849,19 @@ def plot_feature_relevance(
segments = sorted(ts.segments)

is_ascending = not relevance_table.greater_is_better
features = set(ts.columns.get_level_values("feature")) - {"target"}
features = list(set(ts.columns.get_level_values("feature")) - {"target"})
relevance_df = relevance_table(df=ts[:, :, "target"], df_exog=ts[:, :, features], **relevance_params).loc[segments]

if relevance_aggregation_mode == "per-segment":
ax = prepare_axes(segments=segments, columns_num=columns_num, figsize=figsize)
for i, segment in enumerate(segments):
relevance = relevance_df.loc[segment].sort_values(ascending=is_ascending)
# warning about NaNs
if relevance.isna().any():
na_relevance_features = relevance[relevance.isna()].index.tolist()
warnings.warn(
f"Relevances on segment: {segment} of features: {na_relevance_features} can't be calculated."
)
relevance = relevance.dropna()[:top_k]
if normalized:
relevance = relevance / relevance.sum()
Expand All @@ -866,10 +872,14 @@ def plot_feature_relevance(
relevance_aggregation_fn = AGGREGATION_FN[AggregationMode(relevance_aggregation_mode)]
relevance = relevance_df.apply(lambda x: relevance_aggregation_fn(x[~x.isna()])) # type: ignore
relevance = relevance.sort_values(ascending=is_ascending)
# warning about NaNs
if relevance.isna().any():
na_relevance_features = relevance[relevance.isna()].index.tolist()
warnings.warn(f"Relevances of features: {na_relevance_features} can't be calculated.")
# if top_k == None, all the values are selected
relevance = relevance.dropna()[:top_k]
if normalized:
relevance = relevance / relevance.sum()
_, ax = plt.subplots(figsize=figsize)
_, ax = plt.subplots(figsize=figsize, constrained_layout=True)
sns.barplot(x=relevance.values, y=relevance.index, orient="h", ax=ax)
ax.set_title("Feature relevance") # type: ignore

0 comments on commit 0e48462

Please sign in to comment.