Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix resolution of extension classes that have references #1183

Merged
merged 4 commits into from
Aug 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
### Bug fixes
- Fixed issue where scalar datasets with a compound data type were being written as non-scalar datasets @stephprince [#1176](https://github.com/hdmf-dev/hdmf/pull/1176)
- Fixed H5DataIO not exposing `maxshape` on non-dci dsets. @cboulay [#1149](https://github.com/hdmf-dev/hdmf/pull/1149)
- Fixed generation of classes in an extension that contain attributes or datasets storing references to other types defined in the extension.
@rly [#1183](https://github.com/hdmf-dev/hdmf/pull/1183)

## HDMF 3.14.3 (July 29, 2024)

Expand Down
17 changes: 15 additions & 2 deletions src/hdmf/build/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from .classgenerator import ClassGenerator, CustomClassGenerator, MCIClassGenerator
from ..container import AbstractContainer, Container, Data
from ..term_set import TypeConfigurator
from ..spec import DatasetSpec, GroupSpec, NamespaceCatalog
from ..spec import DatasetSpec, GroupSpec, NamespaceCatalog, RefSpec
from ..spec.spec import BaseStorageSpec
from ..utils import docval, getargs, ExtenderMeta, get_docval

Expand Down Expand Up @@ -480,6 +480,7 @@ def load_namespaces(self, **kwargs):
load_namespaces here has the advantage of being able to keep track of type dependencies across namespaces.
'''
deps = self.__ns_catalog.load_namespaces(**kwargs)
# register container types for each dependent type in each dependent namespace
for new_ns, ns_deps in deps.items():
for src_ns, types in ns_deps.items():
for dt in types:
Expand Down Expand Up @@ -529,7 +530,7 @@ def get_dt_container_cls(self, **kwargs):
namespace = ns_key
break
if namespace is None:
raise ValueError("Namespace could not be resolved.")
raise ValueError(f"Namespace could not be resolved for data type '{data_type}'.")

cls = self.__get_container_cls(namespace, data_type)

Expand All @@ -549,6 +550,8 @@ def get_dt_container_cls(self, **kwargs):

def __check_dependent_types(self, spec, namespace):
"""Ensure that classes for all types used by this type exist in this namespace and generate them if not.

`spec` should be a GroupSpec or DatasetSpec in the `namespace`
"""
def __check_dependent_types_helper(spec, namespace):
if isinstance(spec, (GroupSpec, DatasetSpec)):
Expand All @@ -564,6 +567,16 @@ def __check_dependent_types_helper(spec, namespace):

if spec.data_type_inc is not None:
self.get_dt_container_cls(spec.data_type_inc, namespace)

# handle attributes that have a reference dtype
for attr_spec in spec.attributes:
if isinstance(attr_spec.dtype, RefSpec):
self.get_dt_container_cls(attr_spec.dtype.target_type, namespace)
# handle datasets that have a reference dtype
if isinstance(spec, DatasetSpec):
if isinstance(spec.dtype, RefSpec):
self.get_dt_container_cls(spec.dtype.target_type, namespace)
# recurse into nested types
if isinstance(spec, GroupSpec):
for child_spec in (spec.groups + spec.datasets + spec.links):
__check_dependent_types_helper(child_spec, namespace)
Expand Down
180 changes: 178 additions & 2 deletions tests/unit/build_tests/test_classgenerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@
from hdmf.build import TypeMap, CustomClassGenerator
from hdmf.build.classgenerator import ClassGenerator, MCIClassGenerator
from hdmf.container import Container, Data, MultiContainerInterface, AbstractContainer
from hdmf.spec import GroupSpec, AttributeSpec, DatasetSpec, SpecCatalog, SpecNamespace, NamespaceCatalog, LinkSpec
from hdmf.spec import (
GroupSpec, AttributeSpec, DatasetSpec, SpecCatalog, SpecNamespace, NamespaceCatalog, LinkSpec, RefSpec
)
from hdmf.testing import TestCase
from hdmf.utils import get_docval, docval

Expand Down Expand Up @@ -734,9 +736,18 @@ def _build_separate_namespaces(self):
GroupSpec(data_type_inc='Bar', doc='a bar', quantity='?')
]
)
moo_spec = DatasetSpec(
doc='A test dataset that is a 1D array of object references of Baz',
data_type_def='Moo',
shape=(None,),
dtype=RefSpec(
reftype='object',
target_type='Baz'
)
)
create_load_namespace_yaml(
namespace_name='ndx-test',
specs=[baz_spec],
specs=[baz_spec, moo_spec],
output_dir=self.test_dir,
incl_types={
CORE_NAMESPACE: ['Bar'],
Expand Down Expand Up @@ -828,6 +839,171 @@ def test_get_class_include_from_separate_ns_4(self):

self._check_classes(baz_cls, bar_cls, bar_cls2, qux_cls, qux_cls2)

class TestGetClassObjectReferences(TestCase):

def setUp(self):
self.test_dir = tempfile.mkdtemp()
if os.path.exists(self.test_dir): # start clean
self.tearDown()
os.mkdir(self.test_dir)
self.type_map = TypeMap()

def tearDown(self):
shutil.rmtree(self.test_dir)

def test_get_class_include_dataset_of_references(self):
"""Test that get_class resolves datasets of object references."""
qux_spec = DatasetSpec(
doc='A test extension',
data_type_def='Qux'
)
moo_spec = DatasetSpec(
doc='A test dataset that is a 1D array of object references of Qux',
data_type_def='Moo',
shape=(None,),
dtype=RefSpec(
reftype='object',
target_type='Qux'
),
)

create_load_namespace_yaml(
namespace_name='ndx-test',
specs=[qux_spec, moo_spec],
output_dir=self.test_dir,
incl_types={},
type_map=self.type_map
)
# no types should be resolved to start
assert self.type_map.get_container_classes('ndx-test') == []

self.type_map.get_dt_container_cls('Moo', 'ndx-test')
# now, Moo and Qux should be resolved
assert len(self.type_map.get_container_classes('ndx-test')) == 2
assert "Moo" in [c.__name__ for c in self.type_map.get_container_classes('ndx-test')]
assert "Qux" in [c.__name__ for c in self.type_map.get_container_classes('ndx-test')]

def test_get_class_include_attribute_object_reference(self):
"""Test that get_class resolves data types with an attribute that is an object reference."""
qux_spec = DatasetSpec(
doc='A test extension',
data_type_def='Qux'
)
woo_spec = DatasetSpec(
doc='A test dataset that has a scalar object reference to a Qux',
data_type_def='Woo',
attributes=[
AttributeSpec(
name='attr1',
doc='a string attribute',
dtype=RefSpec(reftype='object', target_type='Qux')
),
]
)
create_load_namespace_yaml(
namespace_name='ndx-test',
specs=[qux_spec, woo_spec],
output_dir=self.test_dir,
incl_types={},
type_map=self.type_map
)
# no types should be resolved to start
assert self.type_map.get_container_classes('ndx-test') == []

self.type_map.get_dt_container_cls('Woo', 'ndx-test')
# now, Woo and Qux should be resolved
assert len(self.type_map.get_container_classes('ndx-test')) == 2
assert "Woo" in [c.__name__ for c in self.type_map.get_container_classes('ndx-test')]
assert "Qux" in [c.__name__ for c in self.type_map.get_container_classes('ndx-test')]

def test_get_class_include_nested_object_reference(self):
"""Test that get_class resolves nested datasets that are object references."""
qux_spec = DatasetSpec(
doc='A test extension',
data_type_def='Qux'
)
spam_spec = DatasetSpec(
doc='A test extension',
data_type_def='Spam',
shape=(None,),
dtype=RefSpec(
reftype='object',
target_type='Qux'
),
)
goo_spec = GroupSpec(
doc='A test dataset that has a nested dataset (Spam) that has a scalar object reference to a Qux',
data_type_def='Goo',
datasets=[
DatasetSpec(
doc='a dataset',
data_type_inc='Spam',
),
],
)

create_load_namespace_yaml(
namespace_name='ndx-test',
specs=[qux_spec, spam_spec, goo_spec],
output_dir=self.test_dir,
incl_types={},
type_map=self.type_map
)
# no types should be resolved to start
assert self.type_map.get_container_classes('ndx-test') == []

self.type_map.get_dt_container_cls('Goo', 'ndx-test')
# now, Goo, Spam, and Qux should be resolved
assert len(self.type_map.get_container_classes('ndx-test')) == 3
assert "Goo" in [c.__name__ for c in self.type_map.get_container_classes('ndx-test')]
assert "Spam" in [c.__name__ for c in self.type_map.get_container_classes('ndx-test')]
assert "Qux" in [c.__name__ for c in self.type_map.get_container_classes('ndx-test')]

def test_get_class_include_nested_attribute_object_reference(self):
"""Test that get_class resolves nested datasets that have an attribute that is an object reference."""
qux_spec = DatasetSpec(
doc='A test extension',
data_type_def='Qux'
)
bam_spec = DatasetSpec(
doc='A test extension',
data_type_def='Bam',
attributes=[
AttributeSpec(
name='attr1',
doc='a string attribute',
dtype=RefSpec(reftype='object', target_type='Qux')
),
],
)
boo_spec = GroupSpec(
doc='A test dataset that has a nested dataset (Spam) that has a scalar object reference to a Qux',
data_type_def='Boo',
datasets=[
DatasetSpec(
doc='a dataset',
data_type_inc='Bam',
),
],
)

create_load_namespace_yaml(
namespace_name='ndx-test',
specs=[qux_spec, bam_spec, boo_spec],
output_dir=self.test_dir,
incl_types={},
type_map=self.type_map
)
# no types should be resolved to start
assert self.type_map.get_container_classes('ndx-test') == []

self.type_map.get_dt_container_cls('Boo', 'ndx-test')
# now, Boo, Bam, and Qux should be resolved
assert len(self.type_map.get_container_classes('ndx-test')) == 3
assert "Boo" in [c.__name__ for c in self.type_map.get_container_classes('ndx-test')]
assert "Bam" in [c.__name__ for c in self.type_map.get_container_classes('ndx-test')]
assert "Qux" in [c.__name__ for c in self.type_map.get_container_classes('ndx-test')]


class EmptyBar(Container):
pass
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/build_tests/test_io_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,7 @@ def test_get_dt_container_cls(self):
self.assertIs(ret, Foo)

def test_get_dt_container_cls_no_namespace(self):
with self.assertRaisesWith(ValueError, "Namespace could not be resolved."):
with self.assertRaisesWith(ValueError, "Namespace could not be resolved for data type 'Unknown'."):
self.type_map.get_dt_container_cls(data_type="Unknown")


Expand Down
Loading