diff --git a/holoviews/element/comparison.py b/holoviews/element/comparison.py index a8c622ac00..5fadb04a21 100644 --- a/holoviews/element/comparison.py +++ b/holoviews/element/comparison.py @@ -17,6 +17,7 @@ thus would not supply any information regarding *why* two elements are considered different. """ +from functools import partial import numpy as np from unittest.util import safe_repr from unittest import TestCase @@ -92,6 +93,9 @@ class Comparison(ComparisonInterface): Comparison.assertEqual(matrix1, matrix2) """ + # someone might prefer to use a different function, e.g. assert_all_close + assert_array_almost_equal_fn = partial(assert_array_almost_equal, decimal=6) + @classmethod def register(cls): @@ -246,7 +250,7 @@ def compare_arrays(cls, arr1, arr2, msg='Arrays'): assert_array_equal(arr1, arr2) except: try: - assert_array_almost_equal(arr1, arr2) + cls.assert_array_almost_equal_fn(arr1, arr2) except AssertionError as e: raise cls.failureException(msg + str(e)[11:])