From 45e2229a6fd999d7f3ed34d00cf6de6210edd8f0 Mon Sep 17 00:00:00 2001 From: Steffen Schneider Date: Sun, 7 Jan 2024 18:48:30 +0100 Subject: [PATCH] Pass showlegend/template kwargs to plotly (#122) * Pass showlegend/template kwargs to plotly * Update plotly.py - minor fix * fix dropped " * Fix showlegend for plotly * Fix showlegend for plotly * Add handling of continuous labels * Add tests for plotly integration * fix typo --------- Co-authored-by: Mackenzie Mathis Co-authored-by: Anastasiia Filippova --- cebra/integrations/plotly.py | 77 +++++++++++++++++++++++++----------- tests/test_plotly.py | 35 ++++++++++++++++ 2 files changed, 89 insertions(+), 23 deletions(-) diff --git a/cebra/integrations/plotly.py b/cebra/integrations/plotly.py index 9dd4fa02..08450062 100644 --- a/cebra/integrations/plotly.py +++ b/cebra/integrations/plotly.py @@ -94,33 +94,55 @@ def _plot_3d(self, **kwargs) -> plotly.graph_objects.Figure: Returns: The axis :py:meth:`plotly.graph_objs._figure.Figure` of the plot. """ - - idx1, idx2, idx3 = self.idx_order - data = [ - plotly.graph_objects.Scatter3d( - x=self.embedding[:, idx1], - y=self.embedding[:, idx2], - z=self.embedding[:, idx3], - mode="markers", - marker=dict( - size=self.markersize, - opacity=self.alpha, - color=self.embedding_labels, - colorscale=self.colorscale, - ), - ) - ] + showlegend = kwargs.get("showlegend", False) + discrete = kwargs.get("discrete", False) col = kwargs.get("col", None) row = kwargs.get("row", None) + template = kwargs.get("template", "plotly_white") + data = [] - if col is None or row is None: - self.axis.add_trace(data[0]) + if not discrete and showlegend: + raise ValueError("Cannot show legend with continuous labels.") + + idx1, idx2, idx3 = self.idx_order + + if discrete: + unique_labels = np.unique(self.embedding_labels) else: - self.axis.add_trace(data[0], row=row, col=col) + unique_labels = [self.embedding_labels] + + for label in unique_labels: + if discrete: + filtered_idx = [ + i for i, x in enumerate(self.embedding_labels) if x == label + ] + else: + filtered_idx = np.arange(self.embedding.shape[0]) + data.append( + plotly.graph_objects.Scatter3d(x=self.embedding[filtered_idx, + idx1], + y=self.embedding[filtered_idx, + idx2], + z=self.embedding[filtered_idx, + idx3], + mode="markers", + marker=dict( + size=self.markersize, + opacity=self.alpha, + color=label, + colorscale=self.colorscale, + ), + name=str(label))) + + for trace in data: + if col is None or row is None: + self.axis.add_trace(trace) + else: + self.axis.add_trace(trace, row=row, col=col) self.axis.update_layout( - template="plotly_white", - showlegend=False, + template=template, + showlegend=showlegend, title=self.title, ) @@ -166,8 +188,17 @@ def plot_embedding_interactive( title: The title on top of the embedding. figsize: Figure width and height in inches. dpi: Figure resolution. - kwargs: Optional arguments to customize the plots. See :py:class:`plotly.graph_objects.Scatter` documentation for more - details on which arguments to use. + kwargs: Optional arguments to customize the plots. This dictionary includes the following optional arguments: + -- showlegend: Whether to show the legend or not. + -- discrete: Whether the labels are discrete or not. + -- col: The column of the subplot to plot the embedding on. + -- row: The row of the subplot to plot the embedding on. + -- template: The template to use for the plot. + + Note: showlegend can be True only if discrete is True. + + See :py:class:`plotly.graph_objects.Scatter` documentation for more + details on which arguments to use. Returns: The plotly figure. diff --git a/tests/test_plotly.py b/tests/test_plotly.py index 90fb64ed..187d234c 100644 --- a/tests/test_plotly.py +++ b/tests/test_plotly.py @@ -84,3 +84,38 @@ def test_plot_embedding(output_dimension, idx_order): fig_subplots.data = [] fig_subplots.layout = {} + + +def test_discrete_with_legend(): + embedding = np.random.uniform(0, 1, (1000, 3)) + labels = np.random.randint(0, 10, (1000,)) + + fig = cebra_plotly.plot_embedding_interactive(embedding, + labels, + discrete=True, + showlegend=True) + + assert len(fig._data_objs) == np.unique(labels).shape[0] + assert isinstance(fig, go.Figure) + + +def test_continuous_no_legend(): + embedding = np.random.uniform(0, 1, (1000, 3)) + labels = np.random.uniform(0, 1, (1000,)) + + fig = cebra_plotly.plot_embedding_interactive(embedding, labels) + + assert len(fig._data_objs) == 1 + + assert isinstance(fig, go.Figure) + + +def test_continuous_with_legend_raises_error(): + embedding = np.random.uniform(0, 1, (1000, 3)) + labels = np.random.uniform(0, 1, (1000,)) + + with pytest.raises(ValueError): + cebra_plotly.plot_embedding_interactive(embedding, + labels, + discrete=False, + showlegend=True)