Skip to content

Commit

Permalink
feat(plots): add normalized data to rainbow plot
Browse files Browse the repository at this point in the history
Signed-off-by: Cameron Smith <[email protected]>
  • Loading branch information
cameronraysmith committed Aug 7, 2024
1 parent 0dfef1e commit 7ddc5f0
Showing 1 changed file with 26 additions and 3 deletions.
29 changes: 26 additions & 3 deletions src/pyrovelocity/plots/_rainbow.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def rainbowplot(
negative_correlation: bool = False,
scvelo_colors: bool = False,
save_plot: bool = False,
show_data: bool = True,
rainbow_plot_path: str | Path = "rainbow.pdf",
) -> FigureBase:
set_font_size(7)
Expand All @@ -43,16 +44,22 @@ def rainbowplot(
number_of_genes = len(genes)

subplot_height = 1
subplot_width = subplot_height * 2.0 * 3

if show_data:
horizontal_panels = 4
else:
horizontal_panels = 3

subplot_width = 2.0 * subplot_height * horizontal_panels

if fig is None:
fig, ax = plt.subplots(
number_of_genes,
3,
horizontal_panels,
figsize=(subplot_width, subplot_height * number_of_genes),
)
else:
ax = fig.subplots(number_of_genes, 3)
ax = fig.subplots(number_of_genes, horizontal_panels)

if scvelo_colors:
colors = setup_scvelo_colors(adata, cell_state)
Expand All @@ -66,10 +73,16 @@ def rainbowplot(
ax1 = ax[n, 1]
ax2 = ax[n, 0]
ax3 = ax[n, 2]
if show_data:
ax4 = ax[n, 3]

if n == 0:
ax1.set_title("Rainbow plot", fontsize=7)
ax2.set_title("Phase portrait", fontsize=7)
ax3.set_title("Denoised spliced", fontsize=7)
if show_data:
ax4.set_title("Spliced data", fontsize=7)

plot_gene(ax1, ress, colors, add_line)
scatterplot(ax2, ress, colors)
(index,) = np.where(adata.var_names == gene)
Expand All @@ -82,6 +95,16 @@ def rainbowplot(
)
set_colorbar(im, ax3, labelsize=5, fig=fig, rainbow=True)
ax3.axis("off")
if show_data:
im = ax4.scatter(
adata.obsm[f"X_{basis}"][:, 0],
adata.obsm[f"X_{basis}"][:, 1],
s=3,
c=adata.layers["spliced"][:, index].flatten(),
cmap="RdBu_r",
)
set_colorbar(im, ax4, labelsize=5, fig=fig, rainbow=True)
ax4.axis("off")
set_labels(ax1, ax2, ax3, gene, number_of_genes, ress, n)

sns.despine()
Expand Down

0 comments on commit 7ddc5f0

Please sign in to comment.