Skip to content

Commit

Permalink
Add plot_clusters (#169)
Browse files Browse the repository at this point in the history
* Add plot_clusters

* Update CHANGELOG
  • Loading branch information
alex-hse-repository authored Oct 11, 2021
1 parent 50f2074 commit 716da3a
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 0 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- MyPy checks in CI/CD and lint commands ([#39](https://github.com/tinkoff-ai/etna-ts/issues/39))
- TrendTransform ([#139](https://github.com/tinkoff-ai/etna-ts/pull/139))
- Running notebooks in ci ([#134](https://github.com/tinkoff-ai/etna-ts/issues/134))
- Cluster plotter to EDA ([#169](https://github.com/tinkoff-ai/etna-ts/pull/169))

### Changed
- Delete offset from WindowStatisticsTransform ([#111](https://github.com/tinkoff-ai/etna-ts/pull/111))
Expand Down
1 change: 1 addition & 0 deletions etna/analysis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,6 @@
from etna.analysis.plotters import plot_anomalies
from etna.analysis.plotters import plot_anomalies_interactive
from etna.analysis.plotters import plot_backtest
from etna.analysis.plotters import plot_clusters
from etna.analysis.plotters import plot_correlation_matrix
from etna.analysis.plotters import plot_forecast
37 changes: 37 additions & 0 deletions etna/analysis/plotters.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,3 +326,40 @@ def update(**kwargs):
plt.show()

interact(update, **sliders)


def plot_clusters(
ts: "TSDataset", segment2cluster: Dict[str, int], centroids_df: Optional[pd.DataFrame] = None, columns_num: int = 2
):
"""Plot clusters [with centroids].
Parameters
----------
ts:
TSDataset with timeseries
segment2cluster:
mapping from segment to cluster in format {segment: cluster}
centroids_df:
dataframe with centroids
columns_num:
number of columns in subplots
"""
unique_clusters = sorted(set(segment2cluster.values()))
rows_num = math.ceil(len(unique_clusters) / columns_num)
fig, axs = plt.subplots(rows_num, columns_num, constrained_layout=True, figsize=(20, 5 * rows_num))
for i, cluster in enumerate(unique_clusters):
segments = [segment for segment in segment2cluster if segment2cluster[segment] == cluster]
h, w = i // columns_num, i % columns_num
for segment in segments:
segment_slice = ts[:, segment, "target"]
axs[h][w].plot(
segment_slice.index.values,
segment_slice.values,
alpha=1 / math.sqrt(len(segments)),
c="blue",
)
axs[h][w].set_title(f"cluster={cluster}\n{len(segments)} segments in cluster")
if centroids_df is not None:
centroid = centroids_df[cluster, "target"]
axs[h][w].plot(centroid.index.values, centroid.values, c="red", label="centroid")
axs[h][w].legend()

0 comments on commit 716da3a

Please sign in to comment.