From 34134db8be6dd9a494026758836503812654a3ea Mon Sep 17 00:00:00 2001 From: cbroz1 Date: Mon, 21 Aug 2023 10:26:49 -0700 Subject: [PATCH] WIP: Fix Trodes Video --- src/spyglass/common/common_behav.py | 14 +++- src/spyglass/position/v1/dlc_utils.py | 13 ++-- .../position/v1/position_trodes_position.py | 68 +++++++++++++++---- 3 files changed, 75 insertions(+), 20 deletions(-) diff --git a/src/spyglass/common/common_behav.py b/src/spyglass/common/common_behav.py index 72ea78adc..c775069b0 100644 --- a/src/spyglass/common/common_behav.py +++ b/src/spyglass/common/common_behav.py @@ -295,6 +295,13 @@ class VideoFile(dj.Imported): """ def make(self, key): + self._no_transaction_make(key) + + def _no_transaction_make(self, key, verbose=True): + if not self.connection.in_transaction: + self.populate(key) + return + nwb_file_name = key["nwb_file_name"] nwb_file_abspath = Nwbfile.get_abs_path(nwb_file_name) nwbf = get_nwb_file(nwb_file_abspath) @@ -340,8 +347,11 @@ def make(self, key): self.insert1(key) is_found = True - if not is_found: - print(f"No video found corresponding to epoch {interval_list_name}") + if not is_found and verbose: + print( + f"No video found corresponding to file {nwb_file_name}, " + + f"epoch {interval_list_name}" + ) def fetch_nwb(self, *attrs, **kwargs): return fetch_nwb(self, (Nwbfile, "nwb_file_abs_path"), *attrs, **kwargs) diff --git a/src/spyglass/position/v1/dlc_utils.py b/src/spyglass/position/v1/dlc_utils.py index eb2cf2d1b..ce3e3b3fc 100644 --- a/src/spyglass/position/v1/dlc_utils.py +++ b/src/spyglass/position/v1/dlc_utils.py @@ -323,10 +323,15 @@ def get_video_path(key): from ...common.common_behav import VideoFile - video_info = ( - VideoFile() - & {"nwb_file_name": key["nwb_file_name"], "epoch": key["epoch"]} - ).fetch1() + vf_key = {"nwb_file_name": key["nwb_file_name"], "epoch": key["epoch"]} + VideoFile()._no_transaction_make(vf_key, verbose=False) + video_query = VideoFile & vf_key + + if len(video_query) != 1: + print(f"Found {len(video_query)} videos for {vf_key}") + return None, None, None, None + + video_info = video_query.fetch1() nwb_path = f"{raw_dir}/{video_info['nwb_file_name']}" with pynwb.NWBHDF5IO(path=nwb_path, mode="r") as in_out: diff --git a/src/spyglass/position/v1/position_trodes_position.py b/src/spyglass/position/v1/position_trodes_position.py index e90e8a17a..da7456232 100644 --- a/src/spyglass/position/v1/position_trodes_position.py +++ b/src/spyglass/position/v1/position_trodes_position.py @@ -41,12 +41,13 @@ class TrodesPosParams(dj.Manual): params: longblob """ - @classmethod - def insert_default(cls, **kwargs): - """ - Insert default parameter set for position determination - """ - params = { + @property + def default_pk(self): + return {"trodes_pos_params_name": "default"} + + @property + def default_params(self): + return { "max_separation": 9.0, "max_speed": 300.0, "position_smoothing_duration": 0.125, @@ -57,24 +58,29 @@ def insert_default(cls, **kwargs): "upsampling_sampling_rate": None, "upsampling_interpolation_method": "linear", } + + @classmethod + def insert_default(cls, **kwargs): + """ + Insert default parameter set for position determination + """ cls.insert1( - {"trodes_pos_params_name": "default", "params": params}, + {**cls().default_pk, "params": cls().default_params}, skip_duplicates=True, ) @classmethod def get_default(cls): - query = cls & {"trodes_pos_params_name": "default"} + query = cls & cls().default_pk if not len(query) > 0: cls().insert_default(skip_duplicates=True) - return (cls & {"trodes_pos_params_name": "default"}).fetch1() + return (cls & cls().default_pk).fetch1() return query.fetch1() @classmethod def get_accepted_params(cls): - default = cls.get_default() - return list(default["params"].keys()) + return [k for k in cls().default_params.keys()] @schema @@ -91,16 +97,27 @@ class TrodesPosSelection(dj.Manual): @classmethod def insert_with_default( - cls, key: dict, skip_duplicates: bool = False + cls, + key: dict, + skip_duplicates: bool = False, + edit_defaults: dict = {}, + edit_name: str = None, ) -> None: """Insert key with default parameters. + To change defaults, supply a dict as edit_defaults with a name for + the new paramset as edit_name. + Parameters ---------- key: Union[dict, str] Restriction uniquely identifying entr(y/ies) in RawPosition. skip_duplicates: bool, optional Skip duplicate entries. + edit_defauts: dict, optional + Dictionary of overrides to default parameters. + edit_name: str, optional + If edit_defauts is passed, the name of the new entry Raises ------ @@ -111,11 +128,26 @@ def insert_with_default( if not query: raise ValueError(f"Found no entries found for {key}") - _ = TrodesPosParams.get_default() + param_pk, param_name = list(TrodesPosParams().default_pk.items())[0] + + if bool(edit_defaults) ^ bool(edit_name): # XOR: only one of them + raise ValueError("Must specify both edit_defauts and edit_name") + + elif edit_defaults and edit_name: + TrodesPosParams.insert1( + { + param_pk: edit_name, + "params": { + **TrodesPosParams().default_params, + **edit_defaults, + }, + }, + skip_duplicates=skip_duplicates, + ) cls.insert( [ - dict(**k, trodes_pos_params_name="default") + {**k, param_pk: edit_name or param_name} for k in query.fetch("KEY", as_dict=True) ], skip_duplicates=skip_duplicates, @@ -487,6 +519,8 @@ class TrodesPosVideo(dj.Computed): definition = """ -> TrodesPosV1 + --- + has_video : bool """ def make(self, key): @@ -520,6 +554,11 @@ def make(self, key): ) = get_video_path( {"nwb_file_name": key["nwb_file_name"], "epoch": epoch} ) + + if not video_path: + self.insert1(dict(**key, has_video=False)) + return + video_dir = os.path.dirname(video_path) + "/" video_path = check_videofile( video_path=video_dir, video_filename=video_filename @@ -553,6 +592,7 @@ def make(self, key): cm_to_pixels=cm_per_pixel, disable_progressbar=False, ) + self.insert1(dict(**key, has_video=True)) @staticmethod def convert_to_pixels(data, frame_size, cm_to_pixels=1.0):