diff --git a/CHANGELOG.md b/CHANGELOG.md index cee18c66b..ff7b5c463 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -24,6 +24,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - MyPy checks in CI/CD and lint commands ([#39](https://github.com/tinkoff-ai/etna-ts/issues/39)) - TrendTransform ([#139](https://github.com/tinkoff-ai/etna-ts/pull/139)) - Running notebooks in ci ([#134](https://github.com/tinkoff-ai/etna-ts/issues/134)) +- Cluster plotter to EDA ([#169](https://github.com/tinkoff-ai/etna-ts/pull/169)) ### Changed - Delete offset from WindowStatisticsTransform ([#111](https://github.com/tinkoff-ai/etna-ts/pull/111)) diff --git a/etna/analysis/__init__.py b/etna/analysis/__init__.py index ac20ed69c..736f6d4ad 100644 --- a/etna/analysis/__init__.py +++ b/etna/analysis/__init__.py @@ -9,5 +9,6 @@ from etna.analysis.plotters import plot_anomalies from etna.analysis.plotters import plot_anomalies_interactive from etna.analysis.plotters import plot_backtest +from etna.analysis.plotters import plot_clusters from etna.analysis.plotters import plot_correlation_matrix from etna.analysis.plotters import plot_forecast diff --git a/etna/analysis/plotters.py b/etna/analysis/plotters.py index 24c4acf65..962b3337f 100644 --- a/etna/analysis/plotters.py +++ b/etna/analysis/plotters.py @@ -326,3 +326,40 @@ def update(**kwargs): plt.show() interact(update, **sliders) + + +def plot_clusters( + ts: "TSDataset", segment2cluster: Dict[str, int], centroids_df: Optional[pd.DataFrame] = None, columns_num: int = 2 +): + """Plot clusters [with centroids]. + + Parameters + ---------- + ts: + TSDataset with timeseries + segment2cluster: + mapping from segment to cluster in format {segment: cluster} + centroids_df: + dataframe with centroids + columns_num: + number of columns in subplots + """ + unique_clusters = sorted(set(segment2cluster.values())) + rows_num = math.ceil(len(unique_clusters) / columns_num) + fig, axs = plt.subplots(rows_num, columns_num, constrained_layout=True, figsize=(20, 5 * rows_num)) + for i, cluster in enumerate(unique_clusters): + segments = [segment for segment in segment2cluster if segment2cluster[segment] == cluster] + h, w = i // columns_num, i % columns_num + for segment in segments: + segment_slice = ts[:, segment, "target"] + axs[h][w].plot( + segment_slice.index.values, + segment_slice.values, + alpha=1 / math.sqrt(len(segments)), + c="blue", + ) + axs[h][w].set_title(f"cluster={cluster}\n{len(segments)} segments in cluster") + if centroids_df is not None: + centroid = centroids_df[cluster, "target"] + axs[h][w].plot(centroid.index.values, centroid.values, c="red", label="centroid") + axs[h][w].legend()