diff --git a/asdf/__init__.py b/asdf/__init__.py index cc062f95c..1a4d14e64 100644 --- a/asdf/__init__.py +++ b/asdf/__init__.py @@ -15,7 +15,7 @@ # ---------------------------------------------------------------------------- if _ASDF_SETUP_ is False: - __all__ = ['AsdfFile', 'AsdfType', 'AsdfExtension', + __all__ = ['AsdfFile', 'CustomType', 'AsdfExtension', 'Stream', 'open', 'test', 'commands', 'ValidationError'] @@ -35,7 +35,7 @@ raise ImportError("asdf requires numpy") from .asdf import AsdfFile - from .asdftypes import AsdfType + from .asdftypes import CustomType from .extension import AsdfExtension from .stream import Stream from . import commands diff --git a/asdf/asdf.py b/asdf/asdf.py index 5123b0202..962fef47d 100644 --- a/asdf/asdf.py +++ b/asdf/asdf.py @@ -448,7 +448,6 @@ def _open_asdf(cls, self, fd, uri=None, mode='r', self.version = version yaml_token = fd.read(4) - yaml_content = b'' tree = {} has_blocks = False if yaml_token == b'%YAM': @@ -456,8 +455,11 @@ def _open_asdf(cls, self, fd, uri=None, mode='r', constants.YAML_END_MARKER_REGEX, 7, 'End of YAML marker', include=True, initial_content=yaml_token) + # For testing: just return the raw YAML content if _get_yaml_content: yaml_content = reader.read() + fd.close() + return yaml_content else: # We parse the YAML content into basic data structures # now, but we don't do anything special with it until @@ -469,11 +471,6 @@ def _open_asdf(cls, self, fd, uri=None, mode='r', elif yaml_token != b'': raise IOError("ASDF file appears to contain garbage after header.") - # For testing: just return the raw YAML content - if _get_yaml_content: - fd.close() - return yaml_content - if has_blocks: self._blocks.read_internal_blocks( fd, past_magic=True, validate_checksums=validate_checksums) diff --git a/asdf/asdftypes.py b/asdf/asdftypes.py index bbc791914..eda37ab7c 100644 --- a/asdf/asdftypes.py +++ b/asdf/asdftypes.py @@ -9,6 +9,7 @@ import re import six +from copy import copy from .compat import lru_cache @@ -16,7 +17,7 @@ from . import tagged from . import util -from . import versioning +from .versioning import get_version_map, version_to_string __all__ = ['format_tag', 'AsdfTypeIndex', 'AsdfType'] @@ -51,7 +52,7 @@ def join_tag_version(name, version): """ Join the root and version of a tag back together. """ - return '{0}-{1}'.format(name, versioning.version_to_string(version)) + return '{0}-{1}'.format(name, version_to_string(version)) class _AsdfWriteTypeIndex(object): @@ -106,7 +107,7 @@ def add_by_tag(name, version): add_by_tag(name, version) else: try: - version_map = versioning.get_version_map(version) + version_map = get_version_map(version) except ValueError: raise ValueError( "Don't know how to write out ASDF version {0}".format( @@ -160,6 +161,7 @@ def __init__(self): self._type_by_tag = {} self._versions_by_type_name = {} self._best_matches = {} + self._real_tag = {} self._unnamed_types = set() self._hooks_by_type = {} self._all_types = set() @@ -225,6 +227,8 @@ def fix_yaml_tag(self, tag): Raises a warning if it could not find a match where the major and minor numbers are the same. """ + warning_string = None + if tag in self._type_by_tag: return tag @@ -243,30 +247,32 @@ def fix_yaml_tag(self, tag): # The versions list is kept sorted, so bisect can be used to # quickly find the best option. - i = bisect.bisect_left(versions, version) i = max(0, i - 1) best_version = versions[i] - if best_version[:2] == version[:2]: - # Major and minor match, so only patch and devel differs - # -- no need for alarm - warning_string = None - else: - warning_string = ( - "'{0}' with version {1} found in file, but asdf only " - "understands version {2}.".format( + if best_version[:2] != version[:2]: + warning_string = \ + "'{}' with version {} found in file, but asdf only supports " \ + "version {}".format( name, semver.format_version(*version), - semver.format_version(*best_version))) - - if warning_string: + semver.format_version(*best_version)) warnings.warn(warning_string) best_tag = join_tag_version(name, best_version) self._best_matches[tag] = best_tag, warning_string + if tag != best_tag: + self._real_tag[best_tag] = tag return best_tag + def get_real_tag(self, tag): + if tag in self._real_tag: + return self._real_tag[tag] + elif tag in self._type_by_tag: + return tag + return None + def from_yaml_tag(self, tag): """ From a given YAML tag string, return the corresponding @@ -320,10 +326,9 @@ def _from_tree_tagged_missing_requirements(cls, tree, ctx): return tree -class AsdfTypeMeta(type): +class ExtensionTypeMeta(type): """ - Keeps track of `AsdfType` subclasses that are created, and stores - them in `AsdfTypeIndex`. + Custom class constructor for extension types. """ _import_cache = {} @@ -358,6 +363,10 @@ def _find_in_bases(cls, attrs, bases, name, default=None): return getattr(base, name) return default + @property + def versioned_siblings(mcls): + return getattr(mcls, '__versioned_siblings') or [] + def __new__(mcls, name, bases, attrs): requires = mcls._find_in_bases(attrs, bases, 'requires', []) if not mcls._has_required_modules(requires): @@ -375,7 +384,7 @@ def __new__(mcls, name, bases, attrs): new_types.append(typ) attrs['types'] = new_types - cls = super(AsdfTypeMeta, mcls).__new__(mcls, name, bases, attrs) + cls = super(ExtensionTypeMeta, mcls).__new__(mcls, name, bases, attrs) if hasattr(cls, 'name'): if isinstance(cls.name, six.string_types): @@ -386,14 +395,42 @@ def __new__(mcls, name, bases, attrs): elif cls.name is not None: raise TypeError("name must be string or list") + if hasattr(cls, 'supported_versions'): + if not isinstance(cls.supported_versions, (list, set)): + raise TypeError( + "supported_versions attribute must be list or set") + supported_versions = set() + for version in cls.supported_versions: + supported_versions.add(version_to_string(version)) + cls.supported_versions = supported_versions + siblings = list() + for version in cls.supported_versions: + if version != version_to_string(cls.version): + new_attrs = copy(attrs) + new_attrs['version'] = version + new_attrs['supported_versions'] = set() + siblings.append( + ExtensionTypeMeta. __new__(mcls, name, bases, new_attrs)) + setattr(cls, '__versioned_siblings', siblings) + + return cls + + +class AsdfTypeMeta(ExtensionTypeMeta): + """ + Keeps track of `AsdfType` subclasses that are created, and stores them in + `AsdfTypeIndex`. + """ + def __new__(mcls, name, bases, attrs): + cls = super(AsdfTypeMeta, mcls).__new__(mcls, name, bases, attrs) + # Classes using this metaclass get added to the list of built-in + # extensions _all_asdftypes.add(cls) return cls -@six.add_metaclass(AsdfTypeMeta) -@six.add_metaclass(util.InheritDocstrings) -class AsdfType(object): +class ExtensionType(object): """ The base class of all custom types in the tree. @@ -417,6 +454,11 @@ class AsdfType(object): version : 3-tuple of int The version of the standard the type is defined in. + supported_versions : set + If provided, indicates explicit compatibility with the given set of + versions. Other versions of the same schema that are not included in + this set will not be converted to custom types with this class. + yaml_tag : str The YAML tag to use for the type. If not provided, it will be automatically generated from name, organization, standard and @@ -439,6 +481,7 @@ class AsdfType(object): organization = 'stsci.edu' standard = 'asdf' version = (1, 0, 0) + supported_versions = set() types = [] handle_dynamic_subclasses = False validators = {} @@ -450,7 +493,7 @@ def make_yaml_tag(cls, name): return format_tag( cls.organization, cls.standard, - versioning.version_to_string(cls.version), + version_to_string(cls.version), name) @classmethod @@ -488,3 +531,37 @@ def from_tree_tagged(cls, tree, ctx): with the tag directly. """ return cls.from_tree(tree.data, ctx) + + @classmethod + def incompatible_version(cls, version): + """ + If this tag class explicitly identifies compatible versions then this + checks whether a given version is compatible or not. Otherwise, all + versions are assumed to be compatible. + + Child classes can override this method to affect how version + compatiblity for this type is determined. + """ + if cls.supported_versions: + if version_to_string(version) not in cls.supported_versions: + return True + return False + + +@six.add_metaclass(AsdfTypeMeta) +@six.add_metaclass(util.InheritDocstrings) +class AsdfType(ExtensionType): + """ + Base class for all built-in ASDF types. Types that inherit this class will + be automatically added to the list of built-ins. This should *not* be used + for user-defined extensions. + """ + +@six.add_metaclass(ExtensionTypeMeta) +@six.add_metaclass(util.InheritDocstrings) +class CustomType(ExtensionType): + """ + Base class for all user-defined types. Unlike classes that inherit + AsdfType, classes that inherit this class will *not* automatically be added + to the list of built-ins. This should be used for user-defined extensions. + """ diff --git a/asdf/extension.py b/asdf/extension.py index baefd1d9d..2f287370a 100644 --- a/asdf/extension.py +++ b/asdf/extension.py @@ -120,6 +120,9 @@ def __init__(self, extensions): for typ in extension.types: self._type_index.add_type(typ) validators.update(typ.validators) + for sibling in typ.versioned_siblings: + self._type_index.add_type(sibling) + validators.update(sibling.validators) self._tag_mapping = resolver.Resolver(tag_mapping, 'tag') self._url_mapping = resolver.Resolver(url_mapping, 'url') self._validators = validators diff --git a/asdf/tags/wcs/tests/test_wcs.py b/asdf/tags/wcs/tests/test_wcs.py index 5266fb94c..20f70feeb 100644 --- a/asdf/tags/wcs/tests/test_wcs.py +++ b/asdf/tags/wcs/tests/test_wcs.py @@ -17,6 +17,7 @@ from gwcs import coordinate_frames as cf from gwcs import wcs +from .... import AsdfFile from ....tests import helpers @@ -120,6 +121,7 @@ def create_test_frames(): return frames + def test_frames(tmpdir): tree = { @@ -127,3 +129,144 @@ def test_frames(tmpdir): } helpers.assert_roundtrip_tree(tree, tmpdir) + + +def test_backwards_compat_galcen(): + # Hold these fields constant so that we can compare them + declination = 1.0208 # in degrees + right_ascension = 45.729 # in degrees + galcen_distance = 3.14 + roll = 4.0 + z_sun = 0.2084 + old_frame_yaml = """ +frames: + - !wcs/celestial_frame-1.0.0 + axes_names: [x, y, z] + axes_order: [0, 1, 2] + name: CelestialFrame + reference_frame: + type: galactocentric + galcen_dec: + - %f + - deg + galcen_ra: + - %f + - deg + galcen_distance: + - %f + - m + roll: + - %f + - deg + z_sun: + - %f + - pc + unit: [!unit/unit-1.0.0 deg, !unit/unit-1.0.0 deg, !unit/unit-1.0.0 deg] +""" % (declination, right_ascension, galcen_distance, roll, z_sun) + + new_frame_yaml = """ +frames: + - !wcs/celestial_frame-1.1.0 + axes_names: [x, y, z] + axes_order: [0, 1, 2] + name: CelestialFrame + reference_frame: + type: galactocentric + galcen_coord: !wcs/icrs_coord-1.1.0 + dec: {value: %f} + ra: + value: %f + wrap_angle: + !unit/quantity-1.1.0 {unit: !unit/unit-1.0.0 deg, value: 360.0} + galcen_distance: + !unit/quantity-1.1.0 {unit: !unit/unit-1.0.0 m, value: %f} + galcen_v_sun: + - !unit/quantity-1.1.0 {unit: !unit/unit-1.0.0 km s-1, value: 11.1} + - !unit/quantity-1.1.0 {unit: !unit/unit-1.0.0 km s-1, value: 232.24} + - !unit/quantity-1.1.0 {unit: !unit/unit-1.0.0 km s-1, value: 7.25} + roll: !unit/quantity-1.1.0 {unit: !unit/unit-1.0.0 deg, value: %f} + z_sun: !unit/quantity-1.1.0 {unit: !unit/unit-1.0.0 pc, value: %f} + unit: [!unit/unit-1.0.0 deg, !unit/unit-1.0.0 deg, !unit/unit-1.0.0 deg] +""" % (declination, right_ascension, galcen_distance, roll, z_sun) + + old_buff = helpers.yaml_to_asdf(old_frame_yaml) + old_asdf = AsdfFile.open(old_buff) + old_frame = old_asdf.tree['frames'][0] + new_buff = helpers.yaml_to_asdf(new_frame_yaml) + new_asdf = AsdfFile.open(new_buff) + new_frame = new_asdf.tree['frames'][0] + + # Poor man's frame comparison since it's not implemented by astropy + assert old_frame.axes_names == new_frame.axes_names + assert old_frame.axes_order == new_frame.axes_order + assert old_frame.unit == new_frame.unit + + old_refframe = old_frame.reference_frame + new_refframe = new_frame.reference_frame + + # v1.0.0 frames have no representation of galcen_v_center, so do not compare + assert old_refframe.galcen_distance == new_refframe.galcen_distance + assert old_refframe.galcen_coord.dec == new_refframe.galcen_coord.dec + assert old_refframe.galcen_coord.ra == new_refframe.galcen_coord.ra + + +def test_backwards_compat_gcrs(): + obsgeoloc = ( + 3.0856775814671916e+16, + 9.257032744401574e+16, + 6.1713551629343834e+19 + ) + obsgeovel = (2.0, 1.0, 8.0) + + old_frame_yaml = """ +frames: + - !wcs/celestial_frame-1.0.0 + axes_names: [lon, lat] + name: CelestialFrame + reference_frame: + type: GCRS + obsgeoloc: + - [%f, %f, %f] + - !unit/unit-1.0.0 m + obsgeovel: + - [%f, %f, %f] + - !unit/unit-1.0.0 m s-1 + obstime: !time/time-1.0.0 2010-01-01 00:00:00.000 + unit: [!unit/unit-1.0.0 deg, !unit/unit-1.0.0 deg] +""" % (obsgeovel + obsgeoloc) + + new_frame_yaml = """ +frames: + - !wcs/celestial_frame-1.1.0 + axes_names: [lon, lat] + name: CelestialFrame + reference_frame: + type: GCRS + obsgeoloc: + - !unit/quantity-1.1.0 {unit: !unit/unit-1.0.0 m, value: %f} + - !unit/quantity-1.1.0 {unit: !unit/unit-1.0.0 m, value: %f} + - !unit/quantity-1.1.0 {unit: !unit/unit-1.0.0 m, value: %f} + obsgeovel: + - !unit/quantity-1.1.0 {unit: !unit/unit-1.0.0 m s-1, value: %f} + - !unit/quantity-1.1.0 {unit: !unit/unit-1.0.0 m s-1, value: %f} + - !unit/quantity-1.1.0 {unit: !unit/unit-1.0.0 m s-1, value: %f} + obstime: !time/time-1.0.0 2010-01-01 00:00:00.000 + unit: [!unit/unit-1.0.0 deg, !unit/unit-1.0.0 deg] +""" % (obsgeovel + obsgeoloc) + + old_buff = helpers.yaml_to_asdf(old_frame_yaml) + old_asdf = AsdfFile.open(old_buff) + old_frame = old_asdf.tree['frames'][0] + old_loc = old_frame.reference_frame.obsgeoloc + old_vel = old_frame.reference_frame.obsgeovel + + new_buff = helpers.yaml_to_asdf(new_frame_yaml) + new_asdf = AsdfFile.open(new_buff) + new_frame = new_asdf.tree['frames'][0] + new_loc = new_frame.reference_frame.obsgeoloc + new_vel = new_frame.reference_frame.obsgeovel + + assert (old_loc.x == new_loc.x and old_loc.y == new_loc.y and + old_loc.z == new_loc.z) + assert (old_vel.x == new_vel.x and old_vel.y == new_vel.y and + old_vel.z == new_vel.z) diff --git a/asdf/tags/wcs/wcs.py b/asdf/tags/wcs/wcs.py index 7bc38f6fb..1bee2c50d 100644 --- a/asdf/tags/wcs/wcs.py +++ b/asdf/tags/wcs/wcs.py @@ -9,6 +9,7 @@ from ...asdftypes import AsdfType from ... import yamlutil +from ...versioning import version_to_string @@ -112,44 +113,71 @@ def _get_inverse_reference_frame_mapping(cls): return cls._inverse_reference_frame_mapping @classmethod - def _from_tree(cls, node, ctx): + def _reference_frame_from_tree(cls, node, ctx): from ..unit import QuantityType - from astropy.coordinates import (CartesianRepresentation, + from astropy.units import Quantity + from astropy.coordinates import (ICRS, CartesianRepresentation, CartesianDifferential) - kwargs = {} + version = version_to_string(cls.version) + reference_frame = node['reference_frame'] + reference_frame_name = reference_frame['type'] + + frame_cls = cls._get_reference_frame_mapping()[reference_frame_name] + + frame_kwargs = {} + for name in frame_cls.get_frame_attr_names().keys(): + val = reference_frame.get(name) + if val is not None: + # These are deprecated fields that must be handled as a special + # case for older versions of the schema + if name in ['galcen_ra', 'galcen_dec']: + continue + # There was no schema for quantities in v1.0.0 + if name in ['galcen_distance', 'roll', 'z_sun'] and version == '1.0.0': + val = Quantity(val[0], unit=val[1]) + # These fields are known to be CartesianRepresentations + if name in ['obsgeoloc', 'obsgeovel']: + if version == '1.0.0': + unit = val[1] + x = Quantity(val[0][0], unit=unit) + y = Quantity(val[0][1], unit=unit) + z = Quantity(val[0][2], unit=unit) + else: + x = QuantityType.from_tree(val[0], ctx) + y = QuantityType.from_tree(val[1], ctx) + z = QuantityType.from_tree(val[2], ctx) + val = CartesianRepresentation(x, y, z) + elif name == 'galcen_v_sun': + # This field only exists since v1.1.0 + d_x = QuantityType.from_tree(val[0], ctx) + d_y = QuantityType.from_tree(val[1], ctx) + d_z = QuantityType.from_tree(val[2], ctx) + val = CartesianDifferential(d_x, d_y, d_z) + else: + val = yamlutil.tagged_tree_to_custom_tree(val, ctx) + frame_kwargs[name] = val + has_ra_and_dec = reference_frame.get('galcen_dec') and \ + reference_frame.get('galcen_ra') + if version == '1.0.0' and has_ra_and_dec: + # Convert deprecated ra and dec fields into galcen_coord + galcen_dec = reference_frame['galcen_dec'] + galcen_ra = reference_frame['galcen_ra'] + dec = Quantity(galcen_dec[0], unit=galcen_dec[1]) + ra = Quantity(galcen_ra[0], unit=galcen_ra[1]) + frame_kwargs['galcen_coord'] = ICRS(dec=dec, ra=ra) + return frame_cls(**frame_kwargs) - kwargs['name'] = node['name'] + @classmethod + def _from_tree(cls, node, ctx): + kwargs = {'name': node['name']} if 'axes_names' in node: kwargs['axes_names'] = node['axes_names'] if 'reference_frame' in node: - reference_frame = node['reference_frame'] - reference_frame_name = reference_frame['type'] - - frame_cls = cls._get_reference_frame_mapping()[reference_frame_name] - - frame_kwargs = {} - for name in frame_cls.get_frame_attr_names().keys(): - val = reference_frame.get(name) - if val is not None: - # These fields are known to be CartesianRepresentations - if name in ['obsgeoloc', 'obsgeovel']: - x = QuantityType.from_tree(val[0], ctx) - y = QuantityType.from_tree(val[1], ctx) - z = QuantityType.from_tree(val[2], ctx) - val = CartesianRepresentation(x, y, z) - elif name == 'galcen_v_sun': - d_x = QuantityType.from_tree(val[0], ctx) - d_y = QuantityType.from_tree(val[1], ctx) - d_z = QuantityType.from_tree(val[2], ctx) - val = CartesianDifferential(d_x, d_y, d_z) - else: - val = yamlutil.tagged_tree_to_custom_tree(val, ctx) - frame_kwargs[name] = val - - kwargs['reference_frame'] = frame_cls(**frame_kwargs) + kwargs['reference_frame'] = \ + cls._reference_frame_from_tree(node, ctx) if 'axes_order' in node: kwargs['axes_order'] = tuple(node['axes_order']) @@ -235,10 +263,10 @@ def from_tree(cls, node, ctx): def to_tree(cls, frame, ctx): return cls._to_tree(frame, ctx) - class CelestialFrameType(FrameType): name = "wcs/celestial_frame" types = ['gwcs.CelestialFrame'] + supported_versions = [(1,0,0), (1,1,0)] @classmethod def from_tree(cls, node, ctx): diff --git a/asdf/tests/data/custom_flow-1.1.0.yaml b/asdf/tests/data/custom_flow-1.1.0.yaml new file mode 100644 index 000000000..c932dcb46 --- /dev/null +++ b/asdf/tests/data/custom_flow-1.1.0.yaml @@ -0,0 +1,11 @@ +%YAML 1.1 +--- +$schema: "http://stsci.edu/schemas/yaml-schema/draft-01" +id: "http://nowhere.org/schemas/custom/custom_flow-1.1.0" +type: object +properties: + c: + type: number + d: + type: number +flowStyle: block diff --git a/asdf/tests/test_asdftypes.py b/asdf/tests/test_asdftypes.py index 1355b451a..8f54c8d3a 100644 --- a/asdf/tests/test_asdftypes.py +++ b/asdf/tests/test_asdftypes.py @@ -103,7 +103,7 @@ def test_version_mismatch(): assert len(w) == 1 assert str(w[0].message) == ( "'tag:stsci.edu:asdf/core/complex' with version 42.0.0 found in file, " - "but asdf only understands version 1.0.0.") + "but asdf only supports version 1.0.0") # Make sure warning is repeatable buff.seek(0) @@ -114,7 +114,7 @@ def test_version_mismatch(): assert len(w) == 1 assert str(w[0].message) == ( "'tag:stsci.edu:asdf/core/complex' with version 42.0.0 found in file, " - "but asdf only understands version 1.0.0.") + "but asdf only supports version 1.0.0") # If the major and minor match, there should be no warning. yaml = """ @@ -220,3 +220,174 @@ class DoesntHaveCorrectPytest(asdftypes.AsdfType): assert nmt.has_required_modules == False assert hcp.has_required_modules == True assert dhcp.has_required_modules == False + + +def test_undefined_tag(): + # This tests makes sure that ASDF still returns meaningful structured data + # even when it encounters a schema tag that it does not specifically + # implement as an extension + from numpy import array + + yaml = """ +undefined_data: + ! + - 5 + - {'message': 'there is no tag'} + - !core/ndarray-1.0.0 + [[1, 2, 3], [4, 5, 6]] + - ! + - !core/ndarray-1.0.0 [[7],[8],[9],[10]] + - !core/complex-1.0.0 3.14j +""" + buff = helpers.yaml_to_asdf(yaml) + afile = asdf.AsdfFile.open(buff) + missing = afile.tree['undefined_data'] + + assert missing[0] == 5 + assert missing[1] == {'message': 'there is no tag'} + assert (missing[2] == array([[1, 2, 3], [4, 5, 6]])).all() + assert (missing[3][0] == array([[7],[8],[9],[10]])).all() + assert missing[3][1] == 3.14j + + +def test_newer_tag(): + from astropy.tests.helper import catch_warnings + # This test simulates a scenario where newer versions of CustomFlow + # provides different keyword parameters that the older schema and tag class + # do not account for. We want to test whether ASDF can handle this problem + # gracefully and still provide meaningful data as output. The test case is + # fairly contrived but we want to test whether ASDF can handle backwards + # compatibility even when an explicit tag class for different versions of a + # schema is not available. + class CustomFlow(object): + def __init__(self, c=None, d=None): + self.c = c + self.d = d + + class CustomFlowType(asdftypes.CustomType): + version = '1.1.0' + name = 'custom_flow' + organization = 'nowhere.org' + standard = 'custom' + types = [CustomFlow] + + @classmethod + def from_tree(cls, tree, ctx): + kwargs = {} + for name in tree: + kwargs[name] = tree[name] + return CustomFlow(**kwargs) + + @classmethod + def to_tree(cls, data, ctx): + tree = dict(c=data.c, d=data.d) + + class CustomFlowExtension(object): + @property + def types(self): + return [CustomFlowType] + + @property + def tag_mapping(self): + return [('tag:nowhere.org:custom', + 'http://nowhere.org/schemas/custom{tag_suffix}')] + + @property + def url_mapping(self): + return [('http://nowhere.org/schemas/custom/', + util.filepath_to_url(TEST_DATA_PATH) + + '/{url_suffix}.yaml')] + + new_yaml = """ +flow_thing: + ! + c: 100 + d: 3.14 +""" + new_buff = helpers.yaml_to_asdf(new_yaml) + new_data = asdf.AsdfFile.open(new_buff, extensions=CustomFlowExtension()) + assert type(new_data.tree['flow_thing']) == CustomFlow + + old_yaml = """ +flow_thing: + ! + a: 100 + b: 3.14 +""" + old_buff = helpers.yaml_to_asdf(old_yaml) + with catch_warnings() as w: + asdf.AsdfFile.open(old_buff, extensions=CustomFlowExtension()) + + assert len(w) == 1 + assert str(w[0].message) == ( + "'tag:nowhere.org:custom/custom_flow' with version 1.0.0 found " + "in file, but asdf only supports version 1.1.0") + + +def test_supported_versions(): + from astropy.tests.helper import catch_warnings + class CustomFlow(object): + def __init__(self, c=None, d=None): + self.c = c + self.d = d + + class CustomFlowType(asdftypes.CustomType): + version = '1.1.0' + supported_versions = [(1,0,0), (1,1,0)] + name = 'custom_flow' + organization = 'nowhere.org' + standard = 'custom' + types = [CustomFlow] + + @classmethod + def from_tree(cls, tree, ctx): + # Convert old schema to new CustomFlow type + if versioning.version_to_string(cls.version) == '1.0.0': + return CustomFlow(c=tree['a'], d=tree['b']) + else: + return CustomFlow(**tree) + return CustomFlow(**kwargs) + + @classmethod + def to_tree(cls, data, ctx): + if versioning.version_to_string(cls.version) == '1.0.0': + tree = dict(a=data.c, b=data.d) + else: + tree = dict(c=data.c, d=data.d) + + class CustomFlowExtension(object): + @property + def types(self): + return [CustomFlowType] + + @property + def tag_mapping(self): + return [('tag:nowhere.org:custom', + 'http://nowhere.org/schemas/custom{tag_suffix}')] + + @property + def url_mapping(self): + return [('http://nowhere.org/schemas/custom/', + util.filepath_to_url(TEST_DATA_PATH) + + '/{url_suffix}.yaml')] + + new_yaml = """ +flow_thing: + ! + c: 100 + d: 3.14 +""" + old_yaml = """ +flow_thing: + ! + a: 100 + b: 3.14 +""" + new_buff = helpers.yaml_to_asdf(new_yaml) + new_data = asdf.AsdfFile.open(new_buff, extensions=CustomFlowExtension()) + assert type(new_data.tree['flow_thing']) == CustomFlow + + old_buff = helpers.yaml_to_asdf(old_yaml) + with catch_warnings() as w: + old_data = asdf.AsdfFile.open(old_buff, extensions=CustomFlowExtension()) + assert type(old_data.tree['flow_thing']) == CustomFlow diff --git a/asdf/yamlutil.py b/asdf/yamlutil.py index 251f8cb1d..009502e21 100644 --- a/asdf/yamlutil.py +++ b/asdf/yamlutil.py @@ -6,14 +6,16 @@ import numpy as np import six - import yaml +import warnings from . compat.odict import OrderedDict from . constants import YAML_TAG_PREFIX from . import schema from . import tagged from . import treeutil +from . import asdftypes +from . import versioning from . import util @@ -243,12 +245,25 @@ def tagged_tree_to_custom_tree(tree, ctx): Convert a tree containing only basic data types, annotated with tags, to a tree containing custom data types. """ + def walker(node): tag_name = getattr(node, '_tag', None) if tag_name is not None: tag_type = ctx.type_index.from_yaml_tag(tag_name) if tag_type is not None: - return tag_type.from_tree_tagged(node, ctx) + real_tag = ctx.type_index.get_real_tag(tag_name) + _, real_tag_version = asdftypes.split_tag_version(real_tag) + if not tag_type.incompatible_version(real_tag_version): + # If a tag class does not explicitly list compatible + # versions, then all versions of the corresponding schema + # are assumed to be compatible. Therefore we need to check + # to make sure whether the conversion is actually + # successful, and just return a raw Python data type if it + # is not. + try: + return tag_type.from_tree_tagged(node, ctx) + except TypeError: + pass return node return treeutil.walk_and_modify(tree, walker) diff --git a/docs/asdf/extensions.rst b/docs/asdf/extensions.rst index 3f2e0ed65..280378b93 100644 --- a/docs/asdf/extensions.rst +++ b/docs/asdf/extensions.rst @@ -5,8 +5,8 @@ Supporting new types in asdf is easy. There are three pieces needed: 1. A YAML Schema file for each new type. -2. A Python class (inheriting from `asdf.AsdfType`) for each new - type. +2. A Python class (inheriting from `asdf.CustomType`) for each new + user-defined type. 3. A Python class to define an "extension" to ASDF, which is a set of related types. This class must implement the @@ -33,7 +33,7 @@ First, the YAML Schema, defining the type as a pair of integers: maxItems: 2 ... -Then, the Python implementation. See the `asdf.AsdfType` and +Then, the Python implementation. See the `asdf.CustomType` and `asdf.AsdfExtension` documentation for more information:: import os @@ -43,7 +43,7 @@ Then, the Python implementation. See the `asdf.AsdfType` and import fractions - class FractionType(asdf.AsdfType): + class FractionType(asdf.CustomType): name = 'fraction' organization = 'nowhere.org' version = (1, 0, 0) @@ -83,7 +83,7 @@ values in an ASDF file. This feature is used internally so a schema can specify the required datatype of an array. To support custom validation keywords, set the ``validators`` member -of an ``AsdfType`` subclass to a dictionary where the keys are the +of a ``CustomType`` subclass to a dictionary where the keys are the validation keyword name and the values are validation functions. The validation functions are of the same form as the validation functions in the underlying ``jsonschema`` library, and are passed the following