Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: burn-train in the browser #938

Closed
wants to merge 87 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
87 commits
Select commit Hold shift + click to select a range
5a030b9
pnpm create vite train-web --template vanilla-ts
AlexErrant Nov 2, 2023
2f09310
pnpm i
AlexErrant Nov 2, 2023
408d740
pnpm up -rL
AlexErrant Nov 2, 2023
8eab3a6
pnpm i -D prettier
AlexErrant Nov 2, 2023
74b9b57
add .prettierrc
AlexErrant Nov 2, 2023
86c6a06
pnpm exec prettier . --write
AlexErrant Nov 2, 2023
bddd331
move everything to web folder
AlexErrant Nov 2, 2023
501b360
update workspace
AlexErrant Nov 2, 2023
9ff1615
cargo new train --lib
AlexErrant Nov 2, 2023
0a279de
pnpm add ../train/pkg
AlexErrant Nov 3, 2023
0af481a
add vite config
AlexErrant Nov 3, 2023
2740d4c
can call train's run from web
AlexErrant Nov 4, 2023
cc423e6
add train.rs
AlexErrant Nov 4, 2023
40f4140
make dev friendly
AlexErrant Nov 4, 2023
a77829d
implemented spawn using webworkers
AlexErrant Nov 9, 2023
2e074e7
update notices
AlexErrant Nov 9, 2023
1cf8dd1
Arc => Rc
AlexErrant Nov 9, 2023
f688a3d
clippy
AlexErrant Nov 9, 2023
13f26c9
fix CI
AlexErrant Nov 9, 2023
7fab3a0
Merge branch 'main' into train-browser
AlexErrant Nov 12, 2023
11cb46b
pnpm i sql.js
AlexErrant Nov 13, 2023
5fa4463
add postinstall to cp sql-wasm.wasm into assets
AlexErrant Nov 13, 2023
94d156d
can load mnist data in js and send to rust
AlexErrant Nov 13, 2023
66a7ffd
extract init
AlexErrant Nov 14, 2023
7735e47
autoload/run mnist on pageload
AlexErrant Nov 14, 2023
2d5f8c0
clean up resources
AlexErrant Nov 14, 2023
3d863e7
can create a DataLoader
AlexErrant Nov 15, 2023
0b0fd86
can load train and test
AlexErrant Nov 15, 2023
6331ede
copied over model
AlexErrant Nov 15, 2023
469e270
copy over guide's `training.rs`, more or less
AlexErrant Nov 15, 2023
939f3d0
docs
AlexErrant Nov 15, 2023
437cf00
assert is valid png
AlexErrant Nov 15, 2023
e3c84a6
add pool from https://github.com/rustwasm/wasm-bindgen/blob/main/exam…
AlexErrant Nov 22, 2023
9a355f2
pool works
AlexErrant Nov 24, 2023
ea1cfe1
Merge branch 'main' into train-browser
AlexErrant Nov 29, 2023
70e1360
fix conflict
AlexErrant Nov 29, 2023
a1f5828
it builds
AlexErrant Nov 30, 2023
9f2285e
add rayon and wasm-bindgen-rayon
AlexErrant Dec 1, 2023
6d926e8
replace with rayon
AlexErrant Dec 1, 2023
d2672b4
Merge branch 'main' into train-browser
AlexErrant Dec 1, 2023
8035870
fix conflicts
AlexErrant Dec 1, 2023
7b27871
nix license notice
AlexErrant Dec 1, 2023
6a76157
Merge branch 'main' into train-browser
AlexErrant Dec 1, 2023
da36016
fix conflicts
AlexErrant Dec 1, 2023
9aa0439
fix
AlexErrant Dec 1, 2023
051b41f
fix?
AlexErrant Dec 1, 2023
34d6b3a
Merge branch 'main' into train-browser
AlexErrant Dec 1, 2023
7a8372c
fix??
AlexErrant Dec 1, 2023
2c9dc06
runchecks runs each package separately
AlexErrant Dec 2, 2023
1ee5cef
skip train-web
AlexErrant Dec 2, 2023
3f52a78
fix deps
AlexErrant Dec 3, 2023
f188b7a
converted train to a worker (as to not block main thread)
AlexErrant Nov 16, 2023
37a748e
add autotrain checkbox
AlexErrant Nov 16, 2023
fa63f6d
MetricsRenderer logs
AlexErrant Nov 16, 2023
f8ab1d4
add to_bytes
AlexErrant Dec 3, 2023
9e9dfb6
comment out off by one check
AlexErrant Nov 16, 2023
0a9a467
don't hardcode dirs
AlexErrant Dec 3, 2023
cfffb93
oops
AlexErrant Dec 3, 2023
d4106bb
fmt
AlexErrant Dec 3, 2023
c84d171
exclude `examples/train-web/train` less
AlexErrant Dec 4, 2023
3541503
I regret everything (in this file)
AlexErrant Dec 4, 2023
65724c8
Merge branch 'main' into train-browser
AlexErrant Dec 4, 2023
15af6fd
fix ci
AlexErrant Dec 4, 2023
49e2c38
rustup component add rust-src --toolchain nightly-2023-07-01-x86_64-u…
AlexErrant Dec 4, 2023
519b752
0
AlexErrant Dec 4, 2023
6d8c721
Revert "0"
AlexErrant Dec 4, 2023
c025d35
Revert "rustup component add rust-src --toolchain nightly-2023-07-01-…
AlexErrant Dec 4, 2023
b0fbeef
can I just get a ✔ please
AlexErrant Dec 5, 2023
57472fe
Revert "Revert "rustup component add rust-src --toolchain nightly-202…
AlexErrant Dec 4, 2023
8fa5a03
not always linux
AlexErrant Dec 5, 2023
1b6d619
1
AlexErrant Dec 5, 2023
c8575e3
2
AlexErrant Dec 5, 2023
de55ae8
Revert "can I just get a ✔ please"
AlexErrant Dec 5, 2023
bebab5e
3
AlexErrant Dec 5, 2023
64a11b5
4
AlexErrant Dec 5, 2023
944e53a
5
AlexErrant Dec 5, 2023
1bddd63
6
AlexErrant Dec 5, 2023
e56d489
Merge branch 'main' into train-browser
AlexErrant Dec 23, 2023
218e17d
fix breaking changes
AlexErrant Dec 24, 2023
b96df97
make windows work?
AlexErrant Dec 24, 2023
a92392f
Merge branch 'main' into train-browser
AlexErrant Jan 28, 2024
d3cabca
reject CI changes
AlexErrant Jan 28, 2024
2d0a488
less CI changes
AlexErrant Jan 28, 2024
f053b4d
bump wasm-bindgen
AlexErrant Jan 29, 2024
116d3f5
add generic
AlexErrant Jan 29, 2024
3b53c32
it builds
AlexErrant Jan 29, 2024
a527fae
add AsyncTask(Boxed)
AlexErrant Jan 29, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
[workspace]
# Try
# require version 2 to avoid "feature" additiveness for dev-dependencies
# https://doc.rust-lang.org/cargo/reference/resolver.html#feature-resolver-version-2
resolver = "2"
Expand Down Expand Up @@ -28,10 +27,11 @@ members = [
"xtask",
"examples/*",
"examples/pytorch-import/model",
"examples/train-web/train",
"backend-comparison",
]

exclude = ["examples/notebook"]
exclude = ["examples/notebook", "examples/train-web"]

[workspace.package]
edition = "2021"
Expand Down Expand Up @@ -86,7 +86,7 @@ tokio = { version = "1.35.1", features = ["rt", "macros"] }
tracing-appender = "0.2.3"
tracing-core = "0.1.32"
tracing-subscriber = "0.3.18"
wasm-bindgen = "0.2.88"
wasm-bindgen = "=0.2.90"
wasm-bindgen-futures = "0.4.38"
wasm-logger = "0.2.0"
wasm-timer = "0.2.5"
Expand Down
9 changes: 9 additions & 0 deletions burn-core/src/module/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,15 @@ pub trait Module<B: Backend>: Clone + Send + Sync + core::fmt::Debug {
recorder.record(record, file_path.into())
}

/// Return the module using [BytesRecorder](crate::record::BytesRecorder).
fn to_bytes<BR: crate::record::BytesRecorder<B>>(
self,
recorder: &BR,
) -> Result<Vec<u8>, crate::record::RecorderError> {
let record = Self::into_record(self);
recorder.record(record, ())
}

#[cfg(feature = "std")]
/// Load the module from a file using the provided [file recorder](crate::record::FileRecorder).
///
Expand Down
24 changes: 22 additions & 2 deletions burn-train/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,20 @@ repository = "https://github.com/tracel-ai/burn/tree/main/burn-train"
version.workspace = true

[features]
default = ["metrics", "tui"]
default = ["metrics", "tui", "burn-core/default", "burn-core/dataset"]
metrics = ["nvml-wrapper", "sysinfo", "systemstat"]
tui = ["ratatui", "crossterm"]
browser = [
"js-sys",
"web-sys",
"wasm-bindgen",
"wasm-bindgen-rayon",
"rayon",
"burn-core/std",
]

[dependencies]
burn-core = { path = "../burn-core", version = "0.12.0", features = ["dataset"] }
burn-core = { path = "../burn-core", version = "0.12.0", default-features = false }

log = { workspace = true }
tracing-subscriber = { workspace = true }
Expand All @@ -36,5 +44,17 @@ crossterm = { version = "0.27", optional = true }
derive-new = { workspace = true }
serde = { workspace = true, features = ["std", "derive"] }

js-sys = { version = "0.3.64", optional = true }
web-sys = { version = "0.3.64", optional = true, features = [
"Worker",
"WorkerOptions",
"WorkerType",
"MessageEvent",
"ErrorEvent",
] }
wasm-bindgen = { workspace = true, optional = true }
wasm-bindgen-rayon = { version = "1.0.3", optional = true }
rayon = { workspace = true, optional = true }

[dev-dependencies]
burn-ndarray = { path = "../burn-ndarray", version = "0.12.0" }
4 changes: 2 additions & 2 deletions burn-train/src/checkpoint/strategy/metric.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ mod tests {
},
TestBackend,
};
use std::sync::Arc;
use std::rc::Rc;

use super::*;

Expand All @@ -93,7 +93,7 @@ mod tests {
store.register_logger_train(InMemoryMetricLogger::default());
// Register the loss metric.
metrics.register_train_metric_numeric(LossMetric::<TestBackend>::new());
let store = Arc::new(EventStoreClient::new(store));
let store = Rc::new(EventStoreClient::new(store));
let mut processor = MinimalEventProcessor::new(metrics, store.clone());

// Two points for the first epoch. Mean 0.75
Expand Down
3 changes: 2 additions & 1 deletion burn-train/src/learner/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use burn_core::module::Module;
use burn_core::optim::Optimizer;
use burn_core::tensor::backend::Backend;
use burn_core::tensor::Device;
use std::rc::Rc;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;

Expand All @@ -25,7 +26,7 @@ pub struct Learner<LC: LearnerComponents> {
pub(crate) interrupter: TrainingInterrupter,
pub(crate) early_stopping: Option<Box<dyn EarlyStoppingStrategy>>,
pub(crate) event_processor: LC::EventProcessor,
pub(crate) event_store: Arc<EventStoreClient>,
pub(crate) event_store: Rc<EventStoreClient>,
}

#[derive(new)]
Expand Down
4 changes: 2 additions & 2 deletions burn-train/src/learner/builder.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::sync::Arc;
use std::rc::Rc;

use super::log::install_file_logger;
use super::Learner;
Expand Down Expand Up @@ -313,7 +313,7 @@ where
));
}

let event_store = Arc::new(EventStoreClient::new(self.event_store));
let event_store = Rc::new(EventStoreClient::new(self.event_store));
let event_processor = FullEventProcessor::new(self.metrics, renderer, event_store.clone());

let checkpointer = self.checkpointers.map(|(model, optim, scheduler)| {
Expand Down
4 changes: 2 additions & 2 deletions burn-train/src/learner/early_stopping.rs
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ impl MetricEarlyStoppingStrategy {

#[cfg(test)]
mod tests {
use std::sync::Arc;
use std::rc::Rc;

use crate::{
logger::InMemoryMetricLogger,
Expand Down Expand Up @@ -197,7 +197,7 @@ mod tests {
store.register_logger_train(InMemoryMetricLogger::default());
metrics.register_train_metric_numeric(LossMetric::<TestBackend>::new());

let store = Arc::new(EventStoreClient::new(store));
let store = Rc::new(EventStoreClient::new(store));
let mut processor = MinimalEventProcessor::new(metrics, store.clone());

let mut epoch = 1;
Expand Down
2 changes: 2 additions & 0 deletions burn-train/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

//! A library for training neural networks using the burn crate.

pub mod util;

#[macro_use]
extern crate derive_new;

Expand Down
6 changes: 3 additions & 3 deletions burn-train/src/metric/processor/full.rs
Original file line number Diff line number Diff line change
@@ -1,22 +1,22 @@
use super::{Event, EventProcessor, Metrics};
use crate::metric::store::EventStoreClient;
use crate::renderer::{MetricState, MetricsRenderer};
use std::sync::Arc;
use std::rc::Rc;

/// An [event processor](EventProcessor) that handles:
/// - Computing and storing metrics in an [event store](crate::metric::store::EventStore).
/// - Render metrics using a [metrics renderer](MetricsRenderer).
pub struct FullEventProcessor<T, V> {
metrics: Metrics<T, V>,
renderer: Box<dyn MetricsRenderer>,
store: Arc<EventStoreClient>,
store: Rc<EventStoreClient>,
}

impl<T, V> FullEventProcessor<T, V> {
pub(crate) fn new(
metrics: Metrics<T, V>,
renderer: Box<dyn MetricsRenderer>,
store: Arc<EventStoreClient>,
store: Rc<EventStoreClient>,
) -> Self {
Self {
metrics,
Expand Down
4 changes: 2 additions & 2 deletions burn-train/src/metric/processor/minimal.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
use super::{Event, EventProcessor, Metrics};
use crate::metric::store::EventStoreClient;
use std::sync::Arc;
use std::rc::Rc;

/// An [event processor](EventProcessor) that handles:
/// - Computing and storing metrics in an [event store](crate::metric::store::EventStore).
#[derive(new)]
pub(crate) struct MinimalEventProcessor<T, V> {
metrics: Metrics<T, V>,
store: Arc<EventStoreClient>,
store: Rc<EventStoreClient>,
}

impl<T, V> EventProcessor for MinimalEventProcessor<T, V> {
Expand Down
8 changes: 5 additions & 3 deletions burn-train/src/metric/store/client.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
use super::EventStore;
use super::{Aggregate, Direction, Event, Split};
use std::{sync::mpsc, thread::JoinHandle};
use crate::util::{self, AsyncTaskBoxed};
use log::info;
use std::sync::mpsc;

/// Type that allows to communicate with an [event store](EventStore).
pub struct EventStoreClient {
sender: mpsc::Sender<Message>,
handler: Option<JoinHandle<()>>,
handler: Option<AsyncTaskBoxed>,
}

impl EventStoreClient {
Expand All @@ -17,7 +19,7 @@ impl EventStoreClient {
let (sender, receiver) = mpsc::channel();
let thread = WorkerThread::new(store, receiver);

let handler = std::thread::spawn(move || thread.run());
let handler = util::spawn(move || thread.run());
let handler = Some(handler);

Self { sender, handler }
Expand Down
46 changes: 46 additions & 0 deletions burn-train/src/util.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
#![allow(missing_docs)]

pub trait AsyncTask {
fn join(self: Box<Self>) -> Result<(), ()>;
}

pub type AsyncTaskBoxed = Box<dyn AsyncTask>;

struct Thread {
join: Box<dyn FnOnce() -> Result<(), ()>>,
}

impl AsyncTask for Thread {
fn join(self: Box<Self>) -> Result<(), ()> {
(self.join)()
}
}

impl Thread {
fn new(join: Box<dyn FnOnce() -> Result<(), ()>>) -> Self {
Thread { join }
}
}

#[cfg(not(feature = "browser"))]
pub fn spawn<F>(f: F) -> AsyncTaskBoxed
where
F: FnOnce(),
F: Send + 'static,
{
let handle = std::thread::spawn(f);
Box::new(Thread::new(Box::new(move || handle.join().map_err(|_| ()))))
}

#[cfg(feature = "browser")]
pub fn spawn<F>(f: F) -> AsyncTaskBoxed
where
F: FnOnce(),
F: Send + 'static,
{
rayon::spawn(f);
Box::new(Thread::new(Box::new(|| Ok(()))))
}

#[cfg(feature = "browser")]
pub use wasm_bindgen_rayon::init_thread_pool;
2 changes: 2 additions & 0 deletions burn/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ metrics = ["burn-train?/metrics"]
# Useful when targeting WASM and not using WGPU.
wasm-sync = ["burn-core/wasm-sync"]

browser = ["burn-train/browser"]

# Datasets
dataset = ["burn-core/dataset"]

Expand Down
18 changes: 18 additions & 0 deletions examples/train-web/readme.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
## Getting Started

Your `wasm-bindgen-cli` version must *exactly* match the `wasm-bindgen` version in [Cargo.toml](../../Cargo.toml) since `wasm-bindgen-cli` is implicitly used by `wasm-pack`.

For example, run `cargo install --version 0.2.88 wasm-bindgen-cli --force`. The version in this example command is not guaranteed to be up to date!

Install [PNPM](https://pnpm.io/).

Install [cargo-watch](https://crates.io/crates/cargo-watch).

The [`postinstall.sh`](./web/postinstall.sh) script expects the mnist database to be at `~/.cache/burn-dataset/mnist.db`. Running `burn/examples/guide` will generate this file. Alternatively, you can download it from [Hugging Face](https://huggingface.co/datasets/mnist).

Then in separate terminals:

1. `cd train && dev.sh`
2. `cd web && pnpm i && pnpm dev`

Any changes to `/train` or `burn` should trigger a recompilation. When a new binary is generated, `web` will automatically refresh the page.
6 changes: 6 additions & 0 deletions examples/train-web/train/.cargo/config.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
[unstable]
build-std = ['std', 'panic_abort']

[build]
target = "wasm32-unknown-unknown"
rustflags = '-Ctarget-feature=+atomics,+bulk-memory,+mutable-globals'
1 change: 1 addition & 0 deletions examples/train-web/train/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
pkg
22 changes: 22 additions & 0 deletions examples/train-web/train/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
[package]
name = "train-web"
version = "0.1.0"
edition = "2021"

[lib]
crate-type = ["cdylib"]

[dependencies]
wasm-bindgen = { workspace = true }
log = { workspace = true }
console_error_panic_hook = "0.1.7"
console_log = { version = "1", features = ["color"] }
burn = { path = "../../../burn", default-features = false, features = [
"autodiff",
"train",
"ndarray",
"wasm-sync",
"browser",
] }
serde = { workspace = true }
image = { version = "0.24.7", features = ["png"] }
24 changes: 24 additions & 0 deletions examples/train-web/train/build-for-web.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
#!/usr/bin/env bash

set -eo pipefail # https://stackoverflow.com/a/2871034

# Add wasm32 target for compiler.
rustup target add wasm32-unknown-unknown

if ! command -v wasm-pack &>/dev/null; then
echo "wasm-pack could not be found. Installing ..."
cargo install wasm-pack
exit
fi

# Set optimization flags
if [[ $1 == "release" ]]; then
export RUSTFLAGS="-C lto=fat -C embed-bitcode=yes -C codegen-units=1 -C opt-level=3 --cfg web_sys_unstable_apis"
else
# sets $1 to "dev"
set -- dev
fi

# Run wasm pack tool to build JS wrapper files and copy wasm to pkg directory.
mkdir -p pkg
wasm-pack build --out-dir pkg --$1 --target web --no-default-features
3 changes: 3 additions & 0 deletions examples/train-web/train/dev.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
#!/usr/bin/env bash

cargo watch -- ./build-for-web.sh
2 changes: 2 additions & 0 deletions examples/train-web/train/rust-toolchain.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
[toolchain]
channel = "nightly-2023-07-01"
Loading
Loading