Skip to content

Commit

Permalink
Merge pull request #810 from AFM-SPM/ns-rse/806-mpl-test-image-no-axe…
Browse files Browse the repository at this point in the history
…s-no-colorbar

Tidies up Images.plot_and_save()
  • Loading branch information
ns-rse authored Mar 12, 2024
2 parents 7ca2e58 + f2d0b78 commit 75e6ec4
Show file tree
Hide file tree
Showing 6 changed files with 52 additions and 62 deletions.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
34 changes: 9 additions & 25 deletions tests/test_plottingfuncs.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import pytest
from matplotlib.axes import Axes
from matplotlib.figure import Figure
from skimage import io

from topostats.grains import Grains
from topostats.io import LoadScans
Expand Down Expand Up @@ -108,29 +107,17 @@ def test_save_figure(
assert isinstance(ax, Axes)


def test_save_array_figure(tmp_path: Path):
"""Tests that the image array is saved."""
rng2 = np.random.default_rng()
Images(
data=rng2.random((10, 10)),
output_dir=tmp_path,
filename="result",
).save_array_figure()
assert Path(tmp_path / "result.png").exists()


@pytest.mark.mpl_image_compare(baseline_dir="resources/img/")
def test_plot_and_save_no_colorbar(load_scan_data: LoadScans, tmp_path: Path) -> None:
def test_plot_and_save_no_colorbar(load_scan_data: LoadScans, plotting_config: dict, tmp_path: Path) -> None:
"""Test plotting without colorbar."""
plotting_config["colorbar"] = False
fig, _ = Images(
data=load_scan_data.image,
output_dir=tmp_path,
filename="01-raw_heightmap",
pixel_to_nm_scaling=load_scan_data.pixel_to_nm_scaling,
title="Raw Height",
colorbar=False,
axes=True,
image_set="all",
**plotting_config,
).plot_and_save()
return fig

Expand All @@ -145,17 +132,15 @@ def test_plot_histogram_and_save(load_scan_data: LoadScans, tmp_path: Path) -> N


@pytest.mark.mpl_image_compare(baseline_dir="resources/img/")
def test_plot_and_save_colorbar(load_scan_data: LoadScans, tmp_path: Path) -> None:
"""Test plotting with colorbar."""
def test_plot_and_save_colorbar_and_axes(load_scan_data: LoadScans, plotting_config: dict, tmp_path: Path) -> None:
"""Test plotting with colorbar and axes (True in default_config.yaml)."""
fig, _ = Images(
data=load_scan_data.image,
output_dir=tmp_path,
filename="01-raw_heightmap",
pixel_to_nm_scaling=load_scan_data.pixel_to_nm_scaling,
title="Raw Height",
colorbar=True,
axes=True,
image_set="all",
**plotting_config,
).plot_and_save()
return fig

Expand All @@ -174,20 +159,19 @@ def test_plot_and_save_no_axes(load_scan_data: LoadScans, plotting_config: dict,
return fig


@pytest.mark.mpl_image_compare(baseline_dir="resources/img/")
def test_plot_and_save_no_axes_no_colorbar(load_scan_data: LoadScans, plotting_config: dict, tmp_path: Path) -> None:
"""Test plotting without axes and without the colourbar."""
plotting_config["axes"] = False
plotting_config["colorbar"] = False
Images(
fig, _ = Images(
data=load_scan_data.image,
output_dir=tmp_path,
filename="01-raw_heightmap",
title="Raw Height",
**plotting_config,
).plot_and_save()
img = io.imread(tmp_path / "01-raw_heightmap.png")
assert np.sum(img) == 1535334
assert img.shape == (64, 64, 4)
return fig


@pytest.mark.mpl_image_compare(baseline_dir="resources/img/")
Expand Down
19 changes: 12 additions & 7 deletions tests/test_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,45 +288,50 @@ def test_check_run_steps(
@pytest.mark.parametrize(
("filter_run", "grains_run", "grainstats_run", "dnatracing_run", "log_msg1", "log_msg2"),
[
(
pytest.param(
False,
False,
False,
False,
"You have not included running the initial filter stage.",
"Please check your configuration file.",
id="All stages are disabled",
),
(
pytest.param(
True,
False,
False,
False,
"Detection of grains disabled, returning empty data frame.",
"16-gaussian_filtered",
"minicircle_small.png",
id="Only filtering enabled",
),
(
pytest.param(
True,
True,
False,
False,
"Calculation of grainstats disabled, returning empty dataframe.",
"25-labelled_image_bboxes",
"minicircle_small_above_masked.png",
id="Filtering and Grain enabled",
),
(
pytest.param(
True,
True,
True,
False,
"Processing grain",
"Calculation of DNA Tracing disabled, returning grainstats data frame.",
id="Filtering, Grain and GrainStats enabled",
),
(
pytest.param(
True,
True,
True,
True,
"Traced grain 3 of 3",
"Combining ['above'] grain statistics and dnatracing statistics",
id="Filtering, Grain, GrainStats and DNA Tracing enabled",
),
],
)
Expand Down
40 changes: 16 additions & 24 deletions topostats/plottingfuncs.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,11 @@ def __init__(
"""
Initialise the class.
There are two key parameters that ensure whether and image is plotted that are passed in from the update
plotting dictionary. These are the `image_set` which defines whether to plot 'all' images or just the `core`
set. There is then the 'core_set' which defines whether an individual images belongs to the 'core_set' or not.
If it doesn't then it is not plotted when `image_set == "core"`.
Parameters
----------
data : np.array
Expand Down Expand Up @@ -240,7 +245,7 @@ def plot_histogram_and_save(self):

def plot_and_save(self):
"""
Plot and save the images with savefig or imsave depending on config file parameters.
Plot and save the image.
Returns
-------
Expand All @@ -251,18 +256,15 @@ def plot_and_save(self):
"""
fig, ax = None, None
if self.save:
# Only plot if image_set is "all" (i.e. user wants all images) or an image is in the core_set
if self.image_set == "all" or self.core_set:
if self.axes or self.colorbar:
fig, ax = self.save_figure()
else:
if isinstance(self.masked_array, np.ndarray) or self.region_properties:
fig, ax = self.save_figure()
else:
self.save_array_figure()
LOGGER.info(
f"[{self.filename}] : Image saved to : {str(self.output_dir / self.filename)}.{self.savefig_format}\
| DPI: {self.savefig_dpi}"
)
fig, ax = self.save_figure()
LOGGER.info(
f"[{self.filename}] : Image saved to : {str(self.output_dir / self.filename)}.{self.savefig_format}"
" | DPI: {self.savefig_dpi}"
)
plt.close()
return fig, ax
return fig, ax

def save_figure(self):
Expand Down Expand Up @@ -325,6 +327,8 @@ def save_figure(self):
if not self.axes and not self.colorbar:
plt.title("")
fig.frameon = False
plt.box(False)
plt.tight_layout()
plt.savefig(
(self.output_dir / f"{self.filename}.{self.savefig_format}"),
bbox_inches="tight",
Expand All @@ -345,18 +349,6 @@ def save_figure(self):
plt.close()
return fig, ax

def save_array_figure(self) -> None:
"""Save the image array as an image using plt.imsave()."""
plt.imsave(
(self.output_dir / f"{self.filename}.{self.savefig_format}"),
self.data,
cmap=self.cmap,
vmin=self.zrange[0],
vmax=self.zrange[1],
format=self.savefig_format,
)
plt.close()


def add_bounding_boxes_to_plot(fig, ax, shape, region_properties: list, pixel_to_nm_scaling: float) -> None:
"""Add the bounding boxes to a plot.
Expand Down
21 changes: 15 additions & 6 deletions topostats/processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,17 +79,21 @@ def run_filters(
**filter_config,
)
filters.filter_image()

# Optionally plot filter stage
if plotting_config["run"]:
plotting_config.pop("run")
LOGGER.info(f"[{filename}] : Plotting Filtering Images")
if plotting_config["image_set"] == "all":
filter_out_path.mkdir(parents=True, exist_ok=True)
LOGGER.debug(f"[{filename}] : Target filter directory created : {filter_out_path}")
# Generate plots
for plot_name, array in filters.images.items():
if plot_name not in ["scan_raw"]:
if plot_name == "extracted_channel":
array = np.flipud(array.pixels)
plotting_config["plot_dict"][plot_name]["output_dir"] = filter_out_path
plotting_config["plot_dict"][plot_name]["output_dir"] = (
core_out_path if plotting_config["plot_dict"][plot_name]["core_set"] else filter_out_path
)
try:
Images(array, **plotting_config["plot_dict"][plot_name]).plot_and_save()
Images(array, **plotting_config["plot_dict"][plot_name]).plot_histogram_and_save()
Expand Down Expand Up @@ -184,18 +188,22 @@ def run_grains( # noqa: C901
LOGGER.info(f"[{filename}] : Plotting Grain Finding Images")
for direction, image_arrays in grains.directions.items():
LOGGER.info(f"[{filename}] : Plotting {direction} Grain Finding Images")
grain_out_path_direction = grain_out_path / f"{direction}"
if plotting_config["image_set"] == "all":
grain_out_path_direction.mkdir(parents=True, exist_ok=True)
LOGGER.debug(f"[{filename}] : Target grain directory created : {grain_out_path_direction}")
for plot_name, array in image_arrays.items():
LOGGER.info(f"[{filename}] : Plotting {plot_name} image")
plotting_config["plot_dict"][plot_name]["output_dir"] = grain_out_path / f"{direction}"
plotting_config["plot_dict"][plot_name]["output_dir"] = grain_out_path_direction
Images(array, **plotting_config["plot_dict"][plot_name]).plot_and_save()
# Make a plot of coloured regions with bounding boxes
plotting_config["plot_dict"]["bounding_boxes"]["output_dir"] = grain_out_path / f"{direction}"
plotting_config["plot_dict"]["bounding_boxes"]["output_dir"] = grain_out_path_direction
Images(
grains.directions[direction]["coloured_regions"],
**plotting_config["plot_dict"]["bounding_boxes"],
region_properties=grains.region_properties[direction],
).plot_and_save()
plotting_config["plot_dict"]["coloured_boxes"]["output_dir"] = grain_out_path / f"{direction}"
plotting_config["plot_dict"]["coloured_boxes"]["output_dir"] = grain_out_path_direction
Images(
grains.directions[direction]["labelled_regions_02"],
**plotting_config["plot_dict"]["coloured_boxes"],
Expand Down Expand Up @@ -534,7 +542,8 @@ def process_scan(
Parameters
----------
img_path_px2nm : Dict[str, Union[np.ndarray, Path, float]]
A dictionary with keys 'image', 'img_path' and 'px_2_nm' containing a file or frames' image, it's path and it's pixel to namometre scaling value.
A dictionary with keys 'image', 'img_path' and 'px_2_nm' containing a file or frames' image, it's path and it's
pixel to namometre scaling value.
base_dir : Union[str, Path]
Directory to recursively search for files, if not specified the current directory is scanned.
filter_config : dict
Expand Down

0 comments on commit 75e6ec4

Please sign in to comment.