-
Notifications
You must be signed in to change notification settings - Fork 6
/
plotting_utils.py
64 lines (52 loc) · 2.1 KB
/
plotting_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
# -*- coding: utf-8 -*-
"""
@author: Makan Arastuie
"""
import numpy as np
import matplotlib.pyplot as plt
def heatmap(data, row_labels, col_labels, ax=None, cbar_kw={}, cbarlabel="", color_bar_format='%.1e',
grid_color='#e7dadb', font_size=20, **kwargs):
"""
Based on https://matplotlib.org/3.1.1/gallery/images_contours_and_fields/image_annotated_heatmap.html
Create a heatmap from a numpy array and two lists of labels.
Parameters
----------
data
A 2D numpy array of shape (N, M).
row_labels
A list or array of length N with the labels for the rows.
col_labels
A list or array of length M with the labels for the columns.
ax
A `matplotlib.axes.Axes` instance to which the heatmap is plotted. If
not provided, use current axes or create a new one. Optional.
cbar_kw
A dictionary with arguments to `matplotlib.Figure.colorbar`. Optional.
cbarlabel
The label for the colorbar. Optional.
**kwargs
All other arguments are forwarded to `imshow`.
"""
if not ax:
ax = plt.gca()
# Plot the heatmap
im = ax.imshow(data, **kwargs)
# Create colorbar
cbar = ax.figure.colorbar(im, ax=ax, **cbar_kw, format=color_bar_format)
# cbar = ax.figure.colorbar(im, ax=ax, **cbar_kw)
cbar.ax.set_ylabel(cbarlabel, rotation=-90, va="bottom", fontsize=font_size)
cbar.ax.tick_params(labelsize=font_size)
# We want to show all ticks...
ax.set_xticks(np.arange(data.shape[1]))
ax.set_yticks(np.arange(data.shape[0]))
# ... and label them with the respective list entries.
ax.set_xticklabels(col_labels, fontsize=font_size)
ax.set_yticklabels(row_labels, fontsize=font_size)
# Turn spines off and create white grid.
for edge, spine in ax.spines.items():
spine.set_visible(False)
ax.set_xticks(np.arange(data.shape[1]+1)-.5, minor=True)
ax.set_yticks(np.arange(data.shape[0]+1)-.5, minor=True)
ax.grid(which="minor", color=grid_color, linestyle='-', linewidth=2)
ax.tick_params(which="minor", bottom=False, left=False)
return im, cbar