Skip to content

Commit

Permalink
Merge pull request #848 from AFM-SPM/ns-rse/818-tests-pruning-toposta…
Browse files Browse the repository at this point in the history
…ts-conv

tests(pruning): Tests TopoStatsPrune and convPrune
  • Loading branch information
MaxGamill-Sheffield authored Jul 2, 2024
2 parents ccf292d + 68fb19a commit 4d5bad0
Show file tree
Hide file tree
Showing 11 changed files with 1,225 additions and 838 deletions.
4 changes: 3 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,9 @@ write_to = "topostats/_version.py"
[tool.pytest.ini_options]
minversion = "7.0"
addopts = ["--cov", "--mpl", "-ra", "--strict-config", "--strict-markers"]
log_cli_level = "Info"
log_level = "INFO"
log_cli = true
log_cli_level = "INFO"
testpaths = [
"tests",
]
Expand Down
40 changes: 28 additions & 12 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -805,11 +805,13 @@ def _generate_random_skeleton(**extra_kwargs):
"shape": None,
"allow_overlap": True,
}
# kwargs.update
heights = {"scale": 100, "sigma": 5.0, "cval": 20.0}
random_image, _ = draw.random_shapes(**kwargs, **extra_kwargs)
kwargs = {**kwargs, **extra_kwargs}
random_image, _ = draw.random_shapes(**kwargs)
mask = random_image != 255
skeleton = skeletonize(mask)
return {"img": _generate_heights(skeleton, **heights), "skeleton": skeleton}
return {"original": mask, "img": _generate_heights(skeleton, **heights), "skeleton": skeleton}


@pytest.fixture()
Expand Down Expand Up @@ -842,13 +844,21 @@ def skeleton_linear3() -> dict:
return _generate_random_skeleton(rng=894632511, min_size=20)


@pytest.fixture()
def pruning_skeleton() -> dict:
"""Smaller skeleton for testing parameters of prune_all_skeletons(). Has a T-junction."""
return _generate_random_skeleton(rng=69432138, min_size=15, image_shape=(30, 30))


## Helper function visualising for generating skeletons and heights


# import matplotlib.pyplot as plt
# def pruned_plot(gen_shape: dict) -> None:
# """Plot the original skeleton, its derived height and the pruned skeleton."""
# img_skeleton = gen_shape()
# img_skeleton = gen_shape
# pruned = topostatsPrune(
# img_skeleton["heights"],
# img_skeleton["img"],
# img_skeleton["skeleton"],
# max_length=-1,
# height_threshold=90,
Expand All @@ -857,14 +867,20 @@ def skeleton_linear3() -> dict:
# )
# pruned_skeleton = pruned._prune_by_length(pruned.skeleton, pruned.max_length)
# fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2)
# ax1.imshow(img_skeleton["skeleton"])
# ax2.imshow(img_skeleton["heights"])
# ax3.imshow(pruned_skeleton)
# ax1.imshow(img_skeleton["original"])
# ax1.set_title("Original mask")
# ax2.imshow(img_skeleton["skeleton"])
# ax2.set_title("Skeleton")
# ax3.imshow(img_skeleton["img"])
# ax3.set_title("Gaussian Blurring")
# ax4.imshow(pruned_skeleton)
# ax4.set_title("Pruned Skeleton")
# plt.show()


# pruned_plot(pruning_skeleton_loop1)
# pruned_plot(pruning_skeleton_loop2)
# pruned_plot(pruning_skeleton_linear1)
# pruned_plot(pruning_skeleton_linear2)
# pruned_plot(pruning_skeleton_linear3)
# pruned_plot(pruning_skeleton_loop1())
# pruned_plot(pruning_skeleton_loop2())
# pruned_plot(pruning_skeleton_linear1())
# pruned_plot(pruning_skeleton_linear2())
# pruned_plot(pruning_skeleton_linear3())
# pruned_plot(pruning_skeleton())
Loading

0 comments on commit 4d5bad0

Please sign in to comment.