Skip to content

Commit

Permalink
WIP: Fix Trodes Video
Browse files Browse the repository at this point in the history
  • Loading branch information
CBroz1 committed Aug 21, 2023
1 parent bb7580c commit 34134db
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 20 deletions.
14 changes: 12 additions & 2 deletions src/spyglass/common/common_behav.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,13 @@ class VideoFile(dj.Imported):
"""

def make(self, key):
self._no_transaction_make(key)

def _no_transaction_make(self, key, verbose=True):
if not self.connection.in_transaction:
self.populate(key)
return

nwb_file_name = key["nwb_file_name"]
nwb_file_abspath = Nwbfile.get_abs_path(nwb_file_name)
nwbf = get_nwb_file(nwb_file_abspath)
Expand Down Expand Up @@ -340,8 +347,11 @@ def make(self, key):
self.insert1(key)
is_found = True

if not is_found:
print(f"No video found corresponding to epoch {interval_list_name}")
if not is_found and verbose:
print(
f"No video found corresponding to file {nwb_file_name}, "
+ f"epoch {interval_list_name}"
)

def fetch_nwb(self, *attrs, **kwargs):
return fetch_nwb(self, (Nwbfile, "nwb_file_abs_path"), *attrs, **kwargs)
Expand Down
13 changes: 9 additions & 4 deletions src/spyglass/position/v1/dlc_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,10 +323,15 @@ def get_video_path(key):

from ...common.common_behav import VideoFile

video_info = (
VideoFile()
& {"nwb_file_name": key["nwb_file_name"], "epoch": key["epoch"]}
).fetch1()
vf_key = {"nwb_file_name": key["nwb_file_name"], "epoch": key["epoch"]}
VideoFile()._no_transaction_make(vf_key, verbose=False)
video_query = VideoFile & vf_key

if len(video_query) != 1:
print(f"Found {len(video_query)} videos for {vf_key}")
return None, None, None, None

video_info = video_query.fetch1()
nwb_path = f"{raw_dir}/{video_info['nwb_file_name']}"

with pynwb.NWBHDF5IO(path=nwb_path, mode="r") as in_out:
Expand Down
68 changes: 54 additions & 14 deletions src/spyglass/position/v1/position_trodes_position.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,13 @@ class TrodesPosParams(dj.Manual):
params: longblob
"""

@classmethod
def insert_default(cls, **kwargs):
"""
Insert default parameter set for position determination
"""
params = {
@property
def default_pk(self):
return {"trodes_pos_params_name": "default"}

@property
def default_params(self):
return {
"max_separation": 9.0,
"max_speed": 300.0,
"position_smoothing_duration": 0.125,
Expand All @@ -57,24 +58,29 @@ def insert_default(cls, **kwargs):
"upsampling_sampling_rate": None,
"upsampling_interpolation_method": "linear",
}

@classmethod
def insert_default(cls, **kwargs):
"""
Insert default parameter set for position determination
"""
cls.insert1(
{"trodes_pos_params_name": "default", "params": params},
{**cls().default_pk, "params": cls().default_params},
skip_duplicates=True,
)

@classmethod
def get_default(cls):
query = cls & {"trodes_pos_params_name": "default"}
query = cls & cls().default_pk
if not len(query) > 0:
cls().insert_default(skip_duplicates=True)
return (cls & {"trodes_pos_params_name": "default"}).fetch1()
return (cls & cls().default_pk).fetch1()

return query.fetch1()

@classmethod
def get_accepted_params(cls):
default = cls.get_default()
return list(default["params"].keys())
return [k for k in cls().default_params.keys()]


@schema
Expand All @@ -91,16 +97,27 @@ class TrodesPosSelection(dj.Manual):

@classmethod
def insert_with_default(
cls, key: dict, skip_duplicates: bool = False
cls,
key: dict,
skip_duplicates: bool = False,
edit_defaults: dict = {},
edit_name: str = None,
) -> None:
"""Insert key with default parameters.
To change defaults, supply a dict as edit_defaults with a name for
the new paramset as edit_name.
Parameters
----------
key: Union[dict, str]
Restriction uniquely identifying entr(y/ies) in RawPosition.
skip_duplicates: bool, optional
Skip duplicate entries.
edit_defauts: dict, optional
Dictionary of overrides to default parameters.
edit_name: str, optional
If edit_defauts is passed, the name of the new entry
Raises
------
Expand All @@ -111,11 +128,26 @@ def insert_with_default(
if not query:
raise ValueError(f"Found no entries found for {key}")

_ = TrodesPosParams.get_default()
param_pk, param_name = list(TrodesPosParams().default_pk.items())[0]

if bool(edit_defaults) ^ bool(edit_name): # XOR: only one of them
raise ValueError("Must specify both edit_defauts and edit_name")

elif edit_defaults and edit_name:
TrodesPosParams.insert1(
{
param_pk: edit_name,
"params": {
**TrodesPosParams().default_params,
**edit_defaults,
},
},
skip_duplicates=skip_duplicates,
)

cls.insert(
[
dict(**k, trodes_pos_params_name="default")
{**k, param_pk: edit_name or param_name}
for k in query.fetch("KEY", as_dict=True)
],
skip_duplicates=skip_duplicates,
Expand Down Expand Up @@ -487,6 +519,8 @@ class TrodesPosVideo(dj.Computed):

definition = """
-> TrodesPosV1
---
has_video : bool
"""

def make(self, key):
Expand Down Expand Up @@ -520,6 +554,11 @@ def make(self, key):
) = get_video_path(
{"nwb_file_name": key["nwb_file_name"], "epoch": epoch}
)

if not video_path:
self.insert1(dict(**key, has_video=False))
return

video_dir = os.path.dirname(video_path) + "/"
video_path = check_videofile(
video_path=video_dir, video_filename=video_filename
Expand Down Expand Up @@ -553,6 +592,7 @@ def make(self, key):
cm_to_pixels=cm_per_pixel,
disable_progressbar=False,
)
self.insert1(dict(**key, has_video=True))

@staticmethod
def convert_to_pixels(data, frame_size, cm_to_pixels=1.0):
Expand Down

0 comments on commit 34134db

Please sign in to comment.