Skip to content

Commit

Permalink
Generalize check figures equal to work with pytest.marks (#600)
Browse files Browse the repository at this point in the history
* Allow check_figures_equal to work with pytest parametrized

Following cue of matplotlib's check_figures_equal decorator.

* Silence pylint complaints on ALLOWED_CHARS & KEYWORD_ONLY variable names
* Fix doctest failures on helpers/testing.py
* Update documentation on check_figures_equal
  • Loading branch information
weiji14 authored Sep 11, 2020
1 parent 6f32a2e commit 85c08ef
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 13 deletions.
3 changes: 1 addition & 2 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -328,9 +328,8 @@ Here's an example:
@check_figures_equal()
def test_my_plotting_case():
"Test that my plotting function works"
fig_ref = Figure()
fig_ref, fig_test = Figure(), Figure()
fig_ref.grdimage("@earth_relief_01d_g", projection="W120/15c", cmap="geo")
fig_test = Figure()
fig_test.grdimage(grid, projection="W120/15c", cmap="geo")
return fig_ref, fig_test
```
Expand Down
42 changes: 31 additions & 11 deletions pygmt/helpers/testing.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,20 @@
"""
Helper functions for testing.
"""

import inspect
import os
import string

from matplotlib.testing.compare import compare_images

from ..exceptions import GMTImageComparisonFailure


def check_figures_equal(*, tol=0.0, result_dir="result_images"):
def check_figures_equal(*, extensions=("png",), tol=0.0, result_dir="result_images"):
"""
Decorator for test cases that generate and compare two figures.
The decorated function must take two arguments, *fig_ref* and *fig_test*,
and draw the reference and test images on them. After the function
returns, the figures are saved and compared.
The decorated function must return two arguments, *fig_ref* and *fig_test*,
these two figures will then be saved and compared against each other.
This decorator is practically identical to matplotlib's check_figures_equal
function, but adapted for PyGMT figures. See also the original code at
Expand All @@ -25,6 +23,8 @@ def check_figures_equal(*, tol=0.0, result_dir="result_images"):
Parameters
----------
extensions : list
The extensions to test. Default is ["png"].
tol : float
The RMS threshold above which the test is considered failed.
result_dir : str
Expand Down Expand Up @@ -66,19 +66,30 @@ def check_figures_equal(*, tol=0.0, result_dir="result_images"):
... )
>>> shutil.rmtree(path="tmp_result_images") # cleanup folder if tests pass
"""
# pylint: disable=invalid-name
ALLOWED_CHARS = set(string.digits + string.ascii_letters + "_-[]()")
KEYWORD_ONLY = inspect.Parameter.KEYWORD_ONLY

def decorator(func):
import pytest

os.makedirs(result_dir, exist_ok=True)
old_sig = inspect.signature(func)

def wrapper(*args, **kwargs):
@pytest.mark.parametrize("ext", extensions)
def wrapper(*args, ext="png", request=None, **kwargs):
if "ext" in old_sig.parameters:
kwargs["ext"] = ext
if "request" in old_sig.parameters:
kwargs["request"] = request
try:
file_name = "".join(c for c in request.node.name if c in ALLOWED_CHARS)
except AttributeError: # 'NoneType' object has no attribute 'node'
file_name = func.__name__
try:
fig_ref, fig_test = func(*args, **kwargs)
ref_image_path = os.path.join(
result_dir, func.__name__ + "-expected.png"
)
test_image_path = os.path.join(result_dir, func.__name__ + ".png")
ref_image_path = os.path.join(result_dir, f"{file_name}-expected.{ext}")
test_image_path = os.path.join(result_dir, f"{file_name}.{ext}")
fig_ref.savefig(ref_image_path)
fig_test.savefig(test_image_path)

Expand Down Expand Up @@ -109,9 +120,18 @@ def wrapper(*args, **kwargs):
for param in old_sig.parameters.values()
if param.name not in {"fig_test", "fig_ref"}
]
if "ext" not in old_sig.parameters:
parameters += [inspect.Parameter("ext", KEYWORD_ONLY)]
if "request" not in old_sig.parameters:
parameters += [inspect.Parameter("request", KEYWORD_ONLY)]
new_sig = old_sig.replace(parameters=parameters)
wrapper.__signature__ = new_sig

# reach a bit into pytest internals to hoist the marks from
# our wrapped function
new_marks = getattr(func, "pytestmark", []) + wrapper.pytestmark
wrapper.pytestmark = new_marks

return wrapper

return decorator

0 comments on commit 85c08ef

Please sign in to comment.