Skip to content

Commit

Permalink
Merge branch 'main' of github.com:agosztolai/MARBLE into main
Browse files Browse the repository at this point in the history
  • Loading branch information
agosztolai committed Feb 22, 2024
2 parents fec5460 + 6615440 commit 80b7d58
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 12 deletions.
4 changes: 2 additions & 2 deletions MARBLE/geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def furthest_point_sampling(x, N=None, spacing=0.0, start_idx=0):
return perm, lambdas


def cluster(x, cluster_typ="meanshift", n_clusters=15, seed=0):
def cluster(x, cluster_typ="kmeans", n_clusters=15, seed=0):
"""Cluster data.
Args:
Expand All @@ -78,7 +78,7 @@ def cluster(x, cluster_typ="meanshift", n_clusters=15, seed=0):
"""
clusters = {}
if cluster_typ == "kmeans":
kmeans = KMeans(n_clusters=n_clusters, random_state=seed, n_init="auto").fit(x)
kmeans = KMeans(n_clusters=n_clusters, random_state=seed).fit(x)
clusters["n_clusters"] = n_clusters
clusters["labels"] = kmeans.labels_
clusters["centroids"] = kmeans.cluster_centers_
Expand Down
10 changes: 5 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ See full documentation [here](https://agosztolai.github.io/MARBLE/).

## Installation

The code is tested for CPU and GPU (CUDA) machines running Linux or OSX. Although smaller examples run fast on CPU, for larger datasets, it is highly recommended that you use a GPU machine.
The code is tested for CPU and GPU (CUDA) machines running Linux, Mac OSX or Windows. Although smaller examples run fast on CPU, for larger datasets, it is highly recommended that you use a GPU machine.

We recommend you install the code in a fresh Anaconda virtual environment, as follows.

Expand Down Expand Up @@ -76,14 +76,14 @@ We suggest you study at least the example of a [simple vector fields over flat s

Briefly, MARBLE takes two inputs

1. `pos` - a list of `nxd` arrays, each defining a point cloud describing the geometry of a manifold
2. `x` - a list of `nxD` arrays, defining a signal over the respective manifolds in 1. For dynamical systems, D=d, but our code can also handle signals of other dimensions. Read more about [inputs](#inputs) and [different conditions](#conditions).
1. `pos` - a list of `nxd` arrays, each defining a cloud of anchor points describing the geometry of a manifold
2. `x` - a list of `nxD` arrays, defining a vector signal over the respective manifolds in 1. For dynamical systems, D=d, but our code can also handle signals of other dimensions. Read more about [inputs](#inputs) and [different conditions](#conditions).

Using these inputs, you can construct a dataset for MARBLE.

```
import MARBLE
data = MARBLE.construct_dataset(pos, features=x)
data = MARBLE.construct_dataset(anchor=pos, vector=x)
```

The main attributes are `data.pos` - manifold positions concatenated, `data.x` - manifold signals concatenated and `data.y` - identifiers that tell you which manifold the point belongs to. Read more about [other usedul data attributed](#construct).
Expand All @@ -96,7 +96,7 @@ model = MARBLE.net(data)
model.fit(data)
```

By default, MARBLE operated in geometry-aware mode. You can enable the geometry-agnostic mode by changing the initialisation step to
By default, MARBLE operates in geometry-aware mode. You can enable the geometry-agnostic mode by changing the initialisation step to

```
model = MARBLE.net(data, params = {'inner_product_features': True})
Expand Down
5 changes: 3 additions & 2 deletions examples/macaque_reaching/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ The scripts and notebooks are as follows.

- convert_data.py: convert original spiking and kinematics data to data compatible with MARBLE (both input and generated data are on dataverse)
- run_marble.py: train MARBLE networks (this takes a while to run - 1-2h/session)
- analysis.ipynb: create plots of some panels of Figure 3

3. Decoding into kinematics
3. If you want to simply reproduce the results, run - analysis.ipynb: create plots of some panels of Figure 3

4. Decoding into kinematics
- decoding.ipynb
6 changes: 3 additions & 3 deletions examples/toy_examples/vanderpol.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@
}
],
"source": [
"data = MARBLE.construct_dataset(pos, features=vel, stop_crit=0.03)"
"data = MARBLE.construct_dataset(anchor=pos, vector=vel, spacing=0.03)"
]
},
{
Expand Down Expand Up @@ -513,7 +513,7 @@
"mus = np.linspace(-0.1, 0.1, n_mus)\n",
"pos, vel = get_pos_vel(mus)\n",
"\n",
"data = MARBLE.construct_dataset(pos, features=vel, k=20, stop_crit=0.03)\n",
"data = MARBLE.construct_dataset(anchor=pos, vector=vel, k=20, spacing=0.03)\n",
"model = MARBLE.net(data, params=params)\n",
"model.fit(data, outdir='model_zoom')"
]
Expand Down Expand Up @@ -759,7 +759,7 @@
"mus = np.linspace(-0.1, 0.1, n_mus)\n",
"pos, vel = get_pos_vel_noise(mus, alpha=0.2)\n",
"\n",
"data = MARBLE.construct_dataset(pos, features=vel, k=20, stop_crit=0.03)\n",
"data = MARBLE.construct_dataset(anchor=pos, vector=vel, k=20, spacing=0.03)\n",
"\n",
"params['inner_product_features'] = True\n",
"model = MARBLE.net(data, params=params)\n",
Expand Down

0 comments on commit 80b7d58

Please sign in to comment.