Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix up pint serialisation #9

Merged
merged 7 commits into from
Apr 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions .readthedocs.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,19 @@ version: 2
build:
os: ubuntu-22.04
tools:
python: "3.9"
python: "3.11"
jobs:
post_create_environment:
- pip install poetry
- poetry config virtualenvs.create false
post_install:
- poetry install --with docs --all-extras
# Not sure why this is needed.
# RtD seems to be not happy with poetry installs.
- poetry export -f requirements.txt --output requirements.txt --with docs
- python -m pip install -r requirements.txt
- python -m pip install .
- python -m pip list

# Build documentation in the docs/ directory with Sphinx
# Set sphinx configuration
sphinx:
configuration: docs/source/conf.py
4 changes: 4 additions & 0 deletions changelog/9.fix.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
Fixed handling of serialisation of Pint quantities.

Previously, they were mistakenly being identified as iterable,
which was causing things to explode.
6 changes: 6 additions & 0 deletions docs/source/api/pydoit_nb.config_handling.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,12 @@ insert\_path\_prefix
.. autofunction:: insert_path_prefix


iterable\_values\_are\_updatable
================================

.. autofunction:: iterable_values_are_updatable


update\_attr\_value
===================

Expand Down
37 changes: 36 additions & 1 deletion src/pydoit_nb/config_handling.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
Tools for working with configuration
"""

from __future__ import annotations

from collections.abc import Iterable
Expand All @@ -14,6 +15,13 @@
from pydoit_nb.serialization import write_config_in_config_bundle_to_disk
from pydoit_nb.typing import ConfigBundleCreator, ConfigBundleLike, Converter, NotebookConfigLike

try:
import pint

HAS_PINT = True
except ImportError: # pragma: no cover
HAS_PINT = False

T = TypeVar("T")


Expand Down Expand Up @@ -48,7 +56,7 @@ def insert_path_prefix(config: AI, prefix: Path) -> AI:
update_attr_value(k, prefix): update_attr_value(v, prefix) for k, v in attr_value.items()
}

elif not isinstance(attr_value, (str, np.ndarray)) and isinstance(attr_value, Iterable):
elif isinstance(attr_value, Iterable) and iterable_values_are_updatable(attr_value):
evolutions[attr_name] = [update_attr_value(v, prefix) for v in attr_value]

else:
Expand All @@ -57,6 +65,33 @@ def insert_path_prefix(config: AI, prefix: Path) -> AI:
return evolve(config, **evolutions) # type: ignore # no idea why this fails


# TODO: test this by testing that a value
# which has a pint quantity as an attribute
# doesn't cause insert_path_prefix to explode.
def iterable_values_are_updatable(value: Iterable[Any]) -> bool:
"""
Determine whether an iterable's values are updatable by :func:`insert_path_prefix`.

Parameters
----------
value
Value to check.

Returns
-------
``True`` if ``value``'s elements can be updated by :func:`update_attr_value`,
``False`` otherwise.
"""
to_check = [str, np.ndarray]
if HAS_PINT:
to_check.append(pint.UnitRegistry.Quantity)

if isinstance(value, tuple(to_check)):
return False

return True


@overload
def update_attr_value(value: AttrsInstance, prefix: Path) -> AttrsInstance:
... # pragma: no cover
Expand Down
20 changes: 18 additions & 2 deletions src/pydoit_nb/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,10 @@
the general case, although specific use cases should be far more tractable and
easy to test). If you'd like to discuss this more, please raise an issue.
"""

from __future__ import annotations

import warnings
from collections.abc import Sequence
from pathlib import Path
from typing import Any, TypeVar, Union, cast
Expand Down Expand Up @@ -243,7 +245,7 @@ def unstructure_pint(inp: pint.UnitRegistry.Quantity) -> UnstructuredPint:
if _is_np_scalar(type(inp.magnitude)):
return (unstructure_np_scalar(inp.magnitude), str(inp.units))

if isinstance(inp.magnitude, float):
if isinstance(inp.magnitude, (float, int)):
return (inp.magnitude, str(inp.units))

return (unstructure_np_array(inp.magnitude), str(inp.units))
Expand All @@ -257,7 +259,9 @@ def structure_pint(
Parameters
----------
inp
Unstructured data
Unstructured data. If this is a string containing a slash,
we try and convert it to a fraction but this isn't super safe
so we also raise a warning.

target_type
Type to create
Expand All @@ -269,6 +273,18 @@ def structure_pint(
# pint not playing nice with mypy
ur = pint.get_application_registry() # type: ignore

if isinstance(inp[0], str) and "/" in inp[0]:
msg = (
f"Received {inp[0]=}. "
"We are assuming that this is meant to be interpreted as a float64. "
"It would be safer to put a decimal value into your config, "
"or make a merge request to pydoit-nb to make this handling safer."
)
warnings.warn(msg)
toks = inp[0].split("/")
mag = np.float64(toks[0]) / float(toks[1])
return ur.Quantity(mag, inp[1]) # type: ignore

# Can't do dtype control until pint allows it again with e.g.
# pint.Quantity[np.array[np.float64]]
return ur.Quantity(np.array(inp[0]), inp[1]) # type: ignore
Expand Down
14 changes: 11 additions & 3 deletions tests/unit/test_config_handling.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
"""
Test config_handling
"""

from __future__ import annotations

import re
from functools import partial
from pathlib import Path

import pint
import pytest
from attrs import define

Expand All @@ -18,6 +20,8 @@
)
from pydoit_nb.serialization import converter_yaml, load_config_from_file

Q = pint.get_application_registry().Quantity


@define
class StepConfigA:
Expand Down Expand Up @@ -48,6 +52,7 @@ class ConfigA:
@define
class ConfigB:
step_c: list[StepConfigC]
pint_value: pint.UnitRegistry.Quantity


@define
Expand Down Expand Up @@ -159,7 +164,8 @@ def test_get_config_for_step_id_attribute_error():
output=Path("location") / "somewhere" / "else.txt",
config={"a": "b"},
),
]
],
Q(1, "kg"),
),
ConfigB(
[
Expand All @@ -175,7 +181,8 @@ def test_get_config_for_step_id_attribute_error():
output=Path("/some/prefix") / Path("location") / "somewhere" / "else.txt",
config={"a": "b"},
),
]
],
Q(1, "kg"),
),
Path("/some/prefix"),
id="nested_attrs_object",
Expand Down Expand Up @@ -221,7 +228,8 @@ def create_cb(
output=Path("to") / "somewhere.txt",
config={"something": "here"},
)
]
],
pint_value=Q(32, "K"),
)

with open(configuration_file, "w") as fh:
Expand Down
70 changes: 55 additions & 15 deletions tests/unit/test_serialization.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
Test the serialization module
"""

from pathlib import Path

import numpy as np
Expand Down Expand Up @@ -158,33 +159,72 @@ class ConfigBundle:
# names and dtypes, multi-dimensional, multi-index etc.
),
)
def test_structure_non_primatives(inp, exp, restructure_type):
def test_roundtrip_non_primatives(inp, exp, restructure_type):
res = converter_yaml.dumps(inp)
assert res == exp

roundtrip = converter_yaml.loads(res, restructure_type)
assert_roundtrip_success(roundtrip, inp)
assert_serialisation_success(roundtrip, inp)


@pytest.mark.parametrize(
"inp, exp, restructure_type",
(
pytest.param(
"key:\n- 30.333\n- kilogram\n",
{"key": UR.Quantity(np.float64(30.333), "kg")},
dict[str, UR.Quantity],
id="pint_float64_scalar",
),
pytest.param(
"key:\n- 30 / 400\n- kilogram\n",
{"key": UR.Quantity(np.float64(30 / 400), "kg")},
dict[str, UR.Quantity],
id="pint_float64_scalar_with_value_provided_as_fraction",
),
),
)
def test_structure_non_primatives(inp, exp, restructure_type):
res = converter_yaml.loads(inp, restructure_type)

assert_serialisation_success(res, exp)

def assert_roundtrip_success(roundtrip_res, inp):
if isinstance(inp, dict):
for k, value in inp.items():
assert_roundtrip_success(roundtrip_res[k], value)

def assert_serialisation_success(res, exp):
"""
Assert that serialisation was a success

Parameters
----------
res
The result

inp
The expected result

Raises
------
AssertionError
The serialisation was not a sucess
"""
if isinstance(exp, dict):
for k, value in exp.items():
assert_serialisation_success(res[k], value)

return

if isinstance(inp, np.ndarray):
nptesting.assert_equal(inp, roundtrip_res)
assert inp.dtype == roundtrip_res.dtype
if isinstance(exp, np.ndarray):
nptesting.assert_equal(exp, res)
assert exp.dtype == res.dtype
return

if isinstance(inp, pint.UnitRegistry.Quantity):
pinttesting.assert_equal(inp, roundtrip_res)
if hasattr(inp.m, "dtype"):
assert inp.m.dtype == roundtrip_res.m.dtype
if isinstance(exp, pint.UnitRegistry.Quantity):
pinttesting.assert_equal(exp, res)
if hasattr(exp.m, "dtype"):
assert exp.m.dtype == res.m.dtype
else:
isinstance(inp.m, roundtrip_res.m)
isinstance(exp.m, res.m)

return

assert roundtrip_res == inp
assert res == exp
Loading