From 0c971b58de488c3f8304d16240d9589fcbb8b902 Mon Sep 17 00:00:00 2001 From: Eric Larson Date: Thu, 26 Sep 2024 12:50:25 -0400 Subject: [PATCH 1/4] ENH: Better validation of info and subject_info --- doc/changes/bugfix.rst | 2 + mne/_fiff/meas_info.py | 292 ++++++++++++++++++------------ mne/_fiff/tests/test_meas_info.py | 20 +- mne/io/cnt/cnt.py | 1 + mne/io/persyst/persyst.py | 1 + 5 files changed, 197 insertions(+), 119 deletions(-) create mode 100644 doc/changes/bugfix.rst diff --git a/doc/changes/bugfix.rst b/doc/changes/bugfix.rst new file mode 100644 index 00000000000..c4fa57e9100 --- /dev/null +++ b/doc/changes/bugfix.rst @@ -0,0 +1,2 @@ +Fix bug where invalid data types (e.g., ``np.ndarray``s) could be used in some +:class:`mne.io.Info` fields like ``info["subject_info"]["weight"]``, by `Eric Larson`_. \ No newline at end of file diff --git a/mne/_fiff/meas_info.py b/mne/_fiff/meas_info.py index 2759d2332a3..9a1b7122b2d 100644 --- a/mne/_fiff/meas_info.py +++ b/mne/_fiff/meas_info.py @@ -5,10 +5,12 @@ import contextlib import datetime import operator +import re import string from collections import Counter, OrderedDict from collections.abc import Mapping from copy import deepcopy +from functools import partial from io import BytesIO from textwrap import shorten @@ -305,6 +307,9 @@ def _unique_channel_names(ch_names, max_length=None, verbose=None): return ch_names +# %% Mixin classes + + class MontageMixin: """Mixin for Montage getting and setting.""" @@ -922,6 +927,152 @@ def get_channel_types(self, picks=None, unique=False, only_data_chs=False): return ch_types +# %% ValidatedDict class + + +class ValidatedDict(dict): + _attributes = {} # subclasses should set this to validated attributes + + def __init__(self, *args, **kwargs): + self._unlocked = True + super().__init__(*args, **kwargs) + self._unlocked = False + + def __getstate__(self): + """Get state (for pickling).""" + return {"_unlocked": self._unlocked} + + def __setstate__(self, state): + """Set state (for pickling).""" + self._unlocked = state["_unlocked"] + + def __setitem__(self, key, val): + """Attribute setter.""" + # During unpickling, the _unlocked attribute has not been set, so + # let __setstate__ do it later and act unlocked now + unlocked = getattr(self, "_unlocked", True) + if key in self._attributes: + if isinstance(self._attributes[key], str): + if not unlocked: + raise RuntimeError(self._attributes[key]) + else: + val = self._attributes[key]( + val, info=self + ) # attribute checker function + else: + class_name = self.__class__.__name__ + extra = "" + if "temp" in self._attributes: + var_name = _camel_to_snake(class_name) + extra = ( + f"You can set {var_name}['temp'] to store temporary objects in " + f"{class_name} instances, but these will not survive an I/O " + "round-trip." + ) + raise RuntimeError( + f"{class_name} does not support directly setting the key {repr(key)}. " + + extra + ) + super().__setitem__(key, val) + + def update(self, other=None, **kwargs): + """Update method using __setitem__().""" + iterable = other.items() if isinstance(other, Mapping) else other + if other is not None: + for key, val in iterable: + self[key] = val + for key, val in kwargs.items(): + self[key] = val + + def copy(self): + """Copy the instance. + + Returns + ------- + info : instance of Info + The copied info. + """ + return deepcopy(self) + + def __repr__(self): + """Return a string representation.""" + mapping = ", ".join(f"{key}: {val}" for key, val in self.items()) + return f"<{_camel_to_snake(self.__class__.__name__)} | {mapping}>" + + +# %% Subject info + + +def _check_types(x, *, info, name, types, cast=None): + _validate_type(x, types, name) + if cast is not None and x is not None: + x = cast(x) + return x + + +class SubjectInfo(ValidatedDict): + _attributes = { + "id": partial(_check_types, name='subject_info["id"]', types=int), + "his_id": partial(_check_types, name='subject_info["his_id"]', types=str), + "last_name": partial(_check_types, name='subject_info["last_name"]', types=str), + "first_name": partial( + _check_types, name='subject_info["first_name"]', types=str + ), + "middle_name": partial( + _check_types, name='subject_info["middle_name"]', types=str + ), + "birthday": partial( + _check_types, name='subject_info["birthday"]', types=(datetime.date, None) + ), + "sex": partial(_check_types, name='subject_info["sex"]', types=int), + "hand": partial(_check_types, name='subject_info["hand"]', types=int), + "weight": partial( + _check_types, name='subject_info["weight"]', types="numeric", cast=float + ), + "height": partial( + _check_types, name='subject_info["height"]', types="numeric", cast=float + ), + } + + def __init__(self, initial): + _validate_type(initial, dict, "subject_info") + super().__init__() + for key, val in initial.items(): + self[key] = val + + +class HeliumInfo(ValidatedDict): + _attributes = { + "he_level_raw": partial( + _check_types, + name='helium_info["he_level_raw"]', + types="numeric", + cast=float, + ), + "helium_level": partial( + _check_types, + name='helium_info["helium_level"]', + types="numeric", + cast=float, + ), + "orig_file_guid": partial( + _check_types, name='helium_info["orig_file_guid"]', types=str + ), + "meas_date": partial( + _check_types, name='helium_info["meas_date"]', types=datetime.datetime + ), + } + + def __init__(self, initial): + _validate_type(initial, dict, "subject_info") + super().__init__() + for key, val in initial.items(): + self[key] = val + + +# %% Info class and helpers + + def _format_trans(obj, key): from ..transforms import Transform @@ -993,11 +1144,6 @@ def _check_bads(bads, *, info): return MNEBadsList(bads=bads, info=info) -def _check_description(description, *, info): - _validate_type(description, (None, str), "info['description']") - return description - - def _check_dev_head_t(dev_head_t, *, info): from ..transforms import Transform, _ensure_trans @@ -1007,62 +1153,8 @@ def _check_dev_head_t(dev_head_t, *, info): return dev_head_t -def _check_experimenter(experimenter, *, info): - _validate_type(experimenter, (None, str), "experimenter") - return experimenter - - -def _check_line_freq(line_freq, *, info): - _validate_type(line_freq, (None, "numeric"), "line_freq") - line_freq = float(line_freq) if line_freq is not None else line_freq - return line_freq - - -def _check_subject_info(subject_info, *, info): - _validate_type(subject_info, (None, dict), "subject_info") - if isinstance(subject_info, dict): - if "birthday" in subject_info: - _validate_type( - subject_info["birthday"], - (datetime.date, None), - "subject_info['birthday']", - ) - return subject_info - - -def _check_device_info(device_info, *, info): - _validate_type( - device_info, - ( - None, - dict, - ), - "device_info", - ) - return device_info - - -def _check_helium_info(helium_info, *, info): - _validate_type( - helium_info, - ( - None, - dict, - ), - "helium_info", - ) - if isinstance(helium_info, dict): - if "meas_date" in helium_info: - _validate_type( - helium_info["meas_date"], - datetime.datetime, - "helium_info['meas_date']", - ) - return helium_info - - # TODO: Add fNIRS convention to loc -class Info(dict, SetChannelsMixin, MontageMixin, ContainsMixin): +class Info(ValidatedDict, SetChannelsMixin, MontageMixin, ContainsMixin): """Measurement information. This data structure behaves like a dictionary. It contains all metadata @@ -1502,24 +1594,28 @@ class Info(dict, SetChannelsMixin, MontageMixin, ContainsMixin): "custom_ref_applied": "custom_ref_applied cannot be set directly. " "Please use method inst.set_eeg_reference() " "instead.", - "description": _check_description, + "description": partial(_check_types, name="description", types=(str, None)), "dev_ctf_t": "dev_ctf_t cannot be set directly.", "dev_head_t": _check_dev_head_t, - "device_info": _check_device_info, + "device_info": partial(_check_types, name="device_info", types=(dict, None)), "dig": "dig cannot be set directly. " "Please use method inst.set_montage() instead.", "events": "events cannot be set directly.", - "experimenter": _check_experimenter, + "experimenter": partial(_check_types, name="experimenter", types=(str, None)), "file_id": "file_id cannot be set directly.", "gantry_angle": "gantry_angle cannot be set directly.", - "helium_info": _check_helium_info, + "helium_info": partial( + _check_types, name="helium_info", types=(dict, None), cast=HeliumInfo + ), "highpass": "highpass cannot be set directly. " "Please use method inst.filter() instead.", "hpi_meas": "hpi_meas can not be set directly.", "hpi_results": "hpi_results cannot be set directly.", "hpi_subsystem": "hpi_subsystem cannot be set directly.", "kit_system_id": "kit_system_id cannot be set directly.", - "line_freq": _check_line_freq, + "line_freq": partial( + _check_types, name="line_freq", types=("numeric", None), cast=float + ), "lowpass": "lowpass cannot be set directly. " "Please use method inst.filter() instead.", "maxshield": "maxshield cannot be set directly.", @@ -1541,7 +1637,9 @@ class Info(dict, SetChannelsMixin, MontageMixin, ContainsMixin): "instead.", "sfreq": "sfreq cannot be set directly. " "Please use method inst.resample() instead.", - "subject_info": _check_subject_info, + "subject_info": partial( + _check_types, name="subject_info", types=(dict, None), cast=SubjectInfo + ), "temp": lambda x, info=None: x, "utc_offset": "utc_offset cannot be set directly.", "working_dir": "working_dir cannot be set directly.", @@ -1549,8 +1647,8 @@ class Info(dict, SetChannelsMixin, MontageMixin, ContainsMixin): } def __init__(self, *args, **kwargs): - self._unlocked = True super().__init__(*args, **kwargs) + self._unlocked = True # Deal with h5io writing things as dict if "bads" in self: self["bads"] = MNEBadsList(bads=self["bads"], info=self) @@ -1579,46 +1677,16 @@ def __init__(self, *args, **kwargs): else: self["meas_date"] = _ensure_meas_date_none_or_dt(meas_date) self._unlocked = False - - def __getstate__(self): - """Get state (for pickling).""" - return {"_unlocked": self._unlocked} + # with validation and casting + for key in ("helium_info", "subject_info"): + if key in self: + self[key] = self[key] def __setstate__(self, state): """Set state (for pickling).""" - self._unlocked = state["_unlocked"] + super().__setstate__(state) self["bads"] = MNEBadsList(bads=self["bads"], info=self) - def __setitem__(self, key, val): - """Attribute setter.""" - # During unpickling, the _unlocked attribute has not been set, so - # let __setstate__ do it later and act unlocked now - unlocked = getattr(self, "_unlocked", True) - if key in self._attributes: - if isinstance(self._attributes[key], str): - if not unlocked: - raise RuntimeError(self._attributes[key]) - else: - val = self._attributes[key]( - val, info=self - ) # attribute checker function - else: - raise RuntimeError( - f"Info does not support directly setting the key {repr(key)}. " - "You can set info['temp'] to store temporary objects in an " - "Info instance, but these will not survive an I/O round-trip." - ) - super().__setitem__(key, val) - - def update(self, other=None, **kwargs): - """Update method using __setitem__().""" - iterable = other.items() if isinstance(other, Mapping) else other - if other is not None: - for key, val in iterable: - self[key] = val - for key, val in kwargs.items(): - self[key] = val - @contextlib.contextmanager def _unlock(self, *, update_redundant=False, check_after=False): """Context manager unlocking access to attributes.""" @@ -1638,16 +1706,6 @@ def _unlock(self, *, update_redundant=False, check_after=False): finally: self._unlocked = state - def copy(self): - """Copy the instance. - - Returns - ------- - info : instance of Info - The copied info. - """ - return deepcopy(self) - def normalize_proj(self): """(Re-)Normalize projection vectors after subselection. @@ -1738,6 +1796,8 @@ def __repr__(self): entr = str(bool(v)) if not v: non_empty -= 1 # don't count if 0 + elif isinstance(v, ValidatedDict): + entr = repr(v) else: try: this_len = len(v) @@ -2378,10 +2438,10 @@ def read_meas_info(fid, tree, clean_bads=False, verbose=None): si["hand"] = int(tag.data.item()) elif kind == FIFF.FIFF_SUBJ_WEIGHT: tag = read_tag(fid, pos) - si["weight"] = tag.data + si["weight"] = float(tag.data.item()) elif kind == FIFF.FIFF_SUBJ_HEIGHT: tag = read_tag(fid, pos) - si["height"] = tag.data + si["height"] = float(tag.data.item()) info["subject_info"] = si del si @@ -3697,3 +3757,7 @@ def _get_fnirs_ch_pos(info): for optode in [*srcs, *dets]: ch_pos[optode] = _optode_position(info, optode) return ch_pos + + +def _camel_to_snake(s): + return re.sub(r"(? Date: Thu, 26 Sep 2024 14:50:40 -0400 Subject: [PATCH 2/4] Update mne/_fiff/meas_info.py Co-authored-by: Daniel McCloy --- mne/_fiff/meas_info.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mne/_fiff/meas_info.py b/mne/_fiff/meas_info.py index 9a1b7122b2d..c881822e44a 100644 --- a/mne/_fiff/meas_info.py +++ b/mne/_fiff/meas_info.py @@ -1064,7 +1064,7 @@ class HeliumInfo(ValidatedDict): } def __init__(self, initial): - _validate_type(initial, dict, "subject_info") + _validate_type(initial, dict, "helium_info") super().__init__() for key, val in initial.items(): self[key] = val From 5ad7ceefaaa14ad8c4d25bb6ed31ad5006a2f21c Mon Sep 17 00:00:00 2001 From: Eric Larson Date: Thu, 26 Sep 2024 14:50:52 -0400 Subject: [PATCH 3/4] FIX: Doc --- doc/changes/{ => devel}/bugfix.rst | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename doc/changes/{ => devel}/bugfix.rst (100%) diff --git a/doc/changes/bugfix.rst b/doc/changes/devel/bugfix.rst similarity index 100% rename from doc/changes/bugfix.rst rename to doc/changes/devel/bugfix.rst From 5af21c3fb1c09e0247ebd452efc53045b244716f Mon Sep 17 00:00:00 2001 From: "autofix-ci[bot]" <114827586+autofix-ci[bot]@users.noreply.github.com> Date: Thu, 26 Sep 2024 18:51:22 +0000 Subject: [PATCH 4/4] [autofix.ci] apply automated fixes --- doc/changes/devel/{bugfix.rst => 12875.bugfix.rst} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename doc/changes/devel/{bugfix.rst => 12875.bugfix.rst} (100%) diff --git a/doc/changes/devel/bugfix.rst b/doc/changes/devel/12875.bugfix.rst similarity index 100% rename from doc/changes/devel/bugfix.rst rename to doc/changes/devel/12875.bugfix.rst