Skip to content

Commit

Permalink
Merge pull request #96 from eruijsena/state_transitions
Browse files Browse the repository at this point in the history
add plots to show the change in maxContrib state
  • Loading branch information
candidechamp authored Oct 3, 2023
2 parents 49a5fed + 1060309 commit c401350
Show file tree
Hide file tree
Showing 4 changed files with 140 additions and 2 deletions.
1 change: 1 addition & 0 deletions devtools/conda-envs/full_env.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ dependencies:
- pandas
- pytables
- matplotlib
- plotly
- mpmath

#Docs
Expand Down
1 change: 1 addition & 0 deletions devtools/conda-envs/test_env.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ dependencies:
- pandas
- pytables
- matplotlib
- plotly


# Pip-only installs
Expand Down
60 changes: 60 additions & 0 deletions reeds/function_libs/analysis/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import pandas as pd

import reeds.function_libs.visualization.sampling_plots
from pygromos.files.repdat import ExpandedRepdat


def undersampling_occurence_potential_threshold_densityClustering(ene_trajs: List[pd.DataFrame],
Expand Down Expand Up @@ -349,7 +350,66 @@ def sampling_analysis(ene_trajs: List[pd.DataFrame],

return final_results, out_path

def analyse_state_transitions(repdat: ExpandedRepdat, min_s: int = None, normalize: bool = False, bidirectional: bool = False):
"""
Count the number of times a transition occurs between pairs of states, based on the repdat info.
Parameters
----------
repdat: ExpandedRepdat
ExpandedRepdat object (created from a Repdat) which contains all the exchange information of a
RE-EDS simulation plus the potential energies of the end-states
min_s: int, optional
Index of the lowest s_value to consider for the transitions. If None, consider all s values.
normalize: bool, optional
Normalize the transitions by the total number of outgoing transitions per state
bidirectional: bool, optional
Count the transitions symmetrically (state A to B together with state B to A)
Returns
-------
np.ndarray
number of transitions between all pairs of states
"""
if normalize and bidirectional:
raise Exception("Transitions cannot be normalized w.r.t leaving state and bidirectional")

num_replicas = len(repdat.system.s)
num_states = len(repdat.system.state_eir)

# Initialize transition counts to zero for all pairs of states
transition_counts = np.zeros((num_states, num_states))

for replica in range(1, num_replicas+1):
# Get exchange data per state
if min_s:
state_repdat = repdat.DATA.query(f"coord_ID == {replica} & ID <= {min_s}")
else:
state_repdat = repdat.DATA.query(f"coord_ID == {replica}")
state_trajectory = state_repdat[["Vmin", "run"]].reset_index(drop=True).copy()

# Count the transitions between different states
for i in range(len(state_trajectory) - 1):
current_state = int("".join([char for char in state_trajectory["Vmin"][i] if char.isdigit()])) # Take the i in Vri
next_state = int("".join([char for char in state_trajectory["Vmin"][i + 1] if char.isdigit()]))
current_run = state_trajectory["run"][i] # Check that you are actually comparing consecutive exchanges
next_run = state_trajectory["run"][i+1]
if next_run == current_run +1 and current_state != next_state:
transition_counts[current_state-1][next_state-1] += 1

if normalize:
# Normalize by total number of transitions per state
tot_trans = np.sum(transition_counts, axis=1)
transition_counts = transition_counts / tot_trans[:, np.newaxis]

elif bidirectional:
# Consider exchanges in both directions together
bidirectional_counts = np.zeros((num_states, num_states))
for state1 in range(len(transition_counts)):
for state2 in range(len(transition_counts[state1])):
bidirectional_counts[state1][state2] += transition_counts[state1][state2]
bidirectional_counts[state2][state1] += transition_counts[state1][state2]
transition_counts = bidirectional_counts

return transition_counts

def detect_undersampling(ene_trajs: List[pd.DataFrame],
state_potential_treshold: List[float],
Expand Down
80 changes: 78 additions & 2 deletions reeds/function_libs/visualization/sampling_plots.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
from typing import List
from typing import Union, List

import numpy as np
from matplotlib import pyplot as plt
from matplotlib.colors import Colormap, to_rgba

import plotly.graph_objects as go
from plotly.colors import convert_to_RGB_255

from reeds.function_libs.visualization import plots_style as ps
from reeds.function_libs.visualization.utils import nice_s_vals

import reeds.function_libs.visualization.plots_style as ps


def plot_sampling_convergence(ene_trajs, opt_trajs, outfile, title = None, trim_beg = 0.1):
"""
Expand Down Expand Up @@ -377,3 +381,75 @@ def plot_stateOccurence_matrix(data: dict,
if (not out_dir is None):
fig.savefig(out_dir + '/sampling_maxContrib_matrix.png', bbox_inches='tight')
plt.close()

def plot_state_transitions(state_transitions: np.ndarray, title: str = None, colors: Union[List[str], Colormap] = ps.qualitative_tab_map, out_path: str = None):
"""
Make a Sankey plot showing the flows between states.
Parameters
----------
state_transitions : np.ndarray
num_states * num_states 2D array containing the number of transitions between states
title: str, optional
printed title of the plot
colors: Union[List[str], Colormap], optional
if you don't like the default colors
out_path: str, optional
path to save the image to. if none, the image is returned as a plotly figure
Returns
-------
None or fig
plotly figure if if was not saved
"""
num_states = len(state_transitions)

if isinstance(colors, Colormap):
colors = [colors(i) for i in np.linspace(0, 1, num_states)]
elif len(colors) < num_states:
raise Exception("Insufficient colors to plot all states")

def v_distribute(total_transitions):
# Vertically distribute nodes in plot based on total number of transitions per state
box_sizes = total_transitions / total_transitions.sum()
box_vplace = [np.sum(box_sizes[:i]) + box_sizes[i]/2 for i in range(len(box_sizes))]
return box_vplace

y_placements = v_distribute(np.sum(state_transitions, axis=1)) + v_distribute(np.sum(state_transitions, axis=0))

# Convert colors to plotly format and make them transparent
rgba_colors = []
for color in colors:
rgba = to_rgba(color)
rgba_plotly = convert_to_RGB_255(rgba[:-1])
# Add opacity
rgba_plotly = rgba_plotly + (0.8,)
# Make string
rgba_colors.append("rgba" + str(rgba_plotly))

# Indices 0..n-1 are the source and n..2n-1 are the target.
fig = go.Figure(data=[go.Sankey(
node = dict(
pad = 5,
thickness = 20,
line = dict(color = "black", width = 2),
label = [f"state {i+1}" for i in range(num_states)]*2,
color = rgba_colors[:num_states]*2,
x = [0.1]*num_states + [1]*num_states,
y = y_placements
),
link = dict(
arrowlen = 30,
source = np.array([[i]*num_states for i in range(num_states)]).flatten(),
target = np.array([[i for i in range(num_states, 2*num_states)] for _ in range(num_states)]).flatten(),
value = state_transitions.flatten(),
color = np.array([[c]*num_states for c in rgba_colors[:num_states]]).flatten()
),
arrangement="fixed",
)])
fig.update_layout(title_text=title, font_size=20, title_x=0.5, height=max(600, num_states*100))

if out_path:
fig.write_image(out_path)
return None
else:
return fig

0 comments on commit c401350

Please sign in to comment.