Skip to content

Commit

Permalink
Add export_spec utility function for use by extensions template (#170)
Browse files Browse the repository at this point in the history
  • Loading branch information
rly authored Oct 9, 2019
1 parent 8a6c650 commit 3e63b13
Show file tree
Hide file tree
Showing 2 changed files with 164 additions and 82 deletions.
43 changes: 42 additions & 1 deletion src/hdmf/spec/write.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import json
import ruamel.yaml as yaml
import os.path
import warnings
from collections import OrderedDict
from six import with_metaclass
from abc import ABCMeta, abstractmethod
Expand Down Expand Up @@ -69,7 +70,8 @@ def my_represent_none(self, data):
return self.represent_scalar(u'tag:yaml.org,2002:null', u'null')
yaml.representer.RoundTripRepresenter.add_representer(type(None), my_represent_none)

order = ['neurodata_type_def', 'neurodata_type_inc', 'name', 'default_name',
order = ['neurodata_type_def', 'neurodata_type_inc', 'data_type_def', 'data_type_inc',
'name', 'default_name',
'dtype', 'target_type', 'dims', 'shape', 'default_value', 'value', 'doc',
'required', 'quantity', 'attributes', 'datasets', 'groups', 'links']
if isinstance(obj, dict):
Expand Down Expand Up @@ -218,3 +220,42 @@ def add_spec(self, **kwargs):
self.setdefault('groups', list()).append(spec)
elif isinstance(spec, DatasetSpec):
self.setdefault('datasets', list()).append(spec)


def export_spec(ns_builder, new_data_types, output_dir):
"""
Create YAML specification files for a new namespace and extensions with
the given data type specs.
Args:
ns_builder - NamespaceBuilder instance used to build the
namespace and extension
new_data_types - Iterable of specs that represent new data types
to be added
"""

if len(new_data_types) == 0:
warnings.warn('No data types specified. Exiting.')
return

if not ns_builder.name:
raise RuntimeError('Namespace name is required to export specs')

ns_path = ns_builder.name + '.namespace.yaml'
ext_path = ns_builder.name + '.extensions.yaml'

if len(new_data_types) > 1:
pluralize = 's'
else:
pluralize = ''

print('Creating file {output_dir}/{ext_path} with {new_data_types_count} data type{pluralize}'.format(
pluralize=pluralize, output_dir=output_dir, ext_path=ext_path,
new_data_types_count=len(new_data_types)))

for data_type in new_data_types:
ns_builder.add_spec(ext_path, data_type)

print('Creating file {output_dir}/{ns_path}'.format(output_dir=output_dir, ns_path=ns_path))

ns_builder.export(ns_path, outdir=output_dir)
203 changes: 122 additions & 81 deletions tests/unit/spec_tests/test_spec_write.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,12 @@
import os
import datetime

from hdmf.spec.write import NamespaceBuilder, YAMLSpecWriter
from hdmf.spec.write import NamespaceBuilder, YAMLSpecWriter, export_spec
from hdmf.spec.namespace import SpecNamespace, NamespaceCatalog
from hdmf.spec.spec import GroupSpec


class TestNamespaceBuilder(unittest.TestCase):
NS_NAME = 'test_ns'
class TestSpec(unittest.TestCase):

def setUp(self):
# create a builder for the namespace
Expand Down Expand Up @@ -43,39 +42,70 @@ def setUp(self):
dtype='float',
name='testdata')

self.data_types = [ext1, ext2]

# add the extension
self.ext_source_path = 'mylab.specs.yaml'
self.ns_builder.add_spec(source=self.ext_source_path, spec=ext1)
self.ns_builder.add_spec(source=self.ext_source_path, spec=ext2)
self.ext_source_path = 'mylab.extensions.yaml'
self.namespace_path = 'mylab.namespace.yaml'

def _test_extensions_file(self):
with open(self.ext_source_path, 'r') as file:
match_str = \
"""groups:
- data_type_def: MyDataSeries
doc: A custom DataSeries interface
- data_type_def: MyExtendedMyDataSeries
data_type_inc: MyDataSeries
doc: An extension of a DataSeries interface
datasets:
- name: testdata
dtype: float
doc: test
""" # noqa: E128
nsstr = file.read()
self.assertEqual(nsstr, match_str)


class TestNamespaceBuilder(TestSpec):
NS_NAME = 'test_ns'

def setUp(self):
super(TestNamespaceBuilder, self).setUp()
for data_type in self.data_types:
self.ns_builder.add_spec(source=self.ext_source_path, spec=data_type)
self.ns_builder.add_source(source=self.ext_source_path,
doc='Extensions for my lab',
title='My lab extensions')

self.namespace_path = 'mylab.namespace.yaml'
self.ns_builder.export(self.namespace_path)

def tearDown(self):
if os.path.exists(self.ext_source_path):
os.remove(self.ext_source_path)
if os.path.exists(self.namespace_path):
os.remove(self.namespace_path)
pass

def test_export_namespace(self):
with open(self.namespace_path, 'r') as nsfile:
nsstr = nsfile.read()
self.assertTrue(nsstr.startswith("namespaces:\n"))
self.assertTrue("author: foo\n" in nsstr)
self.assertTrue("contact: [email protected]\n" in nsstr)
self.assertTrue("date: '%s'\n" % self.date.isoformat() in nsstr)
self.assertTrue("doc: mydoc\n" in nsstr)
self.assertTrue("full_name: My Laboratory\n" in nsstr)
self.assertTrue("name: mylab\n" in nsstr)
self.assertTrue("schema:\n" in nsstr)
self.assertTrue("doc: Extensions for my lab\n" in nsstr)
self.assertTrue("source: mylab.specs.yaml\n" in nsstr)
self.assertTrue("title: Extensions for my lab\n" in nsstr)
self.assertTrue("version: 0.0.1\n" in nsstr)
self._test_namespace_file()
self._test_extensions_file()

def _test_namespace_file(self):
with open(self.namespace_path, 'r') as file:
match_str = \
"""namespaces:
- author: foo
contact: [email protected]
date: '%s'
doc: mydoc
full_name: My Laboratory
name: mylab
schema:
- doc: Extensions for my lab
source: mylab.extensions.yaml
title: Extensions for my lab
version: 0.0.1
""" % self.date.isoformat() # noqa: E128
nsstr = file.read()
self.assertEqual(nsstr, match_str)

def test_read_namespace(self):
ns_catalog = NamespaceCatalog()
Expand All @@ -88,77 +118,41 @@ def test_read_namespace(self):
self.assertEqual(loaded_ns.name, "mylab")
self.assertEqual(loaded_ns.date, self.date.isoformat())
self.assertDictEqual(loaded_ns.schema[0], {'doc': 'Extensions for my lab',
'source': 'mylab.specs.yaml',
'source': 'mylab.extensions.yaml',
'title': 'Extensions for my lab'})
self.assertEqual(loaded_ns.version, "0.0.1")

def test_get_source_files(self):
ns_catalog = NamespaceCatalog()
ns_catalog.load_namespaces(self.namespace_path, resolve=True)
loaded_ns = ns_catalog.get_namespace(self.ns_name)
self.assertListEqual(loaded_ns.get_source_files(), ['mylab.specs.yaml'])
self.assertListEqual(loaded_ns.get_source_files(), ['mylab.extensions.yaml'])

def test_get_source_description(self):
ns_catalog = NamespaceCatalog()
ns_catalog.load_namespaces(self.namespace_path, resolve=True)
loaded_ns = ns_catalog.get_namespace(self.ns_name)
descr = loaded_ns.get_source_description('mylab.specs.yaml')
descr = loaded_ns.get_source_description('mylab.extensions.yaml')
self.assertDictEqual(descr, {'doc': 'Extensions for my lab',
'source': 'mylab.specs.yaml',
'source': 'mylab.extensions.yaml',
'title': 'Extensions for my lab'})


class TestYAMLSpecWrite(unittest.TestCase):
class TestYAMLSpecWrite(TestSpec):

def setUp(self):
# create a builder for the namespace
self.ns_name = "mylab"
self.date = datetime.datetime.now()

self.ns_builder = NamespaceBuilder(doc="mydoc",
name=self.ns_name,
full_name="My Laboratory",
version="0.0.1",
author="foo",
contact="[email protected]",
namespace_cls=SpecNamespace,
date=self.date)

# create extensions
ext1 = GroupSpec('A custom DataSeries interface',
attributes=[],
datasets=[],
groups=[],
data_type_inc=None,
data_type_def='MyDataSeries')

ext2 = GroupSpec('An extension of a DataSeries interface',
attributes=[],
datasets=[],
groups=[],
data_type_inc='MyDataSeries',
data_type_def='MyExtendedMyDataSeries')

ext2.add_dataset(doc='test',
dtype='float',
name='testdata')

# add the extension
self.ext_source_path = 'mylab.specs.yaml'
self.ns_builder.add_spec(source=self.ext_source_path, spec=ext1)
self.ns_builder.add_spec(source=self.ext_source_path, spec=ext2)
super(TestYAMLSpecWrite, self).setUp()
for data_type in self.data_types:
self.ns_builder.add_spec(source=self.ext_source_path, spec=data_type)
self.ns_builder.add_source(source=self.ext_source_path,
doc='Extensions for my lab',
title='My lab extensions')

self.namespace_path = 'mylab.namespace.yaml'

def tearDown(self):
if os.path.exists(self.ext_source_path):
os.remove(self.ext_source_path)
if os.path.exists(self.namespace_path):
os.remove(self.namespace_path)
pass

def test_init(self):
temp = YAMLSpecWriter('.')
Expand All @@ -167,20 +161,67 @@ def test_init(self):
def test_write_namespace(self):
temp = YAMLSpecWriter()
self.ns_builder.export(self.namespace_path, writer=temp)
with open(self.namespace_path, 'r') as nsfile:
nsstr = nsfile.read()
self.assertTrue(nsstr.startswith("namespaces:\n"))
self.assertTrue("author: foo\n" in nsstr)
self.assertTrue("contact: [email protected]\n" in nsstr)
self.assertTrue("date: '%s'\n" % self.date.isoformat() in nsstr)
self.assertTrue("doc: mydoc\n" in nsstr)
self.assertTrue("full_name: My Laboratory\n" in nsstr)
self.assertTrue("name: mylab\n" in nsstr)
self.assertTrue("schema:\n" in nsstr)
self.assertTrue("doc: Extensions for my lab\n" in nsstr)
self.assertTrue("source: mylab.specs.yaml\n" in nsstr)
self.assertTrue("title: Extensions for my lab\n" in nsstr)
self.assertTrue("version: 0.0.1\n" in nsstr)
self._test_namespace_file()
self._test_extensions_file()

def test_get_name(self):
self.assertEqual(self.ns_name, self.ns_builder.name)

def _test_namespace_file(self):
with open(self.namespace_path, 'r') as file:
match_str = \
"""namespaces:
- author: foo
contact: [email protected]
date: '%s'
doc: mydoc
full_name: My Laboratory
name: mylab
schema:
- doc: Extensions for my lab
source: mylab.extensions.yaml
title: Extensions for my lab
version: 0.0.1
""" % self.date.isoformat() # noqa: E128
nsstr = file.read()
self.assertEqual(nsstr, match_str)


class TestExportSpec(TestSpec):

def test_export(self):
export_spec(self.ns_builder, self.data_types, '.')
self._test_namespace_file()
self._test_extensions_file()

def tearDown(self):
if os.path.exists(self.ext_source_path):
os.remove(self.ext_source_path)
if os.path.exists(self.namespace_path):
os.remove(self.namespace_path)

def _test_namespace_file(self):
with open(self.namespace_path, 'r') as nsfile:
nsstr = nsfile.read()
match_str = \
"""namespaces:
- author: foo
contact: [email protected]
date: '%s'
doc: mydoc
full_name: My Laboratory
name: mylab
schema:
- source: mylab.extensions.yaml
version: 0.0.1
""" % self.date.isoformat() # noqa: E128
self.assertEqual(nsstr, match_str)

def test_missing_data_types(self):
with self.assertWarnsRegex(UserWarning, 'No data types specified. Exiting.'):
export_spec(self.ns_builder, [], '.')

def test_missing_name(self):
self.ns_builder._NamespaceBuilder__ns_args['name'] = None
with self.assertRaisesRegex(RuntimeError, 'Namespace name is required to export specs'):
export_spec(self.ns_builder, self.data_types, '.')

0 comments on commit 3e63b13

Please sign in to comment.