Skip to content

Commit

Permalink
Fix #200. Fix support for scalar np.bool_ (#203)
Browse files Browse the repository at this point in the history
  • Loading branch information
rly authored Nov 18, 2019
1 parent 42026c5 commit ee1684a
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 0 deletions.
8 changes: 8 additions & 0 deletions src/hdmf/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ def __type_okay(value, argtype, allow_none=False):
return __is_int(value)
elif argtype == 'float':
return __is_float(value)
elif argtype == 'bool':
return __is_bool(value)
return argtype in [cls.__name__ for cls in value.__class__.__mro__]
elif isinstance(argtype, type):
if argtype == six.text_type:
Expand All @@ -61,6 +63,8 @@ def __type_okay(value, argtype, allow_none=False):
return __is_int(value)
elif argtype is float:
return __is_float(value)
elif argtype is bool:
return __is_bool(value)
return isinstance(value, argtype)
elif isinstance(argtype, tuple) or isinstance(argtype, list):
return any(__type_okay(value, i) for i in argtype)
Expand Down Expand Up @@ -100,6 +104,10 @@ def __is_float(value):
return any(isinstance(value, i) for i in SUPPORTED_FLOAT_TYPES)


def __is_bool(value):
return isinstance(value, bool) or isinstance(value, np.bool_)


def __format_type(argtype):
if isinstance(argtype, str):
return argtype
Expand Down
16 changes: 16 additions & 0 deletions tests/unit/build_tests/test_io_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -701,6 +701,22 @@ def test_numeric_spec(self):
self.assertTupleEqual(ret, match)
self.assertIs(ret[0].dtype.type, match[1])

def test_bool_spec(self):
spec_type = 'bool'
spec = DatasetSpec('an example dataset', spec_type, name='data')

value = np.bool_(True)
ret = ObjectMapper.convert_dtype(spec, value)
match = (value, np.bool_)
self.assertTupleEqual(ret, match)
self.assertIs(type(ret[0]), match[1])

value = True
ret = ObjectMapper.convert_dtype(spec, value)
match = (value, np.bool_)
self.assertTupleEqual(ret, match)
self.assertIs(type(ret[0]), match[1])


if __name__ == '__main__':
unittest.main()
14 changes: 14 additions & 0 deletions tests/unit/utils_test/test_docval.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import unittest
from six import text_type
import numpy as np

from hdmf.utils import docval, fmt_docval_args, get_docval, popargs

Expand Down Expand Up @@ -478,6 +479,19 @@ def test_get_docval_none_arg(self):
with self.assertRaisesRegex(ValueError, r'Function __init__ has no docval arguments'):
get_docval(self.test_obj.__init__, 'arg3')

def test_bool_type(self):
@docval({'name': 'arg1', 'type': bool, 'doc': 'this is a bool'})
def method(self, **kwargs):
return popargs('arg1', kwargs)

res = method(self, arg1=True)
self.assertEqual(res, True)
self.assertIsInstance(res, bool)

res = method(self, arg1=np.bool_(True))
self.assertEqual(res, np.bool_(True))
self.assertIsInstance(res, np.bool_)


class TestDocValidatorChain(unittest.TestCase):

Expand Down

0 comments on commit ee1684a

Please sign in to comment.