Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merge revisions #25

Merged
merged 27 commits into from
May 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
a57bf18
RNN example, CCA
agosztolai Apr 9, 2024
4c50289
readme
agosztolai Apr 10, 2024
bbddd84
version control for OSX arm
agosztolai Apr 10, 2024
21ca95d
macaque data
agosztolai Apr 10, 2024
74be637
test
agosztolai Apr 10, 2024
dde02cf
Merge branch 'revision' of github.com:agosztolai/MARBLE into revision
agosztolai Apr 10, 2024
e391b9d
Update run_marble.py
agosztolai Apr 10, 2024
da9c710
Merge completed
agosztolai Apr 10, 2024
99017b9
Merge branch 'revision' of github.com:agosztolai/MARBLE into revision
agosztolai Apr 10, 2024
134d729
RNN example updated with fixed CCA analysis
agosztolai Apr 16, 2024
d7d793c
fixed RNN cca example
peach-lucien Apr 17, 2024
1d54434
RNN update
agosztolai Apr 17, 2024
f6cf48e
RNN example
agosztolai Apr 18, 2024
28ab123
macaque example
agosztolai Apr 26, 2024
9dedc45
macaque example
agosztolai Apr 26, 2024
b46d85e
added robustness analyses for rat data with varying kernel widths and…
peach-lucien May 1, 2024
da9f4e5
Merge branch 'revision' of https://github.com/agosztolai/MARBLE into …
peach-lucien May 1, 2024
c4da153
updating rat examples
peach-lucien May 2, 2024
2c3012f
examples
agosztolai May 16, 2024
cbdba9c
added robustness analysis for decoding rat hippocampus
peach-lucien May 18, 2024
23efcbb
Merge branch 'revision' of https://github.com/agosztolai/MARBLE into …
peach-lucien May 18, 2024
20038f4
rat_example
agosztolai May 21, 2024
adec336
rat example
agosztolai May 22, 2024
d4e12ba
clean notbeook
arnaudon May 22, 2024
5ae3570
Delete examples/macaque_reaching/run_marble.sh
agosztolai May 22, 2024
f848403
partial fix
arnaudon May 22, 2024
0c9cb7b
fix test
arnaudon May 22, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions MARBLE/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""MARBLE main functions."""

from MARBLE.main import net
from MARBLE.postprocessing import distribution_distances
from MARBLE.postprocessing import embed_in_2D
Expand Down
1 change: 1 addition & 0 deletions MARBLE/dataloader.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Data loader module."""

import torch
from torch_cluster import random_walk
from torch_geometric.loader import NeighborSampler as NS
Expand Down
2 changes: 1 addition & 1 deletion MARBLE/default_params.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ order: 2 # order to which to compute the directional derivatives
inner_product_features: False
diffusion: True
frac_sampled_nb: -1 # fraction of neighbours to sample for gradient computation (if -1 then all neighbours)
include_positions: False # include positions as features (warning: this is untested!)
include_positions: False # include positions as features
include_self: True # include vector at the center of feature

# network parameters
Expand Down
1 change: 1 addition & 0 deletions MARBLE/dynamics.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

TODO: clean this up
"""

import sys

import numpy as np
Expand Down
25 changes: 15 additions & 10 deletions MARBLE/geometry.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Geometry module."""

import numpy as np
import ot
import scipy.sparse as sp
Expand Down Expand Up @@ -93,7 +94,7 @@ def cluster(x, cluster_typ="kmeans", n_clusters=15, seed=0):
return clusters


def embed(x, embed_typ="umap", dim_emb=2, manifold=None, seed=0, **kwargs):
def embed(x, embed_typ="umap", dim_emb=2, manifold=None, verbose=True, seed=0, **kwargs):
"""Embed data to 2D space.

Args:
Expand Down Expand Up @@ -149,7 +150,8 @@ def embed(x, embed_typ="umap", dim_emb=2, manifold=None, seed=0, **kwargs):
else:
raise NotImplementedError

print(f"Performed {embed_typ} embedding on embedded results.")
if verbose:
print(f"Performed {embed_typ} embedding on embedded results.")

return emb, manifold

Expand Down Expand Up @@ -195,7 +197,7 @@ def compute_distribution_distances(clusters=None, data=None, slices=None):
centroid_distances: distances between cluster centroids
"""
s = slices

pdists, cdists = None, None
if clusters is not None:
# compute discrete measures supported on cluster centroids
labels = clusters["labels"]
Expand Down Expand Up @@ -230,7 +232,7 @@ def compute_distribution_distances(clusters=None, data=None, slices=None):
for j in range(i + 1, nl):
mu, nu = bins_dataset[i], bins_dataset[j]

if data is not None:
if data is not None and pdists is not None:
cdists = pdists[s[i] : s[i + 1], s[j] : s[j + 1]]

dist[i, j] = ot.emd2(mu, nu, cdists)
Expand Down Expand Up @@ -354,11 +356,11 @@ def manifold_dimension(Sigma, frac_explained=0.9):
return int(dim_man)


def fit_graph(x, graph_type="cknn", par=1, delta=1.0):
def fit_graph(x, graph_type="cknn", par=1, delta=1.0, metric="euclidean"):
"""Fit graph to node positions"""

if graph_type == "cknn":
edge_index = cknneighbors_graph(x, n_neighbors=par, delta=delta).tocoo()
edge_index = cknneighbors_graph(x, n_neighbors=par, delta=delta, metric=metric).tocoo()
edge_index = np.vstack([edge_index.row, edge_index.col])
edge_index = utils.np2torch(edge_index, dtype="double")

Expand Down Expand Up @@ -511,7 +513,7 @@ def compute_connections(data, gauges, processes=1):

Args:
data: Pytorch geometric data object
gauges (n,d,d matrix): Orthogonal unit vectors for each node
gauges (nxdxd matrix): Orthogonal unit vectors for each node
processes: number of CPUs to use

Returns:
Expand Down Expand Up @@ -568,7 +570,7 @@ def compute_eigendecomposition(A, k=None, eps=1e-8):
return None

if k is None:
A = A.to_dense()
A = A.to_dense().double()
else:
indices, values, size = A.indices(), A.values(), A.size()
A = sp.coo_array((values, (indices[0], indices[1])), shape=size)
Expand All @@ -592,6 +594,9 @@ def compute_eigendecomposition(A, k=None, eps=1e-8):
raise ValueError("failed to compute eigendecomp") from e
failcount += 1
print("--- decomp failed; adding eps ===> count: " + str(failcount))
A += sp.eye(A.shape[0]) * (eps * 10 ** (failcount - 1))
if k is None:
A += torch.eye(A.shape[0]) * (eps * 10 ** (failcount - 1))
else:
A += sp.eye(A.shape[0]) * (eps * 10 ** (failcount - 1))

return evals, evecs
return evals.float(), evecs.float()
3 changes: 2 additions & 1 deletion MARBLE/layers.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Layer module."""

import torch
from torch import nn
from torch_geometric.nn.conv import MessagePassing
Expand Down Expand Up @@ -86,7 +87,7 @@ def __init__(self, C, D):

self.O_mat = nn.ModuleList()
for _ in range(C):
self.O_mat.append(nn.Linear(D, D, bias=False))
self.O_mat.append(nn.Linear(D, D, bias=True))

self.reset_parameters()

Expand Down
1 change: 1 addition & 0 deletions MARBLE/lib/cknn.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Module imported and adapted from https://github.com/chlorochrule/cknn."""

import numpy as np
from scipy.sparse import csr_matrix
from scipy.spatial.distance import pdist
Expand Down
10 changes: 6 additions & 4 deletions MARBLE/main.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Main network"""

import glob
import os
import warnings
Expand Down Expand Up @@ -77,6 +78,7 @@ def __init__(self, data, loadpath=None, params=None, verbose=True):
self.setup_layers()
self.loss = loss_fun()
self.reset_parameters()
self.timestamp = None

if verbose:
utils.print_settings(self)
Expand Down Expand Up @@ -170,6 +172,9 @@ def setup_layers(self):

# cumulated number of channels after gradient features
cum_channels = s * (1 - d ** (o + 1)) // (1 - d)
if not self.params["include_self"]:
cum_channels -= s

if self.params["inner_product_features"]:
cum_channels //= s
if s == 1:
Expand All @@ -182,9 +187,6 @@ def setup_layers(self):
if self.params["include_positions"]:
cum_channels += d

if not self.params["include_self"]:
cum_channels -= s

# encoder
if not isinstance(self.params["hidden_channels"], list):
self.params["hidden_channels"] = [self.params["hidden_channels"]]
Expand Down Expand Up @@ -350,7 +352,7 @@ def fit(self, data, outdir=None, verbose=False):

self.timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")

print("\n---- Timestamp: {}".format(self.timestamp))
print(f"\n---- Timestamp: {self.timestamp}")

# load to gpu (if possible)
# pylint: disable=self-cls-assignment
Expand Down
11 changes: 6 additions & 5 deletions MARBLE/plotting.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Plotting module."""

import matplotlib
import matplotlib.pyplot as plt
import networkx as nx
Expand Down Expand Up @@ -182,6 +183,8 @@ def embedding(
emb = data.emb_2D
elif isinstance(data, np.ndarray) or torch.is_tensor(data):
emb = data
else:
raise TypeError

dim = emb.shape[1]
assert dim in [2, 3], f"Embedding dimension is {dim} which cannot be displayed."
Expand Down Expand Up @@ -531,7 +534,7 @@ def trajectories(
Args:
X (np array): Positions
V (np array): Velocities
ax (matplotlib axes object): If specificed, it will plot on existing axes. The default is None
ax (matplotlib axes object): If specificed, it will plot on existing axes. Default is None
style (string): Plotting style. 'o' for scatter plot or '-' for line plot
node_feature: Color lines. The default is None
lw (int): Line width
Expand Down Expand Up @@ -605,9 +608,7 @@ def trajectories(
X, V = X[skip], V[skip]
plot_arrows(X, V, ax, c, width=lw, scale=scale)
else:
raise Exception(
"Data dimension is: {}. It needs to be 2 or 3 to allow plotting.".format(dim)
)
raise Exception(f"Data dimension is: {dim}. It needs to be 2 or 3 to allow plotting.")

set_axes(ax, axes_visible=axes_visible)

Expand Down Expand Up @@ -690,7 +691,7 @@ def create_axis(*args, fig=None):
elif dim == 3:
ax = fig.add_subplot(*args, projection="3d")
else:
raise Exception("Data dimension is {}. We can only plot 2D or 3D data.".format(dim))
raise Exception(f"Data dimension is {dim}. We can only plot 2D or 3D data.")

return fig, ax

Expand Down
1 change: 1 addition & 0 deletions MARBLE/postprocessing.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Postprocessing module."""

import numpy as np

from MARBLE import geometry as g
Expand Down
17 changes: 9 additions & 8 deletions MARBLE/preprocessing.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Preprocessing module."""

import torch
from torch_geometric.data import Batch
from torch_geometric.data import Data
Expand All @@ -16,13 +17,13 @@ def construct_dataset(
graph_type="cknn",
k=20,
delta=1.0,
n_eigenvalues=None,
frac_geodesic_nb=1.5,
spacing=0.0,
number_of_resamples=1,
var_explained=0.9,
local_gauges=False,
seed=None,
metric="euclidean",
):
"""Construct PyG dataset from node positions and features.

Expand All @@ -34,16 +35,14 @@ def construct_dataset(
graph_type: type of nearest-neighbours graph: cknn (default), knn or radius
k: number of nearest-neighbours to construct the graph
delta: argument for cknn graph construction to decide the radius for each points.
n_eigenvalues: number of eigenvalue/eigenvector pairs to compute (None means all,
but this can be slow)
frac_geodesic_nb: number of geodesic neighbours to fit the gauges to
to map to tangent space k*frac_geodesic_nb
stop_crit: stopping criterion for furthest point sampling
number_of_resamples: number of furthest point sampling runs to prevent bias (experimental)
var_explained: fraction of variance explained by the local gauges
local_gauges: is True, it will try to compute local gauges if it can (signal dim is > 2,
embedding dimension is > 2 or dim embedding is not dim of manifold)
seed: Specify for reproducibility in the furthest point sampling.
seed: Specify for reproducibility in the furthest point sampling.
The default is None, which means a random starting vertex.
"""

Expand Down Expand Up @@ -73,14 +72,16 @@ def construct_dataset(
start_idx = torch.randint(low=0, high=len(a), size=(1,))
else:
start_idx = 0

sample_ind, _ = g.furthest_point_sampling(a, spacing=spacing, start_idx=start_idx)
sample_ind, _ = torch.sort(sample_ind) # this will make postprocessing easier
a_, v_, l_, m_ = a[sample_ind], v[sample_ind], l[sample_ind], m[sample_ind]

# fit graph to point cloud
edge_index, edge_weight = g.fit_graph(a_, graph_type=graph_type, par=k, delta=delta)

edge_index, edge_weight = g.fit_graph(
a_, graph_type=graph_type, par=k, delta=delta, metric=metric
)

# define data object
data_ = Data(
pos=a_,
Expand Down
3 changes: 1 addition & 2 deletions MARBLE/smoothing.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""Smoothing module."""

import torch

Expand Down
1 change: 1 addition & 0 deletions MARBLE/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Utils module."""

import multiprocessing
from functools import partial
from typing import NamedTuple
Expand Down
14 changes: 9 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# MARBLE - MAnifold Representation Basis LEarning

MARBLE is a fully unsupervised geometric deep-learning method that can
MARBLE is an unsupervised geometric deep-learning method that can

1. Find interpretable representations of neural dynamics, more generally dynamical systems, or even more generally vector fields over manifolds.
1. Find interpretable latent representations of neural dynamics. It also applies to non-linear dynamical systems in other domains or, more generally, vector fields over manifolds.
2. Perform unbiased comparisons across conditions within the same animal (or dynamical system).
3. Compare dynamics across animals or artificial neural networks.

Expand Down Expand Up @@ -70,14 +70,18 @@ conda activate MARBLE
pip install .
```

### Note on examples

Running the scripts in the `/examples` folder to reproduce our results will rely on several dependencies that are not core to the MARBLE code. On Macs, run `brew install wget`, which you will need to download the datasets automatically. Further dependencies will install automatically when running the example notebooks.

## Quick start

We suggest you study at least the example of a [simple vector fields over flat surfaces](https://github.com/agosztolai/MARBLE/blob/main/examples/toy_examples/ex_vector_field_flat_surface.py) to understand what behaviour to expect.

Briefly, MARBLE takes two inputs

1. `pos` - a list of `nxd` arrays, each defining a cloud of anchor points describing the geometry of a manifold
2. `x` - a list of `nxD` arrays, defining a vector signal over the respective manifolds in 1. For dynamical systems, D=d, but our code can also handle signals of other dimensions. Read more about [inputs](#inputs) and [different conditions](#conditions).
1. `anchor` - a list of `nxd` arrays, each defining a cloud of anchor points describing the manifold
2. `vector` - a list of `nxD` arrays, defining a vector signal at the anchor points over the respective manifolds. For dynamical systems, D=d, but our code can also handle signals of other dimensions. Read more about [inputs](#inputs) and [different conditions](#conditions).

Using these inputs, you can construct a dataset for MARBLE.

Expand All @@ -86,7 +90,7 @@ import MARBLE
data = MARBLE.construct_dataset(anchor=pos, vector=x)
```

The main attributes are `data.pos` - manifold positions concatenated, `data.x` - manifold signals concatenated and `data.y` - identifiers that tell you which manifold the point belongs to. Read more about [other usedul data attributed](#construct).
The main attributes are `data.pos` - manifold positions concatenated, `data.x` - manifold signals concatenated and `data.y` - identifiers that tell you which manifold the point belongs to. Read more about [other useful data attributes](#construct).

Now, you can initialise and train a MARBLE model. Read more about [training parameters](#training).

Expand Down
Loading
Loading