Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tidies up Images.plot_and_save() #810

Merged
merged 4 commits into from
Mar 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
34 changes: 9 additions & 25 deletions tests/test_plottingfuncs.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import pytest
from matplotlib.axes import Axes
from matplotlib.figure import Figure
from skimage import io

from topostats.grains import Grains
from topostats.io import LoadScans
Expand Down Expand Up @@ -108,29 +107,17 @@ def test_save_figure(
assert isinstance(ax, Axes)


def test_save_array_figure(tmp_path: Path):
"""Tests that the image array is saved."""
rng2 = np.random.default_rng()
Images(
data=rng2.random((10, 10)),
output_dir=tmp_path,
filename="result",
).save_array_figure()
assert Path(tmp_path / "result.png").exists()


@pytest.mark.mpl_image_compare(baseline_dir="resources/img/")
def test_plot_and_save_no_colorbar(load_scan_data: LoadScans, tmp_path: Path) -> None:
def test_plot_and_save_no_colorbar(load_scan_data: LoadScans, plotting_config: dict, tmp_path: Path) -> None:
"""Test plotting without colorbar."""
plotting_config["colorbar"] = False
fig, _ = Images(
data=load_scan_data.image,
output_dir=tmp_path,
filename="01-raw_heightmap",
pixel_to_nm_scaling=load_scan_data.pixel_to_nm_scaling,
title="Raw Height",
colorbar=False,
axes=True,
image_set="all",
**plotting_config,
).plot_and_save()
return fig

Expand All @@ -145,17 +132,15 @@ def test_plot_histogram_and_save(load_scan_data: LoadScans, tmp_path: Path) -> N


@pytest.mark.mpl_image_compare(baseline_dir="resources/img/")
def test_plot_and_save_colorbar(load_scan_data: LoadScans, tmp_path: Path) -> None:
"""Test plotting with colorbar."""
def test_plot_and_save_colorbar_and_axes(load_scan_data: LoadScans, plotting_config: dict, tmp_path: Path) -> None:
"""Test plotting with colorbar and axes (True in default_config.yaml)."""
fig, _ = Images(
data=load_scan_data.image,
output_dir=tmp_path,
filename="01-raw_heightmap",
pixel_to_nm_scaling=load_scan_data.pixel_to_nm_scaling,
title="Raw Height",
colorbar=True,
axes=True,
image_set="all",
**plotting_config,
).plot_and_save()
return fig

Expand All @@ -174,20 +159,19 @@ def test_plot_and_save_no_axes(load_scan_data: LoadScans, plotting_config: dict,
return fig


@pytest.mark.mpl_image_compare(baseline_dir="resources/img/")
def test_plot_and_save_no_axes_no_colorbar(load_scan_data: LoadScans, plotting_config: dict, tmp_path: Path) -> None:
"""Test plotting without axes and without the colourbar."""
plotting_config["axes"] = False
plotting_config["colorbar"] = False
Images(
fig, _ = Images(
data=load_scan_data.image,
output_dir=tmp_path,
filename="01-raw_heightmap",
title="Raw Height",
**plotting_config,
).plot_and_save()
img = io.imread(tmp_path / "01-raw_heightmap.png")
assert np.sum(img) == 1535334
assert img.shape == (64, 64, 4)
return fig


@pytest.mark.mpl_image_compare(baseline_dir="resources/img/")
Expand Down
19 changes: 12 additions & 7 deletions tests/test_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,45 +288,50 @@ def test_check_run_steps(
@pytest.mark.parametrize(
("filter_run", "grains_run", "grainstats_run", "dnatracing_run", "log_msg1", "log_msg2"),
[
(
pytest.param(
False,
False,
False,
False,
"You have not included running the initial filter stage.",
"Please check your configuration file.",
id="All stages are disabled",
),
(
pytest.param(
True,
False,
False,
False,
"Detection of grains disabled, returning empty data frame.",
"16-gaussian_filtered",
"minicircle_small.png",
id="Only filtering enabled",
),
(
pytest.param(
True,
True,
False,
False,
"Calculation of grainstats disabled, returning empty dataframe.",
"25-labelled_image_bboxes",
"minicircle_small_above_masked.png",
id="Filtering and Grain enabled",
),
(
pytest.param(
True,
True,
True,
False,
"Processing grain",
"Calculation of DNA Tracing disabled, returning grainstats data frame.",
id="Filtering, Grain and GrainStats enabled",
),
(
pytest.param(
True,
True,
True,
True,
"Traced grain 3 of 3",
"Combining ['above'] grain statistics and dnatracing statistics",
id="Filtering, Grain, GrainStats and DNA Tracing enabled",
),
],
)
Expand Down
40 changes: 16 additions & 24 deletions topostats/plottingfuncs.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,11 @@ def __init__(
"""
Initialise the class.

There are two key parameters that ensure whether and image is plotted that are passed in from the update
plotting dictionary. These are the `image_set` which defines whether to plot 'all' images or just the `core`
set. There is then the 'core_set' which defines whether an individual images belongs to the 'core_set' or not.
If it doesn't then it is not plotted when `image_set == "core"`.

Parameters
----------
data : np.array
Expand Down Expand Up @@ -240,7 +245,7 @@ def plot_histogram_and_save(self):

def plot_and_save(self):
"""
Plot and save the images with savefig or imsave depending on config file parameters.
Plot and save the image.

Returns
-------
Expand All @@ -251,18 +256,15 @@ def plot_and_save(self):
"""
fig, ax = None, None
if self.save:
# Only plot if image_set is "all" (i.e. user wants all images) or an image is in the core_set
if self.image_set == "all" or self.core_set:
if self.axes or self.colorbar:
fig, ax = self.save_figure()
else:
if isinstance(self.masked_array, np.ndarray) or self.region_properties:
fig, ax = self.save_figure()
else:
self.save_array_figure()
LOGGER.info(
f"[{self.filename}] : Image saved to : {str(self.output_dir / self.filename)}.{self.savefig_format}\
| DPI: {self.savefig_dpi}"
)
fig, ax = self.save_figure()
LOGGER.info(
f"[{self.filename}] : Image saved to : {str(self.output_dir / self.filename)}.{self.savefig_format}"
" | DPI: {self.savefig_dpi}"
)
plt.close()
return fig, ax
return fig, ax

def save_figure(self):
Expand Down Expand Up @@ -325,6 +327,8 @@ def save_figure(self):
if not self.axes and not self.colorbar:
plt.title("")
fig.frameon = False
plt.box(False)
plt.tight_layout()
plt.savefig(
(self.output_dir / f"{self.filename}.{self.savefig_format}"),
bbox_inches="tight",
Expand All @@ -345,18 +349,6 @@ def save_figure(self):
plt.close()
return fig, ax

def save_array_figure(self) -> None:
"""Save the image array as an image using plt.imsave()."""
plt.imsave(
(self.output_dir / f"{self.filename}.{self.savefig_format}"),
self.data,
cmap=self.cmap,
vmin=self.zrange[0],
vmax=self.zrange[1],
format=self.savefig_format,
)
plt.close()


def add_bounding_boxes_to_plot(fig, ax, shape, region_properties: list, pixel_to_nm_scaling: float) -> None:
"""Add the bounding boxes to a plot.
Expand Down
21 changes: 15 additions & 6 deletions topostats/processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,17 +79,21 @@ def run_filters(
**filter_config,
)
filters.filter_image()

# Optionally plot filter stage
if plotting_config["run"]:
plotting_config.pop("run")
LOGGER.info(f"[{filename}] : Plotting Filtering Images")
if plotting_config["image_set"] == "all":
filter_out_path.mkdir(parents=True, exist_ok=True)
LOGGER.debug(f"[{filename}] : Target filter directory created : {filter_out_path}")
# Generate plots
for plot_name, array in filters.images.items():
if plot_name not in ["scan_raw"]:
if plot_name == "extracted_channel":
array = np.flipud(array.pixels)
plotting_config["plot_dict"][plot_name]["output_dir"] = filter_out_path
plotting_config["plot_dict"][plot_name]["output_dir"] = (
core_out_path if plotting_config["plot_dict"][plot_name]["core_set"] else filter_out_path
)
try:
Images(array, **plotting_config["plot_dict"][plot_name]).plot_and_save()
Images(array, **plotting_config["plot_dict"][plot_name]).plot_histogram_and_save()
Comment on lines +86 to 99
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The only other thing is that you've got an if image_set=='all' statement which is then used again within the plot_and_save() function, and I'm thinking would it make sense just to use it the once in processing?

We could pop it and re-add it like the plotting.run param and keep all the non-core images within this if statement, and the core image outside like it is? What do you think?

Copy link
Collaborator Author

@ns-rse ns-rse Mar 6, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The logic here though is to create the sub-directory filter as held in filter_out_path (the same is true for grains). We then need to set the parameter in the current images plot_dict to whatever that is, hence the if/else.

When I simplified the logic within plot_and_save() for some reason I couldn't understand I got error messages about target directories not being present so I figured I needed to create them.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also having plot_and_save() know about image_set might be useful should anyone use the class/method interactively e.g. in a Notebook.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok makes sense, thanks for explaining :)

Expand Down Expand Up @@ -184,18 +188,22 @@ def run_grains( # noqa: C901
LOGGER.info(f"[{filename}] : Plotting Grain Finding Images")
for direction, image_arrays in grains.directions.items():
LOGGER.info(f"[{filename}] : Plotting {direction} Grain Finding Images")
grain_out_path_direction = grain_out_path / f"{direction}"
if plotting_config["image_set"] == "all":
grain_out_path_direction.mkdir(parents=True, exist_ok=True)
LOGGER.debug(f"[{filename}] : Target grain directory created : {grain_out_path_direction}")
for plot_name, array in image_arrays.items():
LOGGER.info(f"[{filename}] : Plotting {plot_name} image")
plotting_config["plot_dict"][plot_name]["output_dir"] = grain_out_path / f"{direction}"
plotting_config["plot_dict"][plot_name]["output_dir"] = grain_out_path_direction
Images(array, **plotting_config["plot_dict"][plot_name]).plot_and_save()
# Make a plot of coloured regions with bounding boxes
plotting_config["plot_dict"]["bounding_boxes"]["output_dir"] = grain_out_path / f"{direction}"
plotting_config["plot_dict"]["bounding_boxes"]["output_dir"] = grain_out_path_direction
Images(
grains.directions[direction]["coloured_regions"],
**plotting_config["plot_dict"]["bounding_boxes"],
region_properties=grains.region_properties[direction],
).plot_and_save()
plotting_config["plot_dict"]["coloured_boxes"]["output_dir"] = grain_out_path / f"{direction}"
plotting_config["plot_dict"]["coloured_boxes"]["output_dir"] = grain_out_path_direction
Images(
grains.directions[direction]["labelled_regions_02"],
**plotting_config["plot_dict"]["coloured_boxes"],
Expand Down Expand Up @@ -534,7 +542,8 @@ def process_scan(
Parameters
----------
img_path_px2nm : Dict[str, Union[np.ndarray, Path, float]]
A dictionary with keys 'image', 'img_path' and 'px_2_nm' containing a file or frames' image, it's path and it's pixel to namometre scaling value.
A dictionary with keys 'image', 'img_path' and 'px_2_nm' containing a file or frames' image, it's path and it's
pixel to namometre scaling value.
base_dir : Union[str, Path]
Directory to recursively search for files, if not specified the current directory is scanned.
filter_config : dict
Expand Down
Loading