Skip to content

Commit

Permalink
Handle extra or duplicate values passed to docval (#194)
Browse files Browse the repository at this point in the history
  • Loading branch information
rly authored Nov 18, 2019
1 parent 9f9edf8 commit 8ea0416
Show file tree
Hide file tree
Showing 7 changed files with 116 additions and 42 deletions.
4 changes: 2 additions & 2 deletions src/hdmf/build/builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,9 +272,9 @@ def __set_builder(self, builder, obj_type):
returns='the DatasetBuilder object for the dataset', rtype='DatasetBuilder')
def add_dataset(self, **kwargs):
''' Create a dataset and add it to this group '''
kwargs['parent'] = self
kwargs['source'] = self.source
pargs, pkwargs = fmt_docval_args(DatasetBuilder.__init__, kwargs)
pkwargs['parent'] = self
pkwargs['source'] = self.source
builder = DatasetBuilder(*pargs, **pkwargs)
self.set_dataset(builder)
return builder
Expand Down
12 changes: 7 additions & 5 deletions src/hdmf/build/map.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from datetime import datetime
from six import with_metaclass, raise_from, text_type, binary_type, integer_types

from ..utils import docval, getargs, ExtenderMeta, get_docval, fmt_docval_args, call_docval_func
from ..utils import docval, getargs, ExtenderMeta, get_docval, call_docval_func, fmt_docval_args
from ..container import AbstractContainer, Container, Data, DataRegion
from ..spec import Spec, AttributeSpec, DatasetSpec, GroupSpec, LinkSpec, NAME_WILDCARD, NamespaceCatalog, RefSpec,\
SpecReader
Expand Down Expand Up @@ -1448,15 +1448,17 @@ def __get_cls_dict(self, base, addl_fields, name=None, default_name=None):
fields.append({'name': f, 'child': True})
else:
fields.append(f)
if name is not None:

if name is not None: # fixed name is specified in spec, remove it from docval args
docval_args = filter(lambda x: x['name'] != 'name', docval_args)

@docval(*docval_args)
def __init__(self, **kwargs):
pargs, pkwargs = fmt_docval_args(base.__init__, kwargs)
if name is not None:
pkwargs.update(name=name)
base.__init__(self, *pargs, **pkwargs)
kwargs.update(name=name)
pargs, pkwargs = fmt_docval_args(base.__init__, kwargs)
base.__init__(self, *pargs, **pkwargs) # special case: need to pass self to __init__

for f in new_args:
arg_val = kwargs.get(f, None)
if arg_val is not None:
Expand Down
2 changes: 1 addition & 1 deletion src/hdmf/container.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ def parent(self, parent_container):
parent_container.__children.append(self)
parent_container.set_modified()
else:
self.__parent.add_candidate(parent_container, self)
self.__parent.add_candidate(parent_container)
else:
self.__parent = parent_container
if isinstance(parent_container, Container):
Expand Down
5 changes: 2 additions & 3 deletions src/hdmf/monitor.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from abc import ABCMeta, abstractmethod
import six

from .utils import docval, getargs, fmt_docval_args
from .utils import docval, getargs, call_docval_func
from .data_utils import AbstractDataChunkIterator, DataChunkIterator, DataChunk


Expand Down Expand Up @@ -62,8 +62,7 @@ def compute_final_result(self, **kwargs):
class NumSampleCounter(DataChunkProcessor):

def __init__(self, **kwargs):
args, kwargs = fmt_docval_args(DataChunkProcessor.__init__, kwargs)
super(NumSampleCounter, self).__init__(*args, **kwargs)
call_docval_func(super(NumSampleCounter, self).__init__, kwargs)
self.__sample_count = 0

@docval({'name': 'data_chunk', 'type': DataChunk, 'doc': 'a chunk to process'})
Expand Down
73 changes: 44 additions & 29 deletions src/hdmf/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,8 @@ def __parse_args(validator, args, kwargs, enforce_type=True, enforce_shape=True,
:param enforce_type: Boolean indicating whether the type of arguments should be enforced
:param enforce_shape: Boolean indicating whether the dimensions of array arguments
should be enforced if possible.
:param allow_extra: Boolean indicating whether extra keyword arguments are allowed (if False and extra keyword
arguments are specified, then an error is raised).
:return: Dict with:
* 'args' : Dict all arguments where keys are the names and values are the values of the arguments.
Expand All @@ -145,22 +147,36 @@ def __parse_args(validator, args, kwargs, enforce_type=True, enforce_shape=True,
if duplicated:
raise ValueError('The following names are duplicated: {}'.format(duplicated))
try:
if allow_extra: # extra keyword arguments are allowed so do not consider them when checking number of args
# verify only that the number of positional args is <= number of docval specified args
if len(args) > len(validator):
raise TypeError('Expected at most %s arguments, got %s' % (len(validator), len(args)))
else: # verify that the number of positional args + keyword args is <= number of docval specified args
if (len(args) + len(kwargs)) > len(validator):
raise TypeError('Expected at most %s arguments, got %s' % (len(validator), len(args) + len(kwargs)))

# iterate through the docval specification and find a matching value in args / kwargs
it = iter(validator)
arg = next(it)

# catch unsupported keys
allowable_terms = ('name', 'doc', 'type', 'shape', 'default', 'help')
unsupported_terms = set(arg.keys()) - set(allowable_terms)
if unsupported_terms:
raise ValueError('docval for {}: {} are not supported by docval'.format(arg['name'],
list(unsupported_terms)))
# process positional arguments
# process positional arguments of the docval specification (no default value)
while True:
#
if 'default' in arg:
break
argname = arg['name']
argval_set = False
if argname in kwargs:
# if this positional arg is specified by a keyword arg and there are remaining positional args that
# have not yet been matched, then it is undetermined what those positional args match to. thus, raise
# an error
if argsi < len(args):
type_errors.append("got multiple values for argument '%s'" % argname)
argval = kwargs.get(argname)
extras.pop(argname, None)
argval_set = True
Expand All @@ -171,36 +187,35 @@ def __parse_args(validator, args, kwargs, enforce_type=True, enforce_shape=True,
if not argval_set:
type_errors.append("missing argument '%s'" % argname)
else:
if argname in ret:
type_errors.append("'got multiple arguments for '%s" % argname)
else:
if enforce_type:
if not __type_okay(argval, arg['type']):
if argval is None:
fmt_val = (argname, __format_type(arg['type']))
type_errors.append("None is not allowed for '%s' (expected '%s', not None)" % fmt_val)
else:
fmt_val = (argname, type(argval).__name__, __format_type(arg['type']))
type_errors.append("incorrect type for '%s' (got '%s', expected '%s')" % fmt_val)
if enforce_shape and 'shape' in arg:
if enforce_type:
if not __type_okay(argval, arg['type']):
if argval is None:
fmt_val = (argname, __format_type(arg['type']))
type_errors.append("None is not allowed for '%s' (expected '%s', not None)" % fmt_val)
else:
fmt_val = (argname, type(argval).__name__, __format_type(arg['type']))
type_errors.append("incorrect type for '%s' (got '%s', expected '%s')" % fmt_val)
if enforce_shape and 'shape' in arg:
valshape = get_data_shape(argval)
while valshape is None:
if argval is None:
break
if not hasattr(argval, argname):
fmt_val = (argval, argname, arg['shape'])
value_errors.append("cannot check shape of object '%s' for argument '%s' "
"(expected shape '%s')" % fmt_val)
break
# unpack, e.g. if TimeSeries is passed for arg 'data', then TimeSeries.data is checked
argval = getattr(argval, argname)
valshape = get_data_shape(argval)
while valshape is None:
if argval is None:
break
if not hasattr(argval, argname):
fmt_val = (argval, argname, arg['shape'])
value_errors.append("cannot check shape of object '%s' for argument '%s' "
"(expected shape '%s')" % fmt_val)
break
# unpack, e.g. if TimeSeries is passed for arg 'data', then TimeSeries.data is checked
argval = getattr(argval, argname)
valshape = get_data_shape(argval)
if valshape is not None and not __shape_okay_multi(argval, arg['shape']):
fmt_val = (argname, valshape, arg['shape'])
value_errors.append("incorrect shape for '%s' (got '%s', expected '%s')" % fmt_val)
ret[argname] = argval
if valshape is not None and not __shape_okay_multi(argval, arg['shape']):
fmt_val = (argname, valshape, arg['shape'])
value_errors.append("incorrect shape for '%s' (got '%s', expected '%s')" % fmt_val)
ret[argname] = argval
argsi += 1
arg = next(it)

# process arguments of the docval specification with a default value
while True:
argname = arg['name']
if argname in kwargs:
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/test_io_hdf5_h5tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -1034,7 +1034,7 @@ def setUp(self):
self.path = get_temp_filepath()
foo1 = Foo('foo1', [0, 1, 2, 3, 4], "I am foo1", 17, 3.14)
bucket1 = FooBucket('test_bucket1', [foo1])
self.foofile1 = FooFile('test_foofile1', buckets=[bucket1])
self.foofile1 = FooFile(buckets=[bucket1])

with HDF5IO(self.path, manager=_get_manager(), mode='w') as temp_io:
temp_io.write(self.foofile1)
Expand Down Expand Up @@ -1069,7 +1069,7 @@ class HDF5IOWriteNoFile(unittest.TestCase):
def setUp(self):
foo1 = Foo('foo1', [0, 1, 2, 3, 4], "I am foo1", 17, 3.14)
bucket1 = FooBucket('test_bucket1', [foo1])
self.foofile1 = FooFile('test_foofile1', buckets=[bucket1])
self.foofile1 = FooFile(buckets=[bucket1])
self.path = 'test_write_nofile.h5'

def tearDown(self):
Expand Down
58 changes: 58 additions & 0 deletions tests/unit/utils_test/test_docval.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,13 @@ def basic_add2_kw(self, **kwargs):
def basic_only_kw(self, **kwargs):
return kwargs

@docval({'name': 'arg1', 'type': str, 'doc': 'argument1 is a str'},
{'name': 'arg2', 'type': 'int', 'doc': 'argument2 is a int'},
{'name': 'arg3', 'type': bool, 'doc': 'argument3 is a bool. it defaults to False', 'default': False},
allow_extra=True)
def basic_add2_kw_allow_extra(self, **kwargs):
return kwargs


class MyTestSubclass(MyTestClass):

Expand Down Expand Up @@ -350,6 +357,57 @@ def test_extra_kwarg(self):
with self.assertRaises(TypeError):
self.test_obj.basic_add2_kw('a string', 100, bar=1000)

def test_extra_args_pos_only(self):
"""Test that docval raises an error if too many positional
arguments are specified
"""
with self.assertRaisesRegex(TypeError, r'Expected at most 3 arguments, got 4'):
self.test_obj.basic_add2_kw('a string', 100, True, 'extra')

def test_extra_args_pos_kw(self):
"""Test that docval raises an error if too many positional
arguments are specified and a keyword arg is specified
"""
with self.assertRaisesRegex(TypeError, r'Expected at most 3 arguments, got 4'):
self.test_obj.basic_add2_kw('a string', 'extra', 100, arg3=True)

def test_extra_kwargs_pos_kw(self):
"""Test that docval raises an error if extra keyword
arguments are specified
"""
with self.assertRaisesRegex(TypeError, r'Expected at most 3 arguments, got 4'):
self.test_obj.basic_add2_kw('a string', 100, extra='extra', arg3=True)

def test_extra_args_pos_only_ok(self):
"""Test that docval raises an error if too many positional
arguments are specified even if allow_extra is True
"""
with self.assertRaisesRegex(TypeError, r'Expected at most 3 arguments, got 4'):
self.test_obj.basic_add2_kw_allow_extra('a string', 100, True, 'extra', extra='extra')

def test_extra_args_pos_kw_ok(self):
"""Test that docval does not raise an error if too many
keyword arguments are specified and allow_extra is True
"""
kwargs = self.test_obj.basic_add2_kw_allow_extra('a string', 100, True, extra='extra')
self.assertDictEqual(kwargs, {'arg1': 'a string', 'arg2': 100, 'arg3': True, 'extra': 'extra'})

def test_dup_kw(self):
"""Test that docval raises an error if a keyword argument
captures a positional argument before all positional
arguments have been resolved
"""
with self.assertRaisesRegex(TypeError, r"got multiple values for argument 'arg1'"):
self.test_obj.basic_add2_kw('a string', 100, arg1='extra')

def test_extra_args_dup_kw(self):
"""Test that docval raises an error if a keyword argument
captures a positional argument before all positional
arguments have been resolved and allow_extra is True
"""
with self.assertRaisesRegex(TypeError, r"got multiple values for argument 'arg1'"):
self.test_obj.basic_add2_kw_allow_extra('a string', 100, True, arg1='extra')

def test_unsupported_docval_term(self):
"""Test that docval does not allow setting of arguments
marked as unsupported
Expand Down

0 comments on commit 8ea0416

Please sign in to comment.