Skip to content

Commit

Permalink
Merge branch 'main' into celia/fix-790-public
Browse files Browse the repository at this point in the history
  • Loading branch information
MMathisLab authored Sep 25, 2024
2 parents e11b4dd + 9e14790 commit 6410a1f
Show file tree
Hide file tree
Showing 12 changed files with 34 additions and 25 deletions.
14 changes: 8 additions & 6 deletions cebra/data/assets.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@
#

import hashlib
import os
import re
import warnings
from pathlib import Path
from typing import Optional

import requests
Expand Down Expand Up @@ -57,8 +57,10 @@ def download_file_with_progress_bar(url: str,
"""

# Check if the file already exists in the location
file_path = os.path.join(location, file_name)
if os.path.exists(file_path):
location_path = Path(location)
file_path = location_path / file_name

if file_path.exists():
existing_checksum = calculate_checksum(file_path)
if existing_checksum == expected_checksum:
return file_path
Expand Down Expand Up @@ -91,10 +93,10 @@ def download_file_with_progress_bar(url: str,
)

# Create the directory and any necessary parent directories
os.makedirs(location, exist_ok=True)
location_path.mkdir(exist_ok=True)

filename = filename_match.group(1)
file_path = os.path.join(location, filename)
file_path = location_path / filename

total_size = int(response.headers.get("Content-Length", 0))
checksum = hashlib.md5() # create checksum
Expand All @@ -111,7 +113,7 @@ def download_file_with_progress_bar(url: str,
downloaded_checksum = checksum.hexdigest() # Get the checksum value
if downloaded_checksum != expected_checksum:
warnings.warn(f"Checksum verification failed. Deleting '{file_path}'.")
os.remove(file_path)
file_path.unlink()
warnings.warn("File deleted. Retrying download...")

# Retry download using a for loop
Expand Down
1 change: 0 additions & 1 deletion cebra/datasets/allen/ca_movie_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@

import glob
import hashlib
import os
import pathlib

import h5py
Expand Down
1 change: 0 additions & 1 deletion cebra/datasets/allen/neuropixel_movie.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
"""
import glob
import hashlib
import os
import pathlib

import h5py
Expand Down
1 change: 0 additions & 1 deletion cebra/datasets/allen/neuropixel_movie_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
"""
import glob
import hashlib
import os
import pathlib

import h5py
Expand Down
4 changes: 3 additions & 1 deletion cebra/datasets/gaussian_mixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import pathlib
from typing import Tuple

import joblib as jl
Expand Down Expand Up @@ -51,7 +52,8 @@ def __init__(self, noise: str = "poisson"):
super().__init__()
self.noise = noise
data = jl.load(
get_datapath(f"synthetic/continuous_label_{self.noise}.jl"))
pathlib.Path(_DEFAULT_DATADIR) / "synthetic" /
f"continuous_label_{self.noise}.jl")
self.latent = data["z"]
self.index = torch.from_numpy(data["u"]).float()
self.neural = torch.from_numpy(data["x"]).float()
Expand Down
4 changes: 2 additions & 2 deletions cebra/datasets/generate_synthetic_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
Adapted from pi-VAE: https://github.com/zhd96/pi-vae/blob/main/code/pi_vae.py
"""
import argparse
import os
import pathlib
import sys

import joblib as jl
Expand Down Expand Up @@ -245,5 +245,5 @@ def refractory_poisson(x):
"lam": lam_true,
"x": x
},
os.path.join(args.save_path, f"continuous_label_{args.noise}.jl"),
pathlib.Path(args.save_path) / f"continuous_label_{args.noise}.jl",
)
6 changes: 3 additions & 3 deletions cebra/datasets/hippocampus.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
"""

import hashlib
import os
import pathlib

import joblib
import numpy as np
Expand Down Expand Up @@ -94,8 +94,8 @@ class SingleRatDataset(cebra.data.SingleSessionDataset):
"""

def __init__(self, name="achilles", root=_DEFAULT_DATADIR, download=True):
location = os.path.join(root, "rat_hippocampus")
file_path = os.path.join(location, f"{name}.jl")
location = pathlib.Path(root) / "rat_hippocampus"
file_path = location / f"{name}.jl"

super().__init__(download=download,
data_url=rat_dataset_urls[name]["url"],
Expand Down
23 changes: 16 additions & 7 deletions cebra/datasets/monkey_reaching.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,9 @@
"""

import hashlib
import os
import pathlib
import pickle as pk
from typing import Union

import joblib as jl
import numpy as np
Expand All @@ -41,10 +42,11 @@
from cebra.datasets import get_datapath
from cebra.datasets import register

_DEFAULT_DATADIR = get_datapath()


def _load_data(
path: str = get_datapath(
"s1_reaching/sub-Han_desc-train_behavior+ecephys.nwb"),
path: Union[str, pathlib.Path] = None,
session: str = "active",
split: str = "train",
):
Expand All @@ -61,6 +63,13 @@ def _load_data(
"""

if path is None:
path = pathlib.Path(
_DEFAULT_DATADIR
) / "s1_reaching" / "sub-Han_desc-train_behavior+ecephys.nwb"
else:
path = pathlib.Path(path)

try:
from nlb_tools.nwb_interface import NWBDataset
except ImportError as e:
Expand Down Expand Up @@ -259,7 +268,7 @@ def __init__(self,
)

self.data = jl.load(
os.path.join(self.path, f"{self.load_session}_all.jl"))
pathlib.Path(self.path) / f"{self.load_session}_all.jl")
self._post_load()

def split(self, split):
Expand All @@ -285,7 +294,7 @@ def split(self, split):
file_name=f"{self.load_session}_{split}.jl",
)
self.data = jl.load(
os.path.join(self.path, f"{self.load_session}_{split}.jl"))
pathlib.Path(self.path) / f"{self.load_session}_{split}.jl")
self._post_load()

def _post_load(self):
Expand Down Expand Up @@ -407,7 +416,7 @@ def _create_area2_dataset():
"""

PATH = get_datapath("monkey_reaching_preload_smth_40")
PATH = pathlib.Path(_DEFAULT_DATADIR) / "monkey_reaching_preload_smth_40"
for session_type in ["active", "passive", "active-passive", "all"]:

@register(f"area2-bump-pos-{session_type}")
Expand Down Expand Up @@ -506,7 +515,7 @@ def _create_area2_shuffled_dataset():
"""

PATH = get_datapath("monkey_reaching_preload_smth_40/")
PATH = pathlib.Path(_DEFAULT_DATADIR) / "monkey_reaching_preload_smth_40"
for session_type in ["active", "active-passive"]:

@register(f"area2-bump-pos-{session_type}-shuffled-trial")
Expand Down
1 change: 0 additions & 1 deletion cebra/solver/multi_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
"""Solver implementations for multi-session datasetes."""

import abc
import os
from collections.abc import Iterable
from typing import List, Optional

Expand Down
1 change: 0 additions & 1 deletion cebra/solver/single_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@

import abc
import copy
import os
from collections.abc import Iterable
from typing import List

Expand Down
1 change: 0 additions & 1 deletion cebra/solver/supervised.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
as experimental/outdated, and the API for this particular package unstable.
"""
import abc
import os
from collections.abc import Iterable
from typing import List

Expand Down
2 changes: 2 additions & 0 deletions tests/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,8 @@ def test_monkey():
def test_allen():
from cebra.datasets import allen

pytest.skip("Test takes too long")

ca_dataset = cebra.datasets.init("allen-movie-one-ca-VISp-100-train-10-111")
ca_loader = cebra.data.ContinuousDataLoader(
dataset=ca_dataset,
Expand Down

0 comments on commit 6410a1f

Please sign in to comment.