Skip to content

Commit

Permalink
pyodide web: add demo scatter
Browse files Browse the repository at this point in the history
  • Loading branch information
pthom committed Oct 15, 2024
1 parent be2a817 commit 02a33c8
Show file tree
Hide file tree
Showing 4 changed files with 77 additions and 1 deletion.
4 changes: 4 additions & 0 deletions pyodide_web_demo/examples/examples.json
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@
{
"label": "Fiatlight: Word Count",
"filename": "fl_word_count.py"
},
{
"label": "Fiatlight: Draw Scatter",
"filename": "scatter_fiatlight.py"
}
]
}
67 changes: 67 additions & 0 deletions pyodide_web_demo/examples/scatter_fiatlight.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import matplotlib
matplotlib.use('Agg')

from matplotlib.figure import Figure
from matplotlib import pyplot as plt
from sklearn.linear_model import LogisticRegression
from sklearn.inspection import DecisionBoundaryDisplay
from sklearn.tree import DecisionTreeClassifier
import pandas as pd
import numpy as np
from enum import Enum
import fiatlight as fl


from scatter_widget_bundle import ScatterData

class DecisionStrategy(Enum):
logistic_regression = LogisticRegression
decision_tree = DecisionTreeClassifier


def plot_boundary(df: pd.DataFrame, strategy: DecisionStrategy, eps: float=1.0) -> Figure | None:
if len(df) and (df['color'].nunique() > 1):
X = df[['x', 'y']].values
y = df['color']
fig, ax = plt.subplots()
if strategy == DecisionStrategy.logistic_regression:
classifier = LogisticRegression().fit(X, y)
else:
classifier = DecisionTreeClassifier().fit(X, y)
disp = DecisionBoundaryDisplay.from_estimator(
classifier, X,
response_method="predict_proba" if len(np.unique(df['color'])) == 2 else "predict",
xlabel="x", ylabel="y",
#alpha=0.5,
eps=eps,
ax=ax
)
disp.ax_.scatter(X[:, 0], X[:, 1], c=y, edgecolor="k")
ax.set_title(f"{classifier.__class__.__name__}")
return fig
else:
return None



def scatter_source(scatter_data: ScatterData) -> ScatterData:
return scatter_data

@fl.with_fiat_attributes(eps__range=(0.1, 10.0))
def scatter_to_figure(
scatter_data: ScatterData,
strategy: DecisionStrategy = DecisionStrategy.logistic_regression,
eps: float=1.0) -> Figure:
return plot_boundary(scatter_data.data_as_pandas(), strategy, eps)


def scatter_to_df(scatter_data: ScatterData) -> pd.DataFrame:
return scatter_data.data_as_pandas()

graph = fl.FunctionsGraph()
graph.add_function(scatter_source)
graph.add_function(scatter_to_figure)
graph.add_function(scatter_to_df)
graph.add_link(scatter_source, scatter_to_df)
graph.add_link(scatter_source, scatter_to_figure)
fl.run(graph)
5 changes: 4 additions & 1 deletion pyodide_web_demo/js/pyodide_loader.js
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,11 @@ async function loadPyodideAndPackages() {
// 'ipython',
// 'sqlite3',

"pyodide_dist/fiatlight-0.1.0-py3-none-any.whl"
"pyodide_dist/fiatlight-0.1.0-py3-none-any.whl",

"scikit-learn",
"scipy",
"pyodide_dist/scatter_widget_bundle-0.1.0-py3-none-any.whl",
];

const totalSteps = packages.length;
Expand Down
2 changes: 2 additions & 0 deletions pyodide_web_demo/justfile
Original file line number Diff line number Diff line change
Expand Up @@ -42,3 +42,5 @@ build_dist:
srv:
python3 -m http.server 8005

rsync_tq:
rsync -avz --delete . [email protected]:HTML/probabl/

0 comments on commit 02a33c8

Please sign in to comment.