Skip to content

Commit

Permalink
Improved testsuite
Browse files Browse the repository at this point in the history
- Access to internal attributes of registry
  is wrap in a function for future identification.
- More usage of pytest fixtures instead of default registries
  • Loading branch information
hgrecco committed Dec 2, 2023
1 parent d8dd22d commit 98fbda4
Show file tree
Hide file tree
Showing 7 changed files with 124 additions and 106 deletions.
4 changes: 4 additions & 0 deletions pint/testsuite/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@
_unit_re = re.compile(r"<Unit\((.*)\)>")


def internal(ureg):
return ureg


class PintOutputChecker(doctest.OutputChecker):
def check_output(self, want, got, optionflags):
check = super().check_output(want, got, optionflags)
Expand Down
135 changes: 69 additions & 66 deletions pint/testsuite/test_contexts.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,10 @@
from pint.util import UnitsContainer


def add_ctxs(ureg):
from .helpers import internal


def add_ctxs(ureg: UnitRegistry):
a, b = UnitsContainer({"[length]": 1}), UnitsContainer({"[time]": -1})
d = Context("lc")
d.add_transformation(a, b, lambda ureg, x: ureg.speed_of_light / x)
Expand All @@ -33,7 +36,7 @@ def add_ctxs(ureg):
ureg.add_context(d)


def add_arg_ctxs(ureg):
def add_arg_ctxs(ureg: UnitRegistry):
a, b = UnitsContainer({"[length]": 1}), UnitsContainer({"[time]": -1})
d = Context("lc")
d.add_transformation(a, b, lambda ureg, x, n: ureg.speed_of_light / x / n)
Expand All @@ -49,7 +52,7 @@ def add_arg_ctxs(ureg):
ureg.add_context(d)


def add_argdef_ctxs(ureg):
def add_argdef_ctxs(ureg: UnitRegistry):
a, b = UnitsContainer({"[length]": 1}), UnitsContainer({"[time]": -1})
d = Context("lc", defaults=dict(n=1))
assert d.defaults == dict(n=1)
Expand All @@ -67,7 +70,7 @@ def add_argdef_ctxs(ureg):
ureg.add_context(d)


def add_sharedargdef_ctxs(ureg):
def add_sharedargdef_ctxs(ureg: UnitRegistry):
a, b = UnitsContainer({"[length]": 1}), UnitsContainer({"[time]": -1})
d = Context("lc", defaults=dict(n=1))
assert d.defaults == dict(n=1)
Expand All @@ -90,37 +93,37 @@ def test_known_context(self, func_registry):
ureg = func_registry
add_ctxs(ureg)
with ureg.context("lc"):
assert ureg._active_ctx
assert ureg._active_ctx.graph
assert internal(ureg)._active_ctx
assert internal(ureg)._active_ctx.graph

assert not ureg._active_ctx
assert not ureg._active_ctx.graph
assert not internal(ureg)._active_ctx
assert not internal(ureg)._active_ctx.graph

with ureg.context("lc", n=1):
assert ureg._active_ctx
assert ureg._active_ctx.graph
assert internal(ureg)._active_ctx
assert internal(ureg)._active_ctx.graph

assert not ureg._active_ctx
assert not ureg._active_ctx.graph
assert not internal(ureg)._active_ctx
assert not internal(ureg)._active_ctx.graph

def test_known_context_enable(self, func_registry):
ureg = func_registry
add_ctxs(ureg)
ureg.enable_contexts("lc")
assert ureg._active_ctx
assert ureg._active_ctx.graph
assert internal(ureg)._active_ctx
assert internal(ureg)._active_ctx.graph
ureg.disable_contexts(1)

assert not ureg._active_ctx
assert not ureg._active_ctx.graph
assert not internal(ureg)._active_ctx
assert not internal(ureg)._active_ctx.graph

ureg.enable_contexts("lc", n=1)
assert ureg._active_ctx
assert ureg._active_ctx.graph
assert internal(ureg)._active_ctx
assert internal(ureg)._active_ctx.graph
ureg.disable_contexts(1)

assert not ureg._active_ctx
assert not ureg._active_ctx.graph
assert not internal(ureg)._active_ctx
assert not internal(ureg)._active_ctx.graph

def test_graph(self, func_registry):
ureg = func_registry
Expand All @@ -139,27 +142,27 @@ def test_graph(self, func_registry):
g.update({l: {t, c}, t: {l}, c: {l}})

with ureg.context("lc"):
assert ureg._active_ctx.graph == g_sp
assert internal(ureg)._active_ctx.graph == g_sp

with ureg.context("lc", n=1):
assert ureg._active_ctx.graph == g_sp
assert internal(ureg)._active_ctx.graph == g_sp

with ureg.context("ab"):
assert ureg._active_ctx.graph == g_ab
assert internal(ureg)._active_ctx.graph == g_ab

with ureg.context("lc"):
with ureg.context("ab"):
assert ureg._active_ctx.graph == g
assert internal(ureg)._active_ctx.graph == g

with ureg.context("ab"):
with ureg.context("lc"):
assert ureg._active_ctx.graph == g
assert internal(ureg)._active_ctx.graph == g

with ureg.context("lc", "ab"):
assert ureg._active_ctx.graph == g
assert internal(ureg)._active_ctx.graph == g

with ureg.context("ab", "lc"):
assert ureg._active_ctx.graph == g
assert internal(ureg)._active_ctx.graph == g

def test_graph_enable(self, func_registry):
ureg = func_registry
Expand All @@ -178,82 +181,82 @@ def test_graph_enable(self, func_registry):
g.update({l: {t, c}, t: {l}, c: {l}})

ureg.enable_contexts("lc")
assert ureg._active_ctx.graph == g_sp
assert internal(ureg)._active_ctx.graph == g_sp
ureg.disable_contexts(1)

ureg.enable_contexts("lc", n=1)
assert ureg._active_ctx.graph == g_sp
assert internal(ureg)._active_ctx.graph == g_sp
ureg.disable_contexts(1)

ureg.enable_contexts("ab")
assert ureg._active_ctx.graph == g_ab
assert internal(ureg)._active_ctx.graph == g_ab
ureg.disable_contexts(1)

ureg.enable_contexts("lc")
ureg.enable_contexts("ab")
assert ureg._active_ctx.graph == g
assert internal(ureg)._active_ctx.graph == g
ureg.disable_contexts(2)

ureg.enable_contexts("ab")
ureg.enable_contexts("lc")
assert ureg._active_ctx.graph == g
assert internal(ureg)._active_ctx.graph == g
ureg.disable_contexts(2)

ureg.enable_contexts("lc", "ab")
assert ureg._active_ctx.graph == g
assert internal(ureg)._active_ctx.graph == g
ureg.disable_contexts(2)

ureg.enable_contexts("ab", "lc")
assert ureg._active_ctx.graph == g
assert internal(ureg)._active_ctx.graph == g
ureg.disable_contexts(2)

def test_known_nested_context(self, func_registry):
ureg = func_registry
add_ctxs(ureg)

with ureg.context("lc"):
x = dict(ureg._active_ctx)
y = dict(ureg._active_ctx.graph)
assert ureg._active_ctx
assert ureg._active_ctx.graph
x = dict(internal(ureg)._active_ctx)
y = dict(internal(ureg)._active_ctx.graph)
assert internal(ureg)._active_ctx
assert internal(ureg)._active_ctx.graph

with ureg.context("ab"):
assert ureg._active_ctx
assert ureg._active_ctx.graph
assert x != ureg._active_ctx
assert y != ureg._active_ctx.graph
assert internal(ureg)._active_ctx
assert internal(ureg)._active_ctx.graph
assert x != internal(ureg)._active_ctx
assert y != internal(ureg)._active_ctx.graph

assert x == ureg._active_ctx
assert y == ureg._active_ctx.graph
assert x == internal(ureg)._active_ctx
assert y == internal(ureg)._active_ctx.graph

assert not ureg._active_ctx
assert not ureg._active_ctx.graph
assert not internal(ureg)._active_ctx
assert not internal(ureg)._active_ctx.graph

def test_unknown_context(self, func_registry):
ureg = func_registry
add_ctxs(ureg)
with pytest.raises(KeyError):
with ureg.context("la"):
pass
assert not ureg._active_ctx
assert not ureg._active_ctx.graph
assert not internal(ureg)._active_ctx
assert not internal(ureg)._active_ctx.graph

def test_unknown_nested_context(self, func_registry):
ureg = func_registry
add_ctxs(ureg)

with ureg.context("lc"):
x = dict(ureg._active_ctx)
y = dict(ureg._active_ctx.graph)
x = dict(internal(ureg)._active_ctx)
y = dict(internal(ureg)._active_ctx.graph)
with pytest.raises(KeyError):
with ureg.context("la"):
pass

assert x == ureg._active_ctx
assert y == ureg._active_ctx.graph
assert x == internal(ureg)._active_ctx
assert y == internal(ureg)._active_ctx.graph

assert not ureg._active_ctx
assert not ureg._active_ctx.graph
assert not internal(ureg)._active_ctx
assert not internal(ureg)._active_ctx.graph

def test_one_context(self, func_registry):
ureg = func_registry
Expand Down Expand Up @@ -498,21 +501,21 @@ def _test_ctx(self, ctx, ureg):
q = 500 * ureg.meter
s = (ureg.speed_of_light / q).to("Hz")

nctx = len(ureg._contexts)
nctx = len(internal(ureg)._contexts)

assert ctx.name not in ureg._contexts
assert ctx.name not in internal(ureg)._contexts
ureg.add_context(ctx)

assert ctx.name in ureg._contexts
assert len(ureg._contexts) == nctx + 1 + len(ctx.aliases)
assert ctx.name in internal(ureg)._contexts
assert len(internal(ureg)._contexts) == nctx + 1 + len(ctx.aliases)

with ureg.context(ctx.name):
assert q.to("Hz") == s
assert s.to("meter") == q

ureg.remove_context(ctx.name)
assert ctx.name not in ureg._contexts
assert len(ureg._contexts) == nctx
assert ctx.name not in internal(ureg)._contexts
assert len(internal(ureg)._contexts) == nctx

@pytest.mark.parametrize(
"badrow",
Expand Down Expand Up @@ -661,11 +664,11 @@ def test_defined(self, class_registry):
b = Context.__keytransform__(
UnitsContainer({"[length]": 1.0}), UnitsContainer({"[time]": -1.0})
)
assert a in ureg._contexts["sp"].funcs
assert b in ureg._contexts["sp"].funcs
assert a in internal(ureg)._contexts["sp"].funcs
assert b in internal(ureg)._contexts["sp"].funcs
with ureg.context("sp"):
assert a in ureg._active_ctx
assert b in ureg._active_ctx
assert a in internal(ureg)._active_ctx
assert b in internal(ureg)._active_ctx

def test_spectroscopy(self, class_registry):
ureg = class_registry
Expand All @@ -681,7 +684,7 @@ def test_spectroscopy(self, class_registry):
da, db = Context.__keytransform__(
a.dimensionality, b.dimensionality
)
p = find_shortest_path(ureg._active_ctx.graph, da, db)
p = find_shortest_path(internal(ureg)._active_ctx.graph, da, db)
assert p
msg = f"{a} <-> {b}"
# assertAlmostEqualRelError converts second to first
Expand All @@ -703,7 +706,7 @@ def test_textile(self, class_registry):
a = qty_direct.to_base_units()
b = qty_indirect.to_base_units()
da, db = Context.__keytransform__(a.dimensionality, b.dimensionality)
p = find_shortest_path(ureg._active_ctx.graph, da, db)
p = find_shortest_path(internal(ureg)._active_ctx.graph, da, db)
assert p
msg = f"{a} <-> {b}"
helpers.assert_quantity_almost_equal(b, a, rtol=0.01, msg=msg)
Expand Down
29 changes: 17 additions & 12 deletions pint/testsuite/test_diskcache.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,16 @@
FS_SLEEP = 0.010


from .helpers import internal


@pytest.fixture
def float_cache_filename(tmp_path):
ureg = pint.UnitRegistry(cache_folder=tmp_path / "cache_with_float")
assert ureg._diskcache
assert ureg._diskcache.cache_folder
assert internal(ureg)._diskcache
assert internal(ureg)._diskcache.cache_folder

return tuple(ureg._diskcache.cache_folder.glob("*.pickle"))
return tuple(internal(ureg)._diskcache.cache_folder.glob("*.pickle"))


def test_must_be_three_files(float_cache_filename):
Expand All @@ -30,19 +33,19 @@ def test_must_be_three_files(float_cache_filename):

def test_no_cache():
ureg = pint.UnitRegistry(cache_folder=None)
assert ureg._diskcache is None
assert internal(ureg)._diskcache is None
assert ureg.cache_folder is None


def test_decimal(tmp_path, float_cache_filename):
ureg = pint.UnitRegistry(
cache_folder=tmp_path / "cache_with_decimal", non_int_type=decimal.Decimal
)
assert ureg._diskcache
assert ureg._diskcache.cache_folder == tmp_path / "cache_with_decimal"
assert internal(ureg)._diskcache
assert internal(ureg)._diskcache.cache_folder == tmp_path / "cache_with_decimal"
assert ureg.cache_folder == tmp_path / "cache_with_decimal"

files = tuple(ureg._diskcache.cache_folder.glob("*.pickle"))
files = tuple(internal(ureg)._diskcache.cache_folder.glob("*.pickle"))
assert len(files) == 3

# check that the filenames with decimal are different to the ones with float
Expand All @@ -66,9 +69,11 @@ def test_auto(float_cache_filename):
float_filenames = tuple(p.name for p in float_cache_filename)

ureg = pint.UnitRegistry(cache_folder=":auto:")
assert ureg._diskcache
assert ureg._diskcache.cache_folder
auto_files = tuple(p.name for p in ureg._diskcache.cache_folder.glob("*.pickle"))
assert internal(ureg)._diskcache
assert internal(ureg)._diskcache.cache_folder
auto_files = tuple(
p.name for p in internal(ureg)._diskcache.cache_folder.glob("*.pickle")
)
for file in float_filenames:
assert file in auto_files

Expand All @@ -82,7 +87,7 @@ def test_change_file(tmp_path):
# (this will create two cache files, one for the file another for RegistryCache)
ureg = pint.UnitRegistry(dfile, cache_folder=tmp_path)
assert ureg.x == 1234
files = tuple(ureg._diskcache.cache_folder.glob("*.pickle"))
files = tuple(internal(ureg)._diskcache.cache_folder.glob("*.pickle"))
assert len(files) == 2

# Modify the definition file
Expand All @@ -93,5 +98,5 @@ def test_change_file(tmp_path):
# Verify that the definiton file was loaded (the cache was invalidated).
ureg = pint.UnitRegistry(dfile, cache_folder=tmp_path)
assert ureg.x == 1235
files = tuple(ureg._diskcache.cache_folder.glob("*.pickle"))
files = tuple(internal(ureg)._diskcache.cache_folder.glob("*.pickle"))
assert len(files) == 4
Loading

0 comments on commit 98fbda4

Please sign in to comment.