Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add "array-like" to _validate_type() #11713

Merged
merged 6 commits into from
May 31, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion doc/changes/latest.inc
Original file line number Diff line number Diff line change
Expand Up @@ -34,4 +34,4 @@ Bugs

API changes
~~~~~~~~~~~
- None yet
- The ``baseline`` argument can now be array-like (e.g. ``list``, ``tuple``, ``np.ndarray``, ...) instead of only a ``tuple`` (:gh:`11713` by `Clemens Brunner`_)
30 changes: 15 additions & 15 deletions mne/baseline.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import numpy as np

from .utils import logger, verbose, _check_option
from .utils import logger, verbose, _check_option, _validate_type


def _log_rescale(baseline, mode="mean"):
Expand Down Expand Up @@ -143,42 +143,42 @@ def fun(d, m):


def _check_baseline(baseline, times, sfreq, on_baseline_outside_data="raise"):
"""Check if the baseline is valid, and adjust it if requested.
"""Check if the baseline is valid and adjust it if requested.

``None`` values inside the baseline parameter will be replaced with
``times[0]`` and ``times[-1]``.
``None`` values inside ``baseline`` will be replaced with ``times[0]`` and
``times[-1]``.

Parameters
----------
baseline : tuple | None
baseline : array-like, shape (2,) | None
Beginning and end of the baseline period, in seconds. If ``None``,
assume no baseline and return immediately.
times : array
The time points.
sfreq : float
The sampling rate.
on_baseline_outside_data : 'raise' | 'info' | 'adjust'
What do do if the baseline period exceeds the data.
What to do if the baseline period exceeds the data.
If ``'raise'``, raise an exception (default).
If ``'info'``, log an info message.
If ``'adjust'``, adjust the baseline such that it's within the data
range again.
If ``'adjust'``, adjust the baseline such that it is within the data range.

Returns
-------
(baseline_tmin, baseline_tmax) | None
The baseline with ``None`` values replaced with times, and with
adjusted times if ``on_baseline_outside_data='adjust'``; or ``None``
if the ``baseline`` parameter is ``None``.

The baseline with ``None`` values replaced with times, and with adjusted times
if ``on_baseline_outside_data='adjust'``; or ``None``, if ``baseline`` is
``None``.
"""
if baseline is None:
return None

if not isinstance(baseline, tuple) or len(baseline) != 2:
_validate_type(baseline, "array-like")
baseline = tuple(baseline)

if len(baseline) != 2:
raise ValueError(
f"`baseline={baseline}` is an invalid argument, must "
f"be a tuple of length 2 or None"
f"baseline must have exactly two elements (got {len(baseline)})."
)

tmin, tmax = times[0], times[-1]
Expand Down
2 changes: 1 addition & 1 deletion mne/tests/test_epochs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1349,7 +1349,7 @@ def test_epochs_io_preload(tmp_path, preload):
epochs_no_bl.save(temp_fname_no_bl, overwrite=True)
epochs_read = read_epochs(temp_fname)
epochs_no_bl_read = read_epochs(temp_fname_no_bl)
with pytest.raises(ValueError, match="invalid"):
with pytest.raises(ValueError, match="exactly two elements"):
epochs.apply_baseline(baseline=[1, 2, 3])
epochs_with_bl = epochs_no_bl_read.copy().apply_baseline(baseline)
assert isinstance(epochs_with_bl, BaseEpochs)
Expand Down
7 changes: 4 additions & 3 deletions mne/utils/check.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# License: BSD-3-Clause

from builtins import input # no-op here but facilitates testing
from collections.abc import Sequence
from difflib import get_close_matches
from importlib import import_module
import operator
Expand Down Expand Up @@ -525,6 +526,7 @@ def __instancecheck__(cls, other):
"path-like": path_like,
"int-like": (int_like,),
"callable": (_Callable(),),
"array-like": (Sequence, np.ndarray),
hoechenberger marked this conversation as resolved.
Show resolved Hide resolved
}


Expand All @@ -538,9 +540,8 @@ def _validate_type(item, types=None, item_name=None, type_name=None, *, extra=""
types : type | str | tuple of types | tuple of str
The types to be checked against.
If str, must be one of {'int', 'int-like', 'str', 'numeric', 'info',
'path-like', 'callable'}.
If a tuple of str is passed, use 'int-like' and not 'int' for
integers.
'path-like', 'callable', 'array-like'}.
If a tuple of str is passed, use 'int-like' and not 'int' for integers.
item_name : str | None
Name of the item to show inside the error message.
type_name : str | None
Expand Down