diff --git a/CHANGES.rst b/CHANGES.rst index cbe5a09b0..26cb1b935 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -17,6 +17,9 @@ - Add support for ASDF Standard 1.5.0, which includes several new transform schemas. [#776] +- Enable validation and serialization of previously unhandled numpy + scalar types. [#778] + 2.5.2 (2020-02-28) ------------------ diff --git a/asdf/compat/numpycompat.py b/asdf/compat/numpycompat.py index 79e38320d..45e679777 100644 --- a/asdf/compat/numpycompat.py +++ b/asdf/compat/numpycompat.py @@ -1,7 +1,8 @@ from ..util import minversion -__all__ = ['NUMPY_LT_1_7'] +__all__ = ['NUMPY_LT_1_7', 'NUMPY_LT_1_14'] NUMPY_LT_1_7 = not minversion('numpy', '1.7.0') +NUMPY_LT_1_14 = not minversion('numpy', '1.14.0') diff --git a/asdf/schema.py b/asdf/schema.py index 6665620c7..baa350b7a 100644 --- a/asdf/schema.py +++ b/asdf/schema.py @@ -16,6 +16,7 @@ from jsonschema.exceptions import ValidationError import yaml +import numpy as np from . import constants from . import generic_io @@ -213,10 +214,11 @@ def _create_validator(validators=YAML_VALIDATORS): if JSONSCHEMA_LT_3: base_cls = mvalidators.create(meta_schema=meta_schema, validators=validators) else: - type_checker = mvalidators.Draft4Validator.TYPE_CHECKER.redefine( - 'array', - lambda checker, instance: isinstance(instance, list) or isinstance(instance, tuple) - ) + type_checker = mvalidators.Draft4Validator.TYPE_CHECKER.redefine_many({ + 'array': lambda checker, instance: isinstance(instance, list) or isinstance(instance, tuple), + 'integer': lambda checker, instance: not isinstance(instance, bool) and isinstance(instance, Integral), + 'string': lambda checker, instance: isinstance(instance, (str, np.str_)), + }) id_of = mvalidators.Draft4Validator.ID_OF base_cls = mvalidators.create( meta_schema=meta_schema, @@ -229,6 +231,8 @@ class ASDFValidator(base_cls): if JSONSCHEMA_LT_3: DEFAULT_TYPES = base_cls.DEFAULT_TYPES.copy() DEFAULT_TYPES['array'] = (list, tuple) + DEFAULT_TYPES['integer'] = (Integral) + DEFAULT_TYPES['string'] = (str, np.str_) def iter_errors(self, instance, _schema=None, _seen=set()): # We can't validate anything that looks like an external reference, diff --git a/asdf/tests/test_schema.py b/asdf/tests/test_schema.py index 1a265efe1..98a2f9b26 100644 --- a/asdf/tests/test_schema.py +++ b/asdf/tests/test_schema.py @@ -878,3 +878,48 @@ def test_nonexistent_tag(tmpdir): assert str(w[0].message).startswith("Unable to locate schema file") assert str(w[1].message).startswith("Unable to locate schema file") assert str(w[2].message).startswith(af['a']._tag) + + +@pytest.mark.parametrize("numpy_value,valid_types", [ + (np.str_("foo"), {"string"}), + (np.bytes_("foo"), set()), + (np.float16(3.14), {"number"}), + (np.float32(3.14159), {"number"}), + (np.float64(3.14159), {"number"}), + # Evidently float128 is not available on Windows: + (getattr(np, "float128", np.float64)(3.14159), {"number"}), + (np.int8(42), {"number", "integer"}), + (np.int16(42), {"number", "integer"}), + (np.int32(42), {"number", "integer"}), + (np.longlong(42), {"number", "integer"}), + (np.uint8(42), {"number", "integer"}), + (np.uint16(42), {"number", "integer"}), + (np.uint32(42), {"number", "integer"}), + (np.uint64(42), {"number", "integer"}), + (np.ulonglong(42), {"number", "integer"}), +]) +def test_numpy_scalar_type_validation(numpy_value, valid_types): + def _assert_validation(jsonschema_type, expected_valid): + validator = schema.get_validator() + try: + validator.validate(numpy_value, _schema={"type": jsonschema_type}) + except ValidationError: + valid = False + else: + valid = True + + if valid is not expected_valid: + if expected_valid: + description = "valid" + else: + description = "invalid" + assert False, "Expected numpy.{} to be {} against jsonschema type '{}'".format( + type(numpy_value).__name__, description, jsonschema_type + ) + + for jsonschema_type in valid_types: + _assert_validation(jsonschema_type, True) + + invalid_types = {"string", "number", "integer", "boolean", "null", "object"} - valid_types + for jsonschema_type in invalid_types: + _assert_validation(jsonschema_type, False) diff --git a/asdf/tests/test_yaml.py b/asdf/tests/test_yaml.py index 05e728da9..cc7d4a0c9 100644 --- a/asdf/tests/test_yaml.py +++ b/asdf/tests/test_yaml.py @@ -14,6 +14,8 @@ import asdf from asdf import tagged from asdf import treeutil +from asdf import yamlutil +from asdf.compat.numpycompat import NUMPY_LT_1_14 from . import helpers @@ -275,3 +277,36 @@ class SomeObject: tag = 'tag:nowhere.org:none/some/thing' instance = tagged.tag_object(tag, SomeObject()) assert instance._tag == tag + + +@pytest.mark.parametrize("numpy_value,expected_value", [ + (np.str_("foo"), "foo"), + (np.bytes_("foo"), b"foo"), + (np.float16(3.14), 3.14), + (np.float32(3.14159), 3.14159), + (np.float64(3.14159), 3.14159), + # Evidently float128 is not available on Windows: + (getattr(np, "float128", np.float64)(3.14159), 3.14159), + (np.int8(42), 42), + (np.int16(42), 42), + (np.int32(42), 42), + (np.int64(42), 42), + (np.longlong(42), 42), + (np.uint8(42), 42), + (np.uint16(42), 42), + (np.uint32(42), 42), + (np.uint64(42), 42), + (np.ulonglong(42), 42), +]) +def test_numpy_scalar(numpy_value, expected_value): + ctx = asdf.AsdfFile() + tree = {"value": numpy_value} + buffer = io.BytesIO() + + yamlutil.dump_tree(tree, buffer, ctx) + buffer.seek(0) + + if isinstance(expected_value, float) and NUMPY_LT_1_14: + assert yamlutil.load_tree(buffer, ctx)["value"] == pytest.approx(expected_value, rel=0.001) + else: + assert yamlutil.load_tree(buffer, ctx)["value"] == expected_value diff --git a/asdf/yamlutil.py b/asdf/yamlutil.py index 552b7852a..7248a06ea 100644 --- a/asdf/yamlutil.py +++ b/asdf/yamlutil.py @@ -218,6 +218,15 @@ def represent_ordereddict(dumper, data): for scalar_type in util.iter_subclasses(np.integer): AsdfDumper.add_representer(scalar_type, AsdfDumper.represent_int) +def represent_numpy_str(dumper, data): + # The CSafeDumper implementation will raise an error if it + # doesn't recognize data as a string. The Python SafeDumper + # has no problem with np.str_. + return dumper.represent_str(str(data)) + +AsdfDumper.add_representer(np.str_, represent_numpy_str) +AsdfDumper.add_representer(np.bytes_, AsdfDumper.represent_binary) + def custom_tree_to_tagged_tree(tree, ctx): """