From ee1684a8a4ba8a4d70fb5ba4e78e1998d92c8ba1 Mon Sep 17 00:00:00 2001 From: Ryan Ly Date: Mon, 18 Nov 2019 13:19:50 -0800 Subject: [PATCH] Fix #200. Fix support for scalar np.bool_ (#203) --- src/hdmf/utils.py | 8 ++++++++ tests/unit/build_tests/test_io_map.py | 16 ++++++++++++++++ tests/unit/utils_test/test_docval.py | 14 ++++++++++++++ 3 files changed, 38 insertions(+) diff --git a/src/hdmf/utils.py b/src/hdmf/utils.py index 289c18210..fd800c872 100644 --- a/src/hdmf/utils.py +++ b/src/hdmf/utils.py @@ -51,6 +51,8 @@ def __type_okay(value, argtype, allow_none=False): return __is_int(value) elif argtype == 'float': return __is_float(value) + elif argtype == 'bool': + return __is_bool(value) return argtype in [cls.__name__ for cls in value.__class__.__mro__] elif isinstance(argtype, type): if argtype == six.text_type: @@ -61,6 +63,8 @@ def __type_okay(value, argtype, allow_none=False): return __is_int(value) elif argtype is float: return __is_float(value) + elif argtype is bool: + return __is_bool(value) return isinstance(value, argtype) elif isinstance(argtype, tuple) or isinstance(argtype, list): return any(__type_okay(value, i) for i in argtype) @@ -100,6 +104,10 @@ def __is_float(value): return any(isinstance(value, i) for i in SUPPORTED_FLOAT_TYPES) +def __is_bool(value): + return isinstance(value, bool) or isinstance(value, np.bool_) + + def __format_type(argtype): if isinstance(argtype, str): return argtype diff --git a/tests/unit/build_tests/test_io_map.py b/tests/unit/build_tests/test_io_map.py index ae839dd73..1405364a2 100644 --- a/tests/unit/build_tests/test_io_map.py +++ b/tests/unit/build_tests/test_io_map.py @@ -701,6 +701,22 @@ def test_numeric_spec(self): self.assertTupleEqual(ret, match) self.assertIs(ret[0].dtype.type, match[1]) + def test_bool_spec(self): + spec_type = 'bool' + spec = DatasetSpec('an example dataset', spec_type, name='data') + + value = np.bool_(True) + ret = ObjectMapper.convert_dtype(spec, value) + match = (value, np.bool_) + self.assertTupleEqual(ret, match) + self.assertIs(type(ret[0]), match[1]) + + value = True + ret = ObjectMapper.convert_dtype(spec, value) + match = (value, np.bool_) + self.assertTupleEqual(ret, match) + self.assertIs(type(ret[0]), match[1]) + if __name__ == '__main__': unittest.main() diff --git a/tests/unit/utils_test/test_docval.py b/tests/unit/utils_test/test_docval.py index ebe23d13c..5b711de6b 100644 --- a/tests/unit/utils_test/test_docval.py +++ b/tests/unit/utils_test/test_docval.py @@ -1,5 +1,6 @@ import unittest from six import text_type +import numpy as np from hdmf.utils import docval, fmt_docval_args, get_docval, popargs @@ -478,6 +479,19 @@ def test_get_docval_none_arg(self): with self.assertRaisesRegex(ValueError, r'Function __init__ has no docval arguments'): get_docval(self.test_obj.__init__, 'arg3') + def test_bool_type(self): + @docval({'name': 'arg1', 'type': bool, 'doc': 'this is a bool'}) + def method(self, **kwargs): + return popargs('arg1', kwargs) + + res = method(self, arg1=True) + self.assertEqual(res, True) + self.assertIsInstance(res, bool) + + res = method(self, arg1=np.bool_(True)) + self.assertEqual(res, np.bool_(True)) + self.assertIsInstance(res, np.bool_) + class TestDocValidatorChain(unittest.TestCase):