Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Added SBI tutorial for parameter inference with HNN-core #856

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
197 changes: 197 additions & 0 deletions examples/howto/sbi_hnncore_tutorial.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,197 @@
"""
==========================================
09. SBI with HNN-core: Parameter Inference
==========================================

This tutorial demonstrates how to use Simulation-Based Inference (SBI)
with HNN-core to infer network parameters from simulated data. We'll
simulate neural network activity, then use SBI to infer the parameters
that generated this activity.
"""

# Authors: Abdul Samad Siddiqui <[email protected]>
# Nick Tolley <[email protected]>
# Ryan Thorpe <[email protected]>
# Mainak Jas <[email protected]>

###############################################################################
# First, we need to import the necessary packages. The Simulation-Based
# Inference (SBI) package can be installed with ``pip install sbi``.
# Let us import ``hnn_core`` and all the necessary libraries.

import numpy as np
import torch
from hnn_core.batch_simulate import BatchSimulate
from sbi import utils as sbi_utils
from sbi import inference as sbi_inference
import matplotlib.pyplot as plt
from hnn_core.network_models import jones_2009_model

###############################################################################
# Now, we'll set up our simulation parameters. We're using a small number of
# simulations for this example, but in practice, you might want to use more.

n_simulations = 100
n_jobs = 20

###############################################################################
# This function sets the parameters for our neural network model. It adds an
# evoked drive to the network with specific weights and delays.
samadpls marked this conversation as resolved.
Show resolved Hide resolved
#
# The `add_evoked_drive` function simulates external input to the network,
# mimicking sensory stimulation or other external events.
#
# - `evprox` indicates a proximal drive, targeting dendrites near the cell
# bodies.
# - `mu=40` and `sigma=5` define the timing (mean and spread) of the input.
# - `weights_ampa` and `synaptic_delays` control the strength and
# timing of the input.
#
# This evoked drive causes the initial positive deflection in the dipole
# signal, triggering a cascade of activity through the network and
# resulting in the complex waveforms observed.


def set_params(param_values, net=None):
weight_pyr = 10**float(param_values['weight_pyr'])
weights_ampa = {'L5_pyramidal': weight_pyr}
synaptic_delays = {'L5_pyramidal': 1.}

net.add_evoked_drive('evprox',
mu=40,
sigma=5,
numspikes=1,
location='proximal',
weights_ampa=weights_ampa,
synaptic_delays=synaptic_delays)
return net

###############################################################################
# Here, we generate our parameter grid and run the simulations. We're varying
# the 'weight_pyr' parameter between 1e-4 and 1e-1.


val = np.linspace(-4, -1, n_simulations)
param_grid = {
'weight_pyr': val.tolist()
}

net = jones_2009_model(mesh_shape=(3, 3))
batch_simulator = BatchSimulate(set_params=set_params,
net=net,
tstop=170)
simulation_results = batch_simulator.run(
param_grid, n_jobs=n_jobs, combinations=False)

print(
f"Number of simulations run: {len(simulation_results['simulated_data'])}")

###############################################################################
# This function extracts the dipole data from our simulation results. Dipole
# data represents the aggregate electrical activity of the neural population.


def extract_dipole_data(sim_results):
dipole_data_list = []
for result in sim_results['simulated_data']:
for sim_data in result:
dipole_data_list.append(sim_data['dpl'][0].data['agg'])
return dipole_data_list


dipole_data = extract_dipole_data(simulation_results)

###############################################################################
# Now we prepare our data for the SBI algorithm. `thetas` are our parameters,
# and `xs` are our observed data (the dipole activity). These will be used by
# the SBI algorithm to learn the relationship between parameters and
# the resulting neural activity.

thetas = torch.tensor(param_grid['weight_pyr'],
dtype=torch.float32).reshape(-1, 1)
xs = torch.stack([torch.tensor(data, dtype=torch.float32)
for data in dipole_data])

###############################################################################
# Here we set up our SBI inference. We define a prior distribution for our
# parameter and create our inference object.

prior = sbi_utils.BoxUniform(
low=torch.tensor([-4]), high=torch.tensor([-1]))
inference = sbi_inference.SNPE(prior=prior)
density_estimator = inference.append_simulations(thetas, xs).train()
posterior = inference.build_posterior(density_estimator)

###############################################################################
# The prior distribution represents our initial guess about the range of
# possible values for `weight_pyr`. The SBI algorithm will use this prior,
# along with the simulated data, to build a posterior distribution, which
# represents our updated belief about `weight_pyr` after seeing the data.
#
# This function allows us to simulate data for a single parameter value.


def simulator_batch(param):
param_grid_single = {'weight_pyr': [float(param)]}
results = batch_simulator.run(
param_grid_single, n_jobs=1, combinations=False)
return torch.tensor(extract_dipole_data(results)[0], dtype=torch.float32)

###############################################################################
# Now we'll infer parameters for "unknown" data. We generate this data using
# a parameter value that we pretend we don't know.


np.random.seed(42)
unknown_param = torch.tensor([[np.random.choice(np.linspace(-4, -1, 100))]])
x_o = simulator_batch(unknown_param.item())
samples = posterior.sample((1000,), x=x_o)

print(f"True (unknown) parameter: {unknown_param.item()}")
print(f"Inferred parameter (mean): {samples.mean().item()}")
print(f"Inferred parameter (median): {samples.median().item()}")

###############################################################################
# Let's visualize the posterior distribution of our inferred parameters.
samadpls marked this conversation as resolved.
Show resolved Hide resolved

plt.figure(figsize=(10, 6))
plt.hist(samples.numpy(), bins=30, density=True, alpha=0.7)
plt.xlabel('Parameter Value')
plt.ylabel('Density')
plt.title('Posterior Distribution of Parameters')
plt.axvline(unknown_param.item(), color='r', linestyle='dashed',
linewidth=2, label='True Parameter')
plt.legend()
plt.xlim([-4, -1])
plt.show()

###############################################################################
# This plot shows the posterior distribution of the inferred parameter values.
# If the inferred posterior distribution is centered around the true parameter
# value, it suggests that the SBI method is accurately capturing the underlying
# parameter. The red dashed line represents the true parameter value.
#
# Finally, we'll evaluate the performance of our SBI method on multiple
# unseen parameter values.

unseen_params = np.linspace(-4, -1, 10)
unseen_data = [simulator_batch(param) for param in unseen_params]
unseen_samples = [posterior.sample((100,), x=x) for x in unseen_data]

plt.figure(figsize=(12, 6))
plt.boxplot([samples.numpy().flatten() for samples in unseen_samples],
positions=unseen_params, widths=0.05, vert=True)
plt.xlabel('True Parameter')
plt.ylabel('Inferred Parameter')
plt.title('SBI Performance on Unseen Data')
plt.xticks(ticks=unseen_params, labels=[f'{param:.2f}'
for param in unseen_params])
plt.xlim(-4.1, -0.9)
plt.show()

###############################################################################
# This boxplot visualizes the distribution of inferred parameters for each
# unseen true parameter value. The true parameters are shown on the x-axis,
# and the boxes represent the spread of inferred values. If the inferred
# parameters closely follow the diagonal (where inferred = true), it indicates
# that the SBI method is performing well.
Loading