Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
* Fix LorenFrankLab#1098

* Update changelog
  • Loading branch information
CBroz1 authored Sep 11, 2024
1 parent d9115f6 commit a7b13fa
Show file tree
Hide file tree
Showing 5 changed files with 27 additions and 15 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@ TrackGraph.alter() # Comment regarding the change
- Linearization
- Add edge_map parameter to LinearizedPositionV1 #1091

- Position

- Restore #973, allow DLC without position tracking #1100

- Spike Sorting

- Fix bug in `get_group_by_shank` #1096
Expand Down
9 changes: 6 additions & 3 deletions src/spyglass/position/v1/position_dlc_centroid.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,19 +254,22 @@ def _logged_make(self, key):
total_nan = np.sum(final_df.loc[:, idx[("x", "y")]].isna().any(axis=1))

logger.info(f"total NaNs in centroid dataset: {total_nan}")
spatial_series = (RawPosition() & key).fetch_nwb()[0]["raw_position"]
position = pynwb.behavior.Position()
velocity = pynwb.behavior.BehavioralTimeSeries()
if query := (RawPosition() & key):
spatial_series = query.fetch_nwb()[0]["raw_position"]
else:
spatial_series = None

common_attrs = {
"conversion": METERS_PER_CM,
"comments": spatial_series.comments,
"comments": getattr(spatial_series, "comments", ""),
}
position.create_spatial_series(
name="position",
timestamps=final_df.index.to_numpy(),
data=final_df.loc[:, idx[("x", "y")]].to_numpy(),
reference_frame=spatial_series.reference_frame,
reference_frame=getattr(spatial_series, "reference_frame", ""),
description="x_position, y_position",
**common_attrs,
)
Expand Down
2 changes: 1 addition & 1 deletion src/spyglass/position/v1/position_dlc_orient.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
red_led_bisector_orientation,
two_pt_head_orientation,
)
from spyglass.utils import SpyglassMixin, logger
from spyglass.utils import SpyglassMixin

from .position_dlc_cohort import DLCSmoothInterpCohort

Expand Down
25 changes: 15 additions & 10 deletions src/spyglass/position/v1/position_dlc_pose_estimation.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,17 +258,20 @@ def _logged_make(self, key):
populate_missing=False,
)
)
spatial_series = (
RawPosition() & {**key, "interval_list_name": interval_list_name}
).fetch_nwb()[0]["raw_position"]
_, _, _, video_time = get_video_info(key)
pos_time = spatial_series.timestamps
if interval_list_name:
spatial_series = (
RawPosition()
& {**key, "interval_list_name": interval_list_name}
).fetch_nwb()[0]["raw_position"]
else:
spatial_series = None

_, _, meters_per_pixel, video_time = get_video_info(key)
key["meters_per_pixel"] = meters_per_pixel

# TODO: should get timestamps from VideoFile, but need the
# video_frame_ind from RawPosition, which also has timestamps

key["meters_per_pixel"] = spatial_series.conversion

# Insert entry into DLCPoseEstimation
logger.info(
"Inserting %s, epoch %02d into DLCPoseEsimation",
Expand Down Expand Up @@ -296,7 +299,9 @@ def _logged_make(self, key):
part_df = convert_to_cm(part_df, meters_per_pixel)
logger.info("adding timestamps to DataFrame")
part_df = add_timestamps(
part_df, pos_time=pos_time, video_time=video_time
part_df,
pos_time=getattr(spatial_series, "timestamps", video_time),
video_time=video_time,
)
key["bodypart"] = body_part
key["analysis_file_name"] = AnalysisNwbfile().create(
Expand All @@ -309,8 +314,8 @@ def _logged_make(self, key):
timestamps=part_df.time.to_numpy(),
conversion=METERS_PER_CM,
data=part_df.loc[:, idx[("x", "y")]].to_numpy(),
reference_frame=spatial_series.reference_frame,
comments=spatial_series.comments,
reference_frame=getattr(spatial_series, "reference_frame", ""),
comments=getattr(spatial_series, "comments", "no comments"),
description="x_position, y_position",
)
likelihood.create_timeseries(
Expand Down
2 changes: 1 addition & 1 deletion src/spyglass/utils/dj_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -614,7 +614,7 @@ def populate(self, *restrictions, **kwargs):
for key in keys:
self.make(key)
if upstream_hash != self._hash_upstream(keys):
(self & keys).delete(force=True)
(self & keys).delete(safemode=False)
logger.error(
"Upstream tables changed during non-transaction "
+ "populate. Please try again."
Expand Down

0 comments on commit a7b13fa

Please sign in to comment.