diff --git a/src/hdmf/spec/write.py b/src/hdmf/spec/write.py index 0e28dcf84..ed945679d 100644 --- a/src/hdmf/spec/write.py +++ b/src/hdmf/spec/write.py @@ -2,6 +2,7 @@ import json import ruamel.yaml as yaml import os.path +import warnings from collections import OrderedDict from six import with_metaclass from abc import ABCMeta, abstractmethod @@ -69,7 +70,8 @@ def my_represent_none(self, data): return self.represent_scalar(u'tag:yaml.org,2002:null', u'null') yaml.representer.RoundTripRepresenter.add_representer(type(None), my_represent_none) - order = ['neurodata_type_def', 'neurodata_type_inc', 'name', 'default_name', + order = ['neurodata_type_def', 'neurodata_type_inc', 'data_type_def', 'data_type_inc', + 'name', 'default_name', 'dtype', 'target_type', 'dims', 'shape', 'default_value', 'value', 'doc', 'required', 'quantity', 'attributes', 'datasets', 'groups', 'links'] if isinstance(obj, dict): @@ -218,3 +220,42 @@ def add_spec(self, **kwargs): self.setdefault('groups', list()).append(spec) elif isinstance(spec, DatasetSpec): self.setdefault('datasets', list()).append(spec) + + +def export_spec(ns_builder, new_data_types, output_dir): + """ + Create YAML specification files for a new namespace and extensions with + the given data type specs. + + Args: + ns_builder - NamespaceBuilder instance used to build the + namespace and extension + new_data_types - Iterable of specs that represent new data types + to be added + """ + + if len(new_data_types) == 0: + warnings.warn('No data types specified. Exiting.') + return + + if not ns_builder.name: + raise RuntimeError('Namespace name is required to export specs') + + ns_path = ns_builder.name + '.namespace.yaml' + ext_path = ns_builder.name + '.extensions.yaml' + + if len(new_data_types) > 1: + pluralize = 's' + else: + pluralize = '' + + print('Creating file {output_dir}/{ext_path} with {new_data_types_count} data type{pluralize}'.format( + pluralize=pluralize, output_dir=output_dir, ext_path=ext_path, + new_data_types_count=len(new_data_types))) + + for data_type in new_data_types: + ns_builder.add_spec(ext_path, data_type) + + print('Creating file {output_dir}/{ns_path}'.format(output_dir=output_dir, ns_path=ns_path)) + + ns_builder.export(ns_path, outdir=output_dir) diff --git a/tests/unit/spec_tests/test_spec_write.py b/tests/unit/spec_tests/test_spec_write.py index 01b9608eb..859c61140 100644 --- a/tests/unit/spec_tests/test_spec_write.py +++ b/tests/unit/spec_tests/test_spec_write.py @@ -2,13 +2,12 @@ import os import datetime -from hdmf.spec.write import NamespaceBuilder, YAMLSpecWriter +from hdmf.spec.write import NamespaceBuilder, YAMLSpecWriter, export_spec from hdmf.spec.namespace import SpecNamespace, NamespaceCatalog from hdmf.spec.spec import GroupSpec -class TestNamespaceBuilder(unittest.TestCase): - NS_NAME = 'test_ns' +class TestSpec(unittest.TestCase): def setUp(self): # create a builder for the namespace @@ -43,15 +42,40 @@ def setUp(self): dtype='float', name='testdata') + self.data_types = [ext1, ext2] + # add the extension - self.ext_source_path = 'mylab.specs.yaml' - self.ns_builder.add_spec(source=self.ext_source_path, spec=ext1) - self.ns_builder.add_spec(source=self.ext_source_path, spec=ext2) + self.ext_source_path = 'mylab.extensions.yaml' + self.namespace_path = 'mylab.namespace.yaml' + + def _test_extensions_file(self): + with open(self.ext_source_path, 'r') as file: + match_str = \ +"""groups: +- data_type_def: MyDataSeries + doc: A custom DataSeries interface +- data_type_def: MyExtendedMyDataSeries + data_type_inc: MyDataSeries + doc: An extension of a DataSeries interface + datasets: + - name: testdata + dtype: float + doc: test +""" # noqa: E128 + nsstr = file.read() + self.assertEqual(nsstr, match_str) + + +class TestNamespaceBuilder(TestSpec): + NS_NAME = 'test_ns' + + def setUp(self): + super(TestNamespaceBuilder, self).setUp() + for data_type in self.data_types: + self.ns_builder.add_spec(source=self.ext_source_path, spec=data_type) self.ns_builder.add_source(source=self.ext_source_path, doc='Extensions for my lab', title='My lab extensions') - - self.namespace_path = 'mylab.namespace.yaml' self.ns_builder.export(self.namespace_path) def tearDown(self): @@ -59,23 +83,29 @@ def tearDown(self): os.remove(self.ext_source_path) if os.path.exists(self.namespace_path): os.remove(self.namespace_path) - pass def test_export_namespace(self): - with open(self.namespace_path, 'r') as nsfile: - nsstr = nsfile.read() - self.assertTrue(nsstr.startswith("namespaces:\n")) - self.assertTrue("author: foo\n" in nsstr) - self.assertTrue("contact: foo@bar.com\n" in nsstr) - self.assertTrue("date: '%s'\n" % self.date.isoformat() in nsstr) - self.assertTrue("doc: mydoc\n" in nsstr) - self.assertTrue("full_name: My Laboratory\n" in nsstr) - self.assertTrue("name: mylab\n" in nsstr) - self.assertTrue("schema:\n" in nsstr) - self.assertTrue("doc: Extensions for my lab\n" in nsstr) - self.assertTrue("source: mylab.specs.yaml\n" in nsstr) - self.assertTrue("title: Extensions for my lab\n" in nsstr) - self.assertTrue("version: 0.0.1\n" in nsstr) + self._test_namespace_file() + self._test_extensions_file() + + def _test_namespace_file(self): + with open(self.namespace_path, 'r') as file: + match_str = \ +"""namespaces: +- author: foo + contact: foo@bar.com + date: '%s' + doc: mydoc + full_name: My Laboratory + name: mylab + schema: + - doc: Extensions for my lab + source: mylab.extensions.yaml + title: Extensions for my lab + version: 0.0.1 +""" % self.date.isoformat() # noqa: E128 + nsstr = file.read() + self.assertEqual(nsstr, match_str) def test_read_namespace(self): ns_catalog = NamespaceCatalog() @@ -88,7 +118,7 @@ def test_read_namespace(self): self.assertEqual(loaded_ns.name, "mylab") self.assertEqual(loaded_ns.date, self.date.isoformat()) self.assertDictEqual(loaded_ns.schema[0], {'doc': 'Extensions for my lab', - 'source': 'mylab.specs.yaml', + 'source': 'mylab.extensions.yaml', 'title': 'Extensions for my lab'}) self.assertEqual(loaded_ns.version, "0.0.1") @@ -96,69 +126,33 @@ def test_get_source_files(self): ns_catalog = NamespaceCatalog() ns_catalog.load_namespaces(self.namespace_path, resolve=True) loaded_ns = ns_catalog.get_namespace(self.ns_name) - self.assertListEqual(loaded_ns.get_source_files(), ['mylab.specs.yaml']) + self.assertListEqual(loaded_ns.get_source_files(), ['mylab.extensions.yaml']) def test_get_source_description(self): ns_catalog = NamespaceCatalog() ns_catalog.load_namespaces(self.namespace_path, resolve=True) loaded_ns = ns_catalog.get_namespace(self.ns_name) - descr = loaded_ns.get_source_description('mylab.specs.yaml') + descr = loaded_ns.get_source_description('mylab.extensions.yaml') self.assertDictEqual(descr, {'doc': 'Extensions for my lab', - 'source': 'mylab.specs.yaml', + 'source': 'mylab.extensions.yaml', 'title': 'Extensions for my lab'}) -class TestYAMLSpecWrite(unittest.TestCase): +class TestYAMLSpecWrite(TestSpec): def setUp(self): - # create a builder for the namespace - self.ns_name = "mylab" - self.date = datetime.datetime.now() - - self.ns_builder = NamespaceBuilder(doc="mydoc", - name=self.ns_name, - full_name="My Laboratory", - version="0.0.1", - author="foo", - contact="foo@bar.com", - namespace_cls=SpecNamespace, - date=self.date) - - # create extensions - ext1 = GroupSpec('A custom DataSeries interface', - attributes=[], - datasets=[], - groups=[], - data_type_inc=None, - data_type_def='MyDataSeries') - - ext2 = GroupSpec('An extension of a DataSeries interface', - attributes=[], - datasets=[], - groups=[], - data_type_inc='MyDataSeries', - data_type_def='MyExtendedMyDataSeries') - - ext2.add_dataset(doc='test', - dtype='float', - name='testdata') - - # add the extension - self.ext_source_path = 'mylab.specs.yaml' - self.ns_builder.add_spec(source=self.ext_source_path, spec=ext1) - self.ns_builder.add_spec(source=self.ext_source_path, spec=ext2) + super(TestYAMLSpecWrite, self).setUp() + for data_type in self.data_types: + self.ns_builder.add_spec(source=self.ext_source_path, spec=data_type) self.ns_builder.add_source(source=self.ext_source_path, doc='Extensions for my lab', title='My lab extensions') - self.namespace_path = 'mylab.namespace.yaml' - def tearDown(self): if os.path.exists(self.ext_source_path): os.remove(self.ext_source_path) if os.path.exists(self.namespace_path): os.remove(self.namespace_path) - pass def test_init(self): temp = YAMLSpecWriter('.') @@ -167,20 +161,67 @@ def test_init(self): def test_write_namespace(self): temp = YAMLSpecWriter() self.ns_builder.export(self.namespace_path, writer=temp) - with open(self.namespace_path, 'r') as nsfile: - nsstr = nsfile.read() - self.assertTrue(nsstr.startswith("namespaces:\n")) - self.assertTrue("author: foo\n" in nsstr) - self.assertTrue("contact: foo@bar.com\n" in nsstr) - self.assertTrue("date: '%s'\n" % self.date.isoformat() in nsstr) - self.assertTrue("doc: mydoc\n" in nsstr) - self.assertTrue("full_name: My Laboratory\n" in nsstr) - self.assertTrue("name: mylab\n" in nsstr) - self.assertTrue("schema:\n" in nsstr) - self.assertTrue("doc: Extensions for my lab\n" in nsstr) - self.assertTrue("source: mylab.specs.yaml\n" in nsstr) - self.assertTrue("title: Extensions for my lab\n" in nsstr) - self.assertTrue("version: 0.0.1\n" in nsstr) + self._test_namespace_file() + self._test_extensions_file() def test_get_name(self): self.assertEqual(self.ns_name, self.ns_builder.name) + + def _test_namespace_file(self): + with open(self.namespace_path, 'r') as file: + match_str = \ +"""namespaces: +- author: foo + contact: foo@bar.com + date: '%s' + doc: mydoc + full_name: My Laboratory + name: mylab + schema: + - doc: Extensions for my lab + source: mylab.extensions.yaml + title: Extensions for my lab + version: 0.0.1 +""" % self.date.isoformat() # noqa: E128 + nsstr = file.read() + self.assertEqual(nsstr, match_str) + + +class TestExportSpec(TestSpec): + + def test_export(self): + export_spec(self.ns_builder, self.data_types, '.') + self._test_namespace_file() + self._test_extensions_file() + + def tearDown(self): + if os.path.exists(self.ext_source_path): + os.remove(self.ext_source_path) + if os.path.exists(self.namespace_path): + os.remove(self.namespace_path) + + def _test_namespace_file(self): + with open(self.namespace_path, 'r') as nsfile: + nsstr = nsfile.read() + match_str = \ +"""namespaces: +- author: foo + contact: foo@bar.com + date: '%s' + doc: mydoc + full_name: My Laboratory + name: mylab + schema: + - source: mylab.extensions.yaml + version: 0.0.1 +""" % self.date.isoformat() # noqa: E128 + self.assertEqual(nsstr, match_str) + + def test_missing_data_types(self): + with self.assertWarnsRegex(UserWarning, 'No data types specified. Exiting.'): + export_spec(self.ns_builder, [], '.') + + def test_missing_name(self): + self.ns_builder._NamespaceBuilder__ns_args['name'] = None + with self.assertRaisesRegex(RuntimeError, 'Namespace name is required to export specs'): + export_spec(self.ns_builder, self.data_types, '.')