diff --git a/src/hdmf/backends/hdf5/h5tools.py b/src/hdmf/backends/hdf5/h5tools.py index 0c050e258..64f93e398 100644 --- a/src/hdmf/backends/hdf5/h5tools.py +++ b/src/hdmf/backends/hdf5/h5tools.py @@ -967,10 +967,11 @@ def __get_ref(self, **kwargs): def __is_ref(self, dtype): if isinstance(dtype, DtypeSpec): return self.__is_ref(dtype.dtype) - elif isinstance(dtype, RefSpec): + if isinstance(dtype, RefSpec): return True - else: + if isinstance(dtype, str): return dtype == DatasetBuilder.OBJECT_REF_TYPE or dtype == DatasetBuilder.REGION_REF_TYPE + return False def __queue_ref(self, func): '''Set aside filling dset with references diff --git a/src/hdmf/build/map.py b/src/hdmf/build/map.py index 51f0f6d1f..0dff9db98 100644 --- a/src/hdmf/build/map.py +++ b/src/hdmf/build/map.py @@ -471,12 +471,21 @@ def __check_edgecases(cls, spec, value): return value, spec.dtype if isinstance(value, DataIO): return value, cls.convert_dtype(spec, value.data)[1] - if spec.dtype is None: - return value, None - if spec.dtype == 'numeric': - return value, None - if type(value) in cls.__no_convert: - return value, None + if spec.dtype is None or spec.dtype == 'numeric' or type(value) in cls.__no_convert: + # infer type from value + if hasattr(value, 'dtype'): # covers numpy types, AbstractDataChunkIterator + return value, value.dtype + if isinstance(value, (list, tuple)): + if len(value) == 0: + msg = "cannot infer dtype of empty list or tuple. Please use numpy array with specified dtype." + raise ValueError(msg) + return value, cls.__check_edgecases(spec, value[0])[1] # infer dtype from first element + ret_dtype = type(value) + if ret_dtype is str: + ret_dtype = 'utf8' + elif ret_dtype is bytes: + ret_dtype = 'ascii' + return value, ret_dtype if isinstance(spec.dtype, RefSpec): if not isinstance(value, ReferenceBuilder): msg = "got RefSpec for value of type %s" % type(value) diff --git a/src/hdmf/spec/spec.py b/src/hdmf/spec/spec.py index 088cfcbc8..8d93f9a96 100644 --- a/src/hdmf/spec/spec.py +++ b/src/hdmf/spec/spec.py @@ -305,8 +305,8 @@ def __init__(self, **kwargs): getargs('name', 'doc', 'parent', 'quantity', 'attributes', 'linkable', 'data_type_def', 'data_type_inc', kwargs) if name == NAME_WILDCARD and data_type_def is None and data_type_inc is None: - raise ValueError("Cannot create Group or Dataset spec with wildcard name \ - without specifying 'data_type_def' and/or 'data_type_inc'") + raise ValueError("Cannot create Group or Dataset spec with wildcard name " + "without specifying 'data_type_def' and/or 'data_type_inc'") super(BaseStorageSpec, self).__init__(doc, name=name, parent=parent) default_name = getargs('default_name', kwargs) if default_name: diff --git a/tests/unit/build_tests/test_io_map.py b/tests/unit/build_tests/test_io_map.py index 724a661dd..04c10771d 100644 --- a/tests/unit/build_tests/test_io_map.py +++ b/tests/unit/build_tests/test_io_map.py @@ -1,13 +1,16 @@ import unittest2 as unittest import re -from hdmf.spec import GroupSpec, AttributeSpec, DatasetSpec, SpecCatalog, SpecNamespace, NamespaceCatalog +from hdmf.spec import GroupSpec, AttributeSpec, DatasetSpec, SpecCatalog, SpecNamespace, NamespaceCatalog, RefSpec from hdmf.build import GroupBuilder, DatasetBuilder, ObjectMapper, BuildManager, TypeMap, LinkBuilder from hdmf import Container from hdmf.utils import docval, getargs, get_docval +from hdmf.data_utils import DataChunkIterator +from hdmf.backends.hdf5 import H5DataIO from abc import ABCMeta from six import with_metaclass +import numpy as np from tests.unit.test_utils import CORE_NAMESPACE @@ -496,5 +499,149 @@ def test_build_child_link(self): self.assertDictEqual(bar2_builder, bar2_expected) +class TestConvertDtype(unittest.TestCase): + + def test_value_none(self): + spec = DatasetSpec('an example dataset', 'int', name='data') + self.assertTupleEqual(ObjectMapper.convert_dtype(spec, None), (None, 'int')) + + spec = DatasetSpec('an example dataset', RefSpec(reftype='object', target_type='int'), name='data') + self.assertTupleEqual(ObjectMapper.convert_dtype(spec, None), (None, 'object')) + + def test_convert_higher_precision(self): + """Test that passing a data type with a precision <= specified returns the higher precision type""" + spec_type = 'float64' + value_types = ['float', 'float32', 'double', 'float64'] + self.convert_higher_precision_helper(spec_type, value_types) + + spec_type = 'int64' + value_types = ['long', 'int64', 'uint64', 'int', 'int32', 'int16', 'int8'] + self.convert_higher_precision_helper(spec_type, value_types) + + spec_type = 'int32' + value_types = ['int32', 'int16', 'int8'] + self.convert_higher_precision_helper(spec_type, value_types) + + spec_type = 'int16' + value_types = ['int16', 'int8'] + self.convert_higher_precision_helper(spec_type, value_types) + + spec_type = 'uint32' + value_types = ['uint32', 'uint16', 'uint8'] + self.convert_higher_precision_helper(spec_type, value_types) + + def convert_higher_precision_helper(self, spec_type, value_types): + data = 2 + spec = DatasetSpec('an example dataset', spec_type, name='data') + match = (np.dtype(spec_type).type(data), np.dtype(spec_type)) + for dtype in value_types: + value = np.dtype(dtype).type(data) + with self.subTest(dtype=dtype): + ret = ObjectMapper.convert_dtype(spec, value) + self.assertTupleEqual(ret, match) + self.assertEqual(ret[0].dtype, match[1]) + + def test_keep_higher_precision(self): + """Test that passing a data type with a precision >= specified return the given type""" + spec_type = 'float' + value_types = ['double', 'float64'] + self.keep_higher_precision_helper(spec_type, value_types) + + spec_type = 'int' + value_types = ['int64'] + self.keep_higher_precision_helper(spec_type, value_types) + + spec_type = 'int8' + value_types = ['long', 'int64', 'int', 'int32', 'int16'] + self.keep_higher_precision_helper(spec_type, value_types) + + spec_type = 'uint' + value_types = ['uint64'] + self.keep_higher_precision_helper(spec_type, value_types) + + spec_type = 'uint8' + value_types = ['uint64', 'uint32', 'uint', 'uint16'] + self.keep_higher_precision_helper(spec_type, value_types) + + def keep_higher_precision_helper(self, spec_type, value_types): + data = 2 + spec = DatasetSpec('an example dataset', spec_type, name='data') + for dtype in value_types: + value = np.dtype(dtype).type(data) + match = (value, np.dtype(dtype)) + with self.subTest(dtype=dtype): + ret = ObjectMapper.convert_dtype(spec, value) + self.assertTupleEqual(ret, match) + self.assertEqual(ret[0].dtype, match[1]) + + def test_no_spec(self): + spec_type = None + spec = DatasetSpec('an example dataset', spec_type, name='data') + + value = [1, 2, 3] + ret = ObjectMapper.convert_dtype(spec, value) + match = (value, int) + self.assertTupleEqual(ret, match) + self.assertEqual(type(ret[0][0]), match[1]) + + value = np.uint64(4) + ret = ObjectMapper.convert_dtype(spec, value) + match = (value, np.uint64) + self.assertTupleEqual(ret, match) + self.assertEqual(type(ret[0]), match[1]) + + value = 'hello' + ret = ObjectMapper.convert_dtype(spec, value) + match = (value, 'utf8') + self.assertTupleEqual(ret, match) + self.assertEqual(type(ret[0]), str) + + value = bytes('hello', encoding='utf-8') + ret = ObjectMapper.convert_dtype(spec, value) + match = (value, 'ascii') + self.assertTupleEqual(ret, match) + self.assertEqual(type(ret[0]), bytes) + + value = DataChunkIterator(data=[1, 2, 3]) + ret = ObjectMapper.convert_dtype(spec, value) + match = (value, np.dtype(int)) + self.assertTupleEqual(ret, match) + self.assertEqual(ret[0].dtype, match[1]) + + value = DataChunkIterator(data=[1., 2., 3.]) + ret = ObjectMapper.convert_dtype(spec, value) + match = (value, np.dtype(float)) + self.assertTupleEqual(ret, match) + self.assertEqual(ret[0].dtype, match[1]) + + value = H5DataIO(np.arange(30).reshape(5, 2, 3)) + ret = ObjectMapper.convert_dtype(spec, value) + match = (value, np.dtype(int)) + self.assertTupleEqual(ret, match) + self.assertEqual(ret[0].dtype, match[1]) + + value = H5DataIO(['foo' 'bar']) + ret = ObjectMapper.convert_dtype(spec, value) + match = (value, 'utf8') + self.assertTupleEqual(ret, match) + self.assertEqual(type(ret[0].data[0]), str) + + def test_numeric_spec(self): + spec_type = 'numeric' + spec = DatasetSpec('an example dataset', spec_type, name='data') + + value = np.uint64(4) + ret = ObjectMapper.convert_dtype(spec, value) + match = (value, np.uint64) + self.assertTupleEqual(ret, match) + self.assertEqual(type(ret[0]), match[1]) + + value = DataChunkIterator(data=[1, 2, 3]) + ret = ObjectMapper.convert_dtype(spec, value) + match = (value, np.dtype(int)) + self.assertTupleEqual(ret, match) + self.assertEqual(ret[0].dtype, match[1]) + + if __name__ == '__main__': unittest.main() diff --git a/tests/unit/test_io_hdf5_h5tools.py b/tests/unit/test_io_hdf5_h5tools.py index efd3e9903..53305b28d 100644 --- a/tests/unit/test_io_hdf5_h5tools.py +++ b/tests/unit/test_io_hdf5_h5tools.py @@ -258,9 +258,10 @@ def test_write_dataset_iterable_multidimensional_array_compression(self): ############################################# def test_write_dataset_data_chunk_iterator(self): dci = DataChunkIterator(data=np.arange(10), buffer_size=2) - self.io.write_dataset(self.f, DatasetBuilder('test_dataset', dci, attributes={})) + self.io.write_dataset(self.f, DatasetBuilder('test_dataset', dci, attributes={}, dtype=dci.dtype)) dset = self.f['test_dataset'] self.assertListEqual(dset[:].tolist(), list(range(10))) + self.assertEqual(dset[:].dtype, dci.dtype) def test_write_dataset_data_chunk_iterator_with_compression(self): dci = DataChunkIterator(data=np.arange(10), buffer_size=2)