Skip to content

Commit

Permalink
BUG plot_clusters (#707)
Browse files Browse the repository at this point in the history
Co-authored-by: Artem Makhin <[email protected]>
  • Loading branch information
Ama16 and Artem Makhin authored May 25, 2022
1 parent c54dbb3 commit eebf273
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 9 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
-
-
-
-
- Fix bug in plot_clusters ([#675](https://github.com/tinkoff-ai/etna/pull/675))
-
- Fix bugs and documentation for cross_corr_plot ([#691](https://github.com/tinkoff-ai/etna/pull/691))
-
Expand Down
13 changes: 5 additions & 8 deletions etna/analysis/plotters.py
Original file line number Diff line number Diff line change
Expand Up @@ -788,25 +788,22 @@ def plot_clusters(
size of the figure per subplot with one segment in inches
"""
unique_clusters = sorted(set(segment2cluster.values()))
rows_num = math.ceil(len(unique_clusters) / columns_num)
figsize = (figsize[0] * columns_num, figsize[1] * rows_num)
fig, axs = plt.subplots(rows_num, columns_num, constrained_layout=True, figsize=figsize)
_, ax = prepare_axes(num_plots=len(unique_clusters), columns_num=columns_num, figsize=figsize)
for i, cluster in enumerate(unique_clusters):
segments = [segment for segment in segment2cluster if segment2cluster[segment] == cluster]
h, w = i // columns_num, i % columns_num
for segment in segments:
segment_slice = ts[:, segment, "target"]
axs[h][w].plot(
ax[i].plot(
segment_slice.index.values,
segment_slice.values,
alpha=1 / math.sqrt(len(segments)),
c="blue",
)
axs[h][w].set_title(f"cluster={cluster}\n{len(segments)} segments in cluster")
ax[i].set_title(f"cluster={cluster}\n{len(segments)} segments in cluster")
if centroids_df is not None:
centroid = centroids_df[cluster, "target"]
axs[h][w].plot(centroid.index.values, centroid.values, c="red", label="centroid")
axs[h][w].legend()
ax[i].plot(centroid.index.values, centroid.values, c="red", label="centroid")
ax[i].legend()


def plot_time_series_with_change_points(
Expand Down

1 comment on commit eebf273

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please sign in to comment.