Skip to content

issue-726-fix #753

Merged
merged 7 commits into from
Jun 16, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
-
-
### Changed
- Add columns and mode parameters in plot_correlation_matrix ([#726](https://github.com/tinkoff-ai/etna/pull/753))
-
-
-
Expand Down
57 changes: 48 additions & 9 deletions etna/analysis/plotters.py
Original file line number Diff line number Diff line change
Expand Up @@ -633,14 +633,19 @@ def plot_anomalies(


def get_correlation_matrix(
ts: "TSDataset", segments: Optional[List[str]] = None, method: str = "pearson"
ts: "TSDataset",
columns: Optional[List[str]] = None,
segments: Optional[List[str]] = None,
method: str = "pearson",
) -> np.ndarray:
"""Compute pairwise correlation of timeseries for selected segments.

Parameters
----------
ts:
TSDataset with timeseries data
columns:
Columns to use, if None use all columns
segments:
Segments to use
method:
Expand All @@ -659,16 +664,23 @@ def get_correlation_matrix(
"""
if method not in ["pearson", "kendall", "spearman"]:
raise ValueError(f"'{method}' is not a valid method of correlation.")

if segments is None:
segments = sorted(ts.segments)
correlation_matrix = ts[:, segments, :].corr(method=method).values
if columns is None:
columns = list(set(ts.df.columns.get_level_values("feature")))

correlation_matrix = ts[:, segments, columns].corr(method=method).values
return correlation_matrix


def plot_correlation_matrix(
ts: "TSDataset",
columns: Optional[List[str]] = None,
segments: Optional[List[str]] = None,
method: str = "pearson",
mode: str = "macro",
columns_num: int = 2,
figsize: Tuple[int, int] = (10, 10),
**heatmap_kwargs,
):
Expand All @@ -678,6 +690,8 @@ def plot_correlation_matrix(
----------
ts:
TSDataset with timeseries data
columns:
Columns to use, if None use all columns
segments:
Segments to use
method:
Expand All @@ -689,23 +703,48 @@ def plot_correlation_matrix(

* spearman: Spearman rank correlation

mode: 'macro' or 'per-segment'
Aggregation mode
columns_num:
Number of subplots columns
figsize:
size of the figure in inches
"""
if segments is None:
segments = sorted(ts.segments)
if columns is None:
columns = list(set(ts.df.columns.get_level_values("feature")))
if "vmin" not in heatmap_kwargs:
heatmap_kwargs["vmin"] = -1
if "vmax" not in heatmap_kwargs:
heatmap_kwargs["vmax"] = 1

correlation_matrix = get_correlation_matrix(ts, segments, method)
fig, ax = plt.subplots(figsize=figsize)
ax = sns.heatmap(correlation_matrix, annot=True, fmt=".1g", square=True, ax=ax, **heatmap_kwargs)
labels = list(ts[:, segments, :].columns.values)
ax.set_xticklabels(labels, rotation=45, horizontalalignment="right")
ax.set_yticklabels(labels, rotation=0, horizontalalignment="right")
ax.set_title("Correlation Heatmap")
if mode not in ["macro", "per-segment"]:
raise ValueError(f"'{mode}' is not a valid method of mode.")

if mode == "macro":
fig, ax = plt.subplots(figsize=figsize)
correlation_matrix = get_correlation_matrix(ts, columns, segments, method)
labels = list(ts[:, segments, columns].columns.values)
ax = sns.heatmap(correlation_matrix, annot=True, fmt=".1g", square=True, ax=ax, **heatmap_kwargs)
ax.set_xticks(np.arange(len(labels)) + 0.5, labels=labels)
ax.set_yticks(np.arange(len(labels)) + 0.5, labels=labels)
plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor")
plt.setp(ax.get_yticklabels(), rotation=0, ha="right", rotation_mode="anchor")
ax.set_title("Correlation Heatmap")

if mode == "per-segment":
fig, ax = prepare_axes(len(segments), columns_num=columns_num, figsize=figsize)

for i, segment in enumerate(segments):
correlation_matrix = get_correlation_matrix(ts, columns, [segment], method)
labels = list(ts[:, segment, columns].columns.values)
ax[i] = sns.heatmap(correlation_matrix, annot=True, fmt=".1g", square=True, ax=ax[i], **heatmap_kwargs)
ax[i].set_xticks(np.arange(len(labels)) + 0.5, labels=labels)
ax[i].set_yticks(np.arange(len(labels)) + 0.5, labels=labels)
plt.setp(ax[i].get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor")
plt.setp(ax[i].get_yticklabels(), rotation=0, ha="right", rotation_mode="anchor")
ax[i].set_title("Correlation Heatmap" + " " + segment)


def plot_anomalies_interactive(
Expand Down