From 5a030b9445bc84796ef37ac3c176392b0d774690 Mon Sep 17 00:00:00 2001 From: Alex Date: Thu, 2 Nov 2023 16:52:06 -0500 Subject: [PATCH 01/79] pnpm create vite train-web --template vanilla-ts --- examples/train-web/.gitignore | 24 +++++++ examples/train-web/index.html | 13 ++++ examples/train-web/package.json | 15 +++++ examples/train-web/public/vite.svg | 1 + examples/train-web/src/counter.ts | 9 +++ examples/train-web/src/main.ts | 24 +++++++ examples/train-web/src/style.css | 97 +++++++++++++++++++++++++++ examples/train-web/src/typescript.svg | 1 + examples/train-web/src/vite-env.d.ts | 1 + examples/train-web/tsconfig.json | 23 +++++++ 10 files changed, 208 insertions(+) create mode 100644 examples/train-web/.gitignore create mode 100644 examples/train-web/index.html create mode 100644 examples/train-web/package.json create mode 100644 examples/train-web/public/vite.svg create mode 100644 examples/train-web/src/counter.ts create mode 100644 examples/train-web/src/main.ts create mode 100644 examples/train-web/src/style.css create mode 100644 examples/train-web/src/typescript.svg create mode 100644 examples/train-web/src/vite-env.d.ts create mode 100644 examples/train-web/tsconfig.json diff --git a/examples/train-web/.gitignore b/examples/train-web/.gitignore new file mode 100644 index 0000000000..a547bf36d8 --- /dev/null +++ b/examples/train-web/.gitignore @@ -0,0 +1,24 @@ +# Logs +logs +*.log +npm-debug.log* +yarn-debug.log* +yarn-error.log* +pnpm-debug.log* +lerna-debug.log* + +node_modules +dist +dist-ssr +*.local + +# Editor directories and files +.vscode/* +!.vscode/extensions.json +.idea +.DS_Store +*.suo +*.ntvs* +*.njsproj +*.sln +*.sw? diff --git a/examples/train-web/index.html b/examples/train-web/index.html new file mode 100644 index 0000000000..44a933506f --- /dev/null +++ b/examples/train-web/index.html @@ -0,0 +1,13 @@ + + + + + + + Vite + TS + + +
+ + + diff --git a/examples/train-web/package.json b/examples/train-web/package.json new file mode 100644 index 0000000000..9d90e0784e --- /dev/null +++ b/examples/train-web/package.json @@ -0,0 +1,15 @@ +{ + "name": "train-web", + "private": true, + "version": "0.0.0", + "type": "module", + "scripts": { + "dev": "vite", + "build": "tsc && vite build", + "preview": "vite preview" + }, + "devDependencies": { + "typescript": "^5.0.2", + "vite": "^4.4.5" + } +} diff --git a/examples/train-web/public/vite.svg b/examples/train-web/public/vite.svg new file mode 100644 index 0000000000..e7b8dfb1b2 --- /dev/null +++ b/examples/train-web/public/vite.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/examples/train-web/src/counter.ts b/examples/train-web/src/counter.ts new file mode 100644 index 0000000000..09e5afd2d8 --- /dev/null +++ b/examples/train-web/src/counter.ts @@ -0,0 +1,9 @@ +export function setupCounter(element: HTMLButtonElement) { + let counter = 0 + const setCounter = (count: number) => { + counter = count + element.innerHTML = `count is ${counter}` + } + element.addEventListener('click', () => setCounter(counter + 1)) + setCounter(0) +} diff --git a/examples/train-web/src/main.ts b/examples/train-web/src/main.ts new file mode 100644 index 0000000000..791547b0d3 --- /dev/null +++ b/examples/train-web/src/main.ts @@ -0,0 +1,24 @@ +import './style.css' +import typescriptLogo from './typescript.svg' +import viteLogo from '/vite.svg' +import { setupCounter } from './counter.ts' + +document.querySelector('#app')!.innerHTML = ` +
+ + + + + + +

Vite + TypeScript

+
+ +
+

+ Click on the Vite and TypeScript logos to learn more +

+
+` + +setupCounter(document.querySelector('#counter')!) diff --git a/examples/train-web/src/style.css b/examples/train-web/src/style.css new file mode 100644 index 0000000000..b528b6cc28 --- /dev/null +++ b/examples/train-web/src/style.css @@ -0,0 +1,97 @@ +:root { + font-family: Inter, system-ui, Avenir, Helvetica, Arial, sans-serif; + line-height: 1.5; + font-weight: 400; + + color-scheme: light dark; + color: rgba(255, 255, 255, 0.87); + background-color: #242424; + + font-synthesis: none; + text-rendering: optimizeLegibility; + -webkit-font-smoothing: antialiased; + -moz-osx-font-smoothing: grayscale; + -webkit-text-size-adjust: 100%; +} + +a { + font-weight: 500; + color: #646cff; + text-decoration: inherit; +} +a:hover { + color: #535bf2; +} + +body { + margin: 0; + display: flex; + place-items: center; + min-width: 320px; + min-height: 100vh; +} + +h1 { + font-size: 3.2em; + line-height: 1.1; +} + +#app { + max-width: 1280px; + margin: 0 auto; + padding: 2rem; + text-align: center; +} + +.logo { + height: 6em; + padding: 1.5em; + will-change: filter; + transition: filter 300ms; +} +.logo:hover { + filter: drop-shadow(0 0 2em #646cffaa); +} +.logo.vanilla:hover { + filter: drop-shadow(0 0 2em #3178c6aa); +} + +.card { + padding: 2em; +} + +.read-the-docs { + color: #888; +} + +button { + border-radius: 8px; + border: 1px solid transparent; + padding: 0.6em 1.2em; + font-size: 1em; + font-weight: 500; + font-family: inherit; + background-color: #1a1a1a; + cursor: pointer; + transition: border-color 0.25s; +} +button:hover { + border-color: #646cff; +} +button:focus, +button:focus-visible { + outline: 4px auto -webkit-focus-ring-color; +} + +@media (prefers-color-scheme: light) { + :root { + color: #213547; + background-color: #ffffff; + } + a:hover { + color: #747bff; + } + button { + background-color: #f9f9f9; + } +} diff --git a/examples/train-web/src/typescript.svg b/examples/train-web/src/typescript.svg new file mode 100644 index 0000000000..d91c910cc3 --- /dev/null +++ b/examples/train-web/src/typescript.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/examples/train-web/src/vite-env.d.ts b/examples/train-web/src/vite-env.d.ts new file mode 100644 index 0000000000..11f02fe2a0 --- /dev/null +++ b/examples/train-web/src/vite-env.d.ts @@ -0,0 +1 @@ +/// diff --git a/examples/train-web/tsconfig.json b/examples/train-web/tsconfig.json new file mode 100644 index 0000000000..75abdef265 --- /dev/null +++ b/examples/train-web/tsconfig.json @@ -0,0 +1,23 @@ +{ + "compilerOptions": { + "target": "ES2020", + "useDefineForClassFields": true, + "module": "ESNext", + "lib": ["ES2020", "DOM", "DOM.Iterable"], + "skipLibCheck": true, + + /* Bundler mode */ + "moduleResolution": "bundler", + "allowImportingTsExtensions": true, + "resolveJsonModule": true, + "isolatedModules": true, + "noEmit": true, + + /* Linting */ + "strict": true, + "noUnusedLocals": true, + "noUnusedParameters": true, + "noFallthroughCasesInSwitch": true + }, + "include": ["src"] +} From 2f09310173c6aaf0276f2ce9a36007845db247bc Mon Sep 17 00:00:00 2001 From: Alex Date: Thu, 2 Nov 2023 16:53:12 -0500 Subject: [PATCH 02/79] pnpm i --- examples/train-web/pnpm-lock.yaml | 324 ++++++++++++++++++++++++++++++ 1 file changed, 324 insertions(+) create mode 100644 examples/train-web/pnpm-lock.yaml diff --git a/examples/train-web/pnpm-lock.yaml b/examples/train-web/pnpm-lock.yaml new file mode 100644 index 0000000000..492b18fe75 --- /dev/null +++ b/examples/train-web/pnpm-lock.yaml @@ -0,0 +1,324 @@ +lockfileVersion: '6.1' + +settings: + autoInstallPeers: true + excludeLinksFromLockfile: false + +devDependencies: + typescript: + specifier: ^5.0.2 + version: 5.0.2 + vite: + specifier: ^4.4.5 + version: 4.4.5 + +packages: + + /@esbuild/android-arm64@0.18.20: + resolution: {integrity: sha512-Nz4rJcchGDtENV0eMKUNa6L12zz2zBDXuhj/Vjh18zGqB44Bi7MBMSXjgunJgjRhCmKOjnPuZp4Mb6OKqtMHLQ==} + engines: {node: '>=12'} + cpu: [arm64] + os: [android] + requiresBuild: true + dev: true + optional: true + + /@esbuild/android-arm@0.18.20: + resolution: {integrity: sha512-fyi7TDI/ijKKNZTUJAQqiG5T7YjJXgnzkURqmGj13C6dCqckZBLdl4h7bkhHt/t0WP+zO9/zwroDvANaOqO5Sw==} + engines: {node: '>=12'} + cpu: [arm] + os: [android] + requiresBuild: true + dev: true + optional: true + + /@esbuild/android-x64@0.18.20: + resolution: {integrity: sha512-8GDdlePJA8D6zlZYJV/jnrRAi6rOiNaCC/JclcXpB+KIuvfBN4owLtgzY2bsxnx666XjJx2kDPUmnTtR8qKQUg==} + engines: {node: '>=12'} + cpu: [x64] + os: [android] + requiresBuild: true + dev: true + optional: true + + /@esbuild/darwin-arm64@0.18.20: + resolution: {integrity: sha512-bxRHW5kHU38zS2lPTPOyuyTm+S+eobPUnTNkdJEfAddYgEcll4xkT8DB9d2008DtTbl7uJag2HuE5NZAZgnNEA==} + engines: {node: '>=12'} + cpu: [arm64] + os: [darwin] + requiresBuild: true + dev: true + optional: true + + /@esbuild/darwin-x64@0.18.20: + resolution: {integrity: sha512-pc5gxlMDxzm513qPGbCbDukOdsGtKhfxD1zJKXjCCcU7ju50O7MeAZ8c4krSJcOIJGFR+qx21yMMVYwiQvyTyQ==} + engines: {node: '>=12'} + cpu: [x64] + os: [darwin] + requiresBuild: true + dev: true + optional: true + + /@esbuild/freebsd-arm64@0.18.20: + resolution: {integrity: sha512-yqDQHy4QHevpMAaxhhIwYPMv1NECwOvIpGCZkECn8w2WFHXjEwrBn3CeNIYsibZ/iZEUemj++M26W3cNR5h+Tw==} + engines: {node: '>=12'} + cpu: [arm64] + os: [freebsd] + requiresBuild: true + dev: true + optional: true + + /@esbuild/freebsd-x64@0.18.20: + resolution: {integrity: sha512-tgWRPPuQsd3RmBZwarGVHZQvtzfEBOreNuxEMKFcd5DaDn2PbBxfwLcj4+aenoh7ctXcbXmOQIn8HI6mCSw5MQ==} + engines: {node: '>=12'} + cpu: [x64] + os: [freebsd] + requiresBuild: true + dev: true + optional: true + + /@esbuild/linux-arm64@0.18.20: + resolution: {integrity: sha512-2YbscF+UL7SQAVIpnWvYwM+3LskyDmPhe31pE7/aoTMFKKzIc9lLbyGUpmmb8a8AixOL61sQ/mFh3jEjHYFvdA==} + engines: {node: '>=12'} + cpu: [arm64] + os: [linux] + requiresBuild: true + dev: true + optional: true + + /@esbuild/linux-arm@0.18.20: + resolution: {integrity: sha512-/5bHkMWnq1EgKr1V+Ybz3s1hWXok7mDFUMQ4cG10AfW3wL02PSZi5kFpYKrptDsgb2WAJIvRcDm+qIvXf/apvg==} + engines: {node: '>=12'} + cpu: [arm] + os: [linux] + requiresBuild: true + dev: true + optional: true + + /@esbuild/linux-ia32@0.18.20: + resolution: {integrity: sha512-P4etWwq6IsReT0E1KHU40bOnzMHoH73aXp96Fs8TIT6z9Hu8G6+0SHSw9i2isWrD2nbx2qo5yUqACgdfVGx7TA==} + engines: {node: '>=12'} + cpu: [ia32] + os: [linux] + requiresBuild: true + dev: true + optional: true + + /@esbuild/linux-loong64@0.18.20: + resolution: {integrity: sha512-nXW8nqBTrOpDLPgPY9uV+/1DjxoQ7DoB2N8eocyq8I9XuqJ7BiAMDMf9n1xZM9TgW0J8zrquIb/A7s3BJv7rjg==} + engines: {node: '>=12'} + cpu: [loong64] + os: [linux] + requiresBuild: true + dev: true + optional: true + + /@esbuild/linux-mips64el@0.18.20: + resolution: {integrity: sha512-d5NeaXZcHp8PzYy5VnXV3VSd2D328Zb+9dEq5HE6bw6+N86JVPExrA6O68OPwobntbNJ0pzCpUFZTo3w0GyetQ==} + engines: {node: '>=12'} + cpu: [mips64el] + os: [linux] + requiresBuild: true + dev: true + optional: true + + /@esbuild/linux-ppc64@0.18.20: + resolution: {integrity: sha512-WHPyeScRNcmANnLQkq6AfyXRFr5D6N2sKgkFo2FqguP44Nw2eyDlbTdZwd9GYk98DZG9QItIiTlFLHJHjxP3FA==} + engines: {node: '>=12'} + cpu: [ppc64] + os: [linux] + requiresBuild: true + dev: true + optional: true + + /@esbuild/linux-riscv64@0.18.20: + resolution: {integrity: sha512-WSxo6h5ecI5XH34KC7w5veNnKkju3zBRLEQNY7mv5mtBmrP/MjNBCAlsM2u5hDBlS3NGcTQpoBvRzqBcRtpq1A==} + engines: {node: '>=12'} + cpu: [riscv64] + os: [linux] + requiresBuild: true + dev: true + optional: true + + /@esbuild/linux-s390x@0.18.20: + resolution: {integrity: sha512-+8231GMs3mAEth6Ja1iK0a1sQ3ohfcpzpRLH8uuc5/KVDFneH6jtAJLFGafpzpMRO6DzJ6AvXKze9LfFMrIHVQ==} + engines: {node: '>=12'} + cpu: [s390x] + os: [linux] + requiresBuild: true + dev: true + optional: true + + /@esbuild/linux-x64@0.18.20: + resolution: {integrity: sha512-UYqiqemphJcNsFEskc73jQ7B9jgwjWrSayxawS6UVFZGWrAAtkzjxSqnoclCXxWtfwLdzU+vTpcNYhpn43uP1w==} + engines: {node: '>=12'} + cpu: [x64] + os: [linux] + requiresBuild: true + dev: true + optional: true + + /@esbuild/netbsd-x64@0.18.20: + resolution: {integrity: sha512-iO1c++VP6xUBUmltHZoMtCUdPlnPGdBom6IrO4gyKPFFVBKioIImVooR5I83nTew5UOYrk3gIJhbZh8X44y06A==} + engines: {node: '>=12'} + cpu: [x64] + os: [netbsd] + requiresBuild: true + dev: true + optional: true + + /@esbuild/openbsd-x64@0.18.20: + resolution: {integrity: sha512-e5e4YSsuQfX4cxcygw/UCPIEP6wbIL+se3sxPdCiMbFLBWu0eiZOJ7WoD+ptCLrmjZBK1Wk7I6D/I3NglUGOxg==} + engines: {node: '>=12'} + cpu: [x64] + os: [openbsd] + requiresBuild: true + dev: true + optional: true + + /@esbuild/sunos-x64@0.18.20: + resolution: {integrity: sha512-kDbFRFp0YpTQVVrqUd5FTYmWo45zGaXe0X8E1G/LKFC0v8x0vWrhOWSLITcCn63lmZIxfOMXtCfti/RxN/0wnQ==} + engines: {node: '>=12'} + cpu: [x64] + os: [sunos] + requiresBuild: true + dev: true + optional: true + + /@esbuild/win32-arm64@0.18.20: + resolution: {integrity: sha512-ddYFR6ItYgoaq4v4JmQQaAI5s7npztfV4Ag6NrhiaW0RrnOXqBkgwZLofVTlq1daVTQNhtI5oieTvkRPfZrePg==} + engines: {node: '>=12'} + cpu: [arm64] + os: [win32] + requiresBuild: true + dev: true + optional: true + + /@esbuild/win32-ia32@0.18.20: + resolution: {integrity: sha512-Wv7QBi3ID/rROT08SABTS7eV4hX26sVduqDOTe1MvGMjNd3EjOz4b7zeexIR62GTIEKrfJXKL9LFxTYgkyeu7g==} + engines: {node: '>=12'} + cpu: [ia32] + os: [win32] + requiresBuild: true + dev: true + optional: true + + /@esbuild/win32-x64@0.18.20: + resolution: {integrity: sha512-kTdfRcSiDfQca/y9QIkng02avJ+NCaQvrMejlsB3RRv5sE9rRoeBPISaZpKxHELzRxZyLvNts1P27W3wV+8geQ==} + engines: {node: '>=12'} + cpu: [x64] + os: [win32] + requiresBuild: true + dev: true + optional: true + + /esbuild@0.18.20: + resolution: {integrity: sha512-ceqxoedUrcayh7Y7ZX6NdbbDzGROiyVBgC4PriJThBKSVPWnnFHZAkfI1lJT8QFkOwH4qOS2SJkS4wvpGl8BpA==} + engines: {node: '>=12'} + hasBin: true + requiresBuild: true + optionalDependencies: + '@esbuild/android-arm': 0.18.20 + '@esbuild/android-arm64': 0.18.20 + '@esbuild/android-x64': 0.18.20 + '@esbuild/darwin-arm64': 0.18.20 + '@esbuild/darwin-x64': 0.18.20 + '@esbuild/freebsd-arm64': 0.18.20 + '@esbuild/freebsd-x64': 0.18.20 + '@esbuild/linux-arm': 0.18.20 + '@esbuild/linux-arm64': 0.18.20 + '@esbuild/linux-ia32': 0.18.20 + '@esbuild/linux-loong64': 0.18.20 + '@esbuild/linux-mips64el': 0.18.20 + '@esbuild/linux-ppc64': 0.18.20 + '@esbuild/linux-riscv64': 0.18.20 + '@esbuild/linux-s390x': 0.18.20 + '@esbuild/linux-x64': 0.18.20 + '@esbuild/netbsd-x64': 0.18.20 + '@esbuild/openbsd-x64': 0.18.20 + '@esbuild/sunos-x64': 0.18.20 + '@esbuild/win32-arm64': 0.18.20 + '@esbuild/win32-ia32': 0.18.20 + '@esbuild/win32-x64': 0.18.20 + dev: true + + /fsevents@2.3.3: + resolution: {integrity: sha512-5xoDfX+fL7faATnagmWPpbFtwh/R77WmMMqqHGS65C3vvB0YHrgF+B1YmZ3441tMj5n63k0212XNoJwzlhffQw==} + engines: {node: ^8.16.0 || ^10.6.0 || >=11.0.0} + os: [darwin] + requiresBuild: true + dev: true + optional: true + + /nanoid@3.3.6: + resolution: {integrity: sha512-BGcqMMJuToF7i1rt+2PWSNVnWIkGCU78jBG3RxO/bZlnZPK2Cmi2QaffxGO/2RvWi9sL+FAiRiXMgsyxQ1DIDA==} + engines: {node: ^10 || ^12 || ^13.7 || ^14 || >=15.0.1} + hasBin: true + dev: true + + /picocolors@1.0.0: + resolution: {integrity: sha512-1fygroTLlHu66zi26VoTDv8yRgm0Fccecssto+MhsZ0D/DGW2sm8E8AjW7NU5VVTRt5GxbeZ5qBuJr+HyLYkjQ==} + dev: true + + /postcss@8.4.31: + resolution: {integrity: sha512-PS08Iboia9mts/2ygV3eLpY5ghnUcfLV/EXTOW1E2qYxJKGGBUtNjN76FYHnMs36RmARn41bC0AZmn+rR0OVpQ==} + engines: {node: ^10 || ^12 || >=14} + dependencies: + nanoid: 3.3.6 + picocolors: 1.0.0 + source-map-js: 1.0.2 + dev: true + + /rollup@3.29.4: + resolution: {integrity: sha512-oWzmBZwvYrU0iJHtDmhsm662rC15FRXmcjCk1xD771dFDx5jJ02ufAQQTn0etB2emNk4J9EZg/yWKpsn9BWGRw==} + engines: {node: '>=14.18.0', npm: '>=8.0.0'} + hasBin: true + optionalDependencies: + fsevents: 2.3.3 + dev: true + + /source-map-js@1.0.2: + resolution: {integrity: sha512-R0XvVJ9WusLiqTCEiGCmICCMplcCkIwwR11mOSD9CR5u+IXYdiseeEuXCVAjS54zqwkLcPNnmU4OeJ6tUrWhDw==} + engines: {node: '>=0.10.0'} + dev: true + + /typescript@5.0.2: + resolution: {integrity: sha512-wVORMBGO/FAs/++blGNeAVdbNKtIh1rbBL2EyQ1+J9lClJ93KiiKe8PmFIVdXhHcyv44SL9oglmfeSsndo0jRw==} + engines: {node: '>=12.20'} + hasBin: true + dev: true + + /vite@4.4.5: + resolution: {integrity: sha512-4m5kEtAWHYr0O1Fu7rZp64CfO1PsRGZlD3TAB32UmQlpd7qg15VF7ROqGN5CyqN7HFuwr7ICNM2+fDWRqFEKaA==} + engines: {node: ^14.18.0 || >=16.0.0} + hasBin: true + peerDependencies: + '@types/node': '>= 14' + less: '*' + lightningcss: ^1.21.0 + sass: '*' + stylus: '*' + sugarss: '*' + terser: ^5.4.0 + peerDependenciesMeta: + '@types/node': + optional: true + less: + optional: true + lightningcss: + optional: true + sass: + optional: true + stylus: + optional: true + sugarss: + optional: true + terser: + optional: true + dependencies: + esbuild: 0.18.20 + postcss: 8.4.31 + rollup: 3.29.4 + optionalDependencies: + fsevents: 2.3.3 + dev: true From 408d740d20ad49934de201d3f341ea59a9d86fd1 Mon Sep 17 00:00:00 2001 From: Alex Date: Thu, 2 Nov 2023 16:53:47 -0500 Subject: [PATCH 03/79] pnpm up -rL --- examples/train-web/package.json | 4 ++-- examples/train-web/pnpm-lock.yaml | 18 +++++++++--------- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/examples/train-web/package.json b/examples/train-web/package.json index 9d90e0784e..b5d9b469a7 100644 --- a/examples/train-web/package.json +++ b/examples/train-web/package.json @@ -9,7 +9,7 @@ "preview": "vite preview" }, "devDependencies": { - "typescript": "^5.0.2", - "vite": "^4.4.5" + "typescript": "^5.2.2", + "vite": "^4.5.0" } } diff --git a/examples/train-web/pnpm-lock.yaml b/examples/train-web/pnpm-lock.yaml index 492b18fe75..adcc50e3be 100644 --- a/examples/train-web/pnpm-lock.yaml +++ b/examples/train-web/pnpm-lock.yaml @@ -6,11 +6,11 @@ settings: devDependencies: typescript: - specifier: ^5.0.2 - version: 5.0.2 + specifier: ^5.2.2 + version: 5.2.2 vite: - specifier: ^4.4.5 - version: 4.4.5 + specifier: ^4.5.0 + version: 4.5.0 packages: @@ -282,14 +282,14 @@ packages: engines: {node: '>=0.10.0'} dev: true - /typescript@5.0.2: - resolution: {integrity: sha512-wVORMBGO/FAs/++blGNeAVdbNKtIh1rbBL2EyQ1+J9lClJ93KiiKe8PmFIVdXhHcyv44SL9oglmfeSsndo0jRw==} - engines: {node: '>=12.20'} + /typescript@5.2.2: + resolution: {integrity: sha512-mI4WrpHsbCIcwT9cF4FZvr80QUeKvsUsUvKDoR+X/7XHQH98xYD8YHZg7ANtz2GtZt/CBq2QJ0thkGJMHfqc1w==} + engines: {node: '>=14.17'} hasBin: true dev: true - /vite@4.4.5: - resolution: {integrity: sha512-4m5kEtAWHYr0O1Fu7rZp64CfO1PsRGZlD3TAB32UmQlpd7qg15VF7ROqGN5CyqN7HFuwr7ICNM2+fDWRqFEKaA==} + /vite@4.5.0: + resolution: {integrity: sha512-ulr8rNLA6rkyFAlVWw2q5YJ91v098AFQ2R0PRFwPzREXOUJQPtFUG0t+/ZikhaOCDqFoDhN6/v8Sq0o4araFAw==} engines: {node: ^14.18.0 || >=16.0.0} hasBin: true peerDependencies: From 8eab3a6815fdb9c6b8710bf6a98e93db0760c28c Mon Sep 17 00:00:00 2001 From: Alex Date: Thu, 2 Nov 2023 16:56:32 -0500 Subject: [PATCH 04/79] pnpm i -D prettier --- examples/train-web/package.json | 1 + examples/train-web/pnpm-lock.yaml | 9 +++++++++ 2 files changed, 10 insertions(+) diff --git a/examples/train-web/package.json b/examples/train-web/package.json index b5d9b469a7..c3088b9641 100644 --- a/examples/train-web/package.json +++ b/examples/train-web/package.json @@ -9,6 +9,7 @@ "preview": "vite preview" }, "devDependencies": { + "prettier": "3.0.3", "typescript": "^5.2.2", "vite": "^4.5.0" } diff --git a/examples/train-web/pnpm-lock.yaml b/examples/train-web/pnpm-lock.yaml index adcc50e3be..4a7c7c2cd6 100644 --- a/examples/train-web/pnpm-lock.yaml +++ b/examples/train-web/pnpm-lock.yaml @@ -5,6 +5,9 @@ settings: excludeLinksFromLockfile: false devDependencies: + prettier: + specifier: 3.0.3 + version: 3.0.3 typescript: specifier: ^5.2.2 version: 5.2.2 @@ -269,6 +272,12 @@ packages: source-map-js: 1.0.2 dev: true + /prettier@3.0.3: + resolution: {integrity: sha512-L/4pUDMxcNa8R/EthV08Zt42WBO4h1rarVtK0K+QJG0X187OLo7l699jWw0GKuwzkPQ//jMFA/8Xm6Fh3J/DAg==} + engines: {node: '>=14'} + hasBin: true + dev: true + /rollup@3.29.4: resolution: {integrity: sha512-oWzmBZwvYrU0iJHtDmhsm662rC15FRXmcjCk1xD771dFDx5jJ02ufAQQTn0etB2emNk4J9EZg/yWKpsn9BWGRw==} engines: {node: '>=14.18.0', npm: '>=8.0.0'} From 74b9b5706afaa8af4b4a55df81fe6cfd154c22ce Mon Sep 17 00:00:00 2001 From: Alex Date: Thu, 2 Nov 2023 16:59:10 -0500 Subject: [PATCH 05/79] add .prettierrc --- examples/train-web/.prettierrc | 5 +++++ 1 file changed, 5 insertions(+) create mode 100644 examples/train-web/.prettierrc diff --git a/examples/train-web/.prettierrc b/examples/train-web/.prettierrc new file mode 100644 index 0000000000..c5f6f03cf5 --- /dev/null +++ b/examples/train-web/.prettierrc @@ -0,0 +1,5 @@ +{ + "semi": false, + "singleQuote": true, + "useTabs": true +} From 86c6a06d871dd889d4b870dce55d5647aab0ea85 Mon Sep 17 00:00:00 2001 From: Alex Date: Thu, 2 Nov 2023 16:59:16 -0500 Subject: [PATCH 06/79] pnpm exec prettier . --write --- examples/train-web/index.html | 20 +-- examples/train-web/package.json | 28 ++-- examples/train-web/pnpm-lock.yaml | 223 +++++++++++++++++++++--------- examples/train-web/src/counter.ts | 14 +- examples/train-web/src/style.css | 110 +++++++-------- examples/train-web/tsconfig.json | 38 ++--- 6 files changed, 264 insertions(+), 169 deletions(-) diff --git a/examples/train-web/index.html b/examples/train-web/index.html index 44a933506f..32a4f54abc 100644 --- a/examples/train-web/index.html +++ b/examples/train-web/index.html @@ -1,13 +1,13 @@ - - - - - Vite + TS - - -
- - + + + + + Vite + TS + + +
+ + diff --git a/examples/train-web/package.json b/examples/train-web/package.json index c3088b9641..a89dfe068f 100644 --- a/examples/train-web/package.json +++ b/examples/train-web/package.json @@ -1,16 +1,16 @@ { - "name": "train-web", - "private": true, - "version": "0.0.0", - "type": "module", - "scripts": { - "dev": "vite", - "build": "tsc && vite build", - "preview": "vite preview" - }, - "devDependencies": { - "prettier": "3.0.3", - "typescript": "^5.2.2", - "vite": "^4.5.0" - } + "name": "train-web", + "private": true, + "version": "0.0.0", + "type": "module", + "scripts": { + "dev": "vite", + "build": "tsc && vite build", + "preview": "vite preview" + }, + "devDependencies": { + "prettier": "3.0.3", + "typescript": "^5.2.2", + "vite": "^4.5.0" + } } diff --git a/examples/train-web/pnpm-lock.yaml b/examples/train-web/pnpm-lock.yaml index 4a7c7c2cd6..41dbb1ab05 100644 --- a/examples/train-web/pnpm-lock.yaml +++ b/examples/train-web/pnpm-lock.yaml @@ -16,10 +16,12 @@ devDependencies: version: 4.5.0 packages: - /@esbuild/android-arm64@0.18.20: - resolution: {integrity: sha512-Nz4rJcchGDtENV0eMKUNa6L12zz2zBDXuhj/Vjh18zGqB44Bi7MBMSXjgunJgjRhCmKOjnPuZp4Mb6OKqtMHLQ==} - engines: {node: '>=12'} + resolution: + { + integrity: sha512-Nz4rJcchGDtENV0eMKUNa6L12zz2zBDXuhj/Vjh18zGqB44Bi7MBMSXjgunJgjRhCmKOjnPuZp4Mb6OKqtMHLQ==, + } + engines: { node: '>=12' } cpu: [arm64] os: [android] requiresBuild: true @@ -27,8 +29,11 @@ packages: optional: true /@esbuild/android-arm@0.18.20: - resolution: {integrity: sha512-fyi7TDI/ijKKNZTUJAQqiG5T7YjJXgnzkURqmGj13C6dCqckZBLdl4h7bkhHt/t0WP+zO9/zwroDvANaOqO5Sw==} - engines: {node: '>=12'} + resolution: + { + integrity: sha512-fyi7TDI/ijKKNZTUJAQqiG5T7YjJXgnzkURqmGj13C6dCqckZBLdl4h7bkhHt/t0WP+zO9/zwroDvANaOqO5Sw==, + } + engines: { node: '>=12' } cpu: [arm] os: [android] requiresBuild: true @@ -36,8 +41,11 @@ packages: optional: true /@esbuild/android-x64@0.18.20: - resolution: {integrity: sha512-8GDdlePJA8D6zlZYJV/jnrRAi6rOiNaCC/JclcXpB+KIuvfBN4owLtgzY2bsxnx666XjJx2kDPUmnTtR8qKQUg==} - engines: {node: '>=12'} + resolution: + { + integrity: sha512-8GDdlePJA8D6zlZYJV/jnrRAi6rOiNaCC/JclcXpB+KIuvfBN4owLtgzY2bsxnx666XjJx2kDPUmnTtR8qKQUg==, + } + engines: { node: '>=12' } cpu: [x64] os: [android] requiresBuild: true @@ -45,8 +53,11 @@ packages: optional: true /@esbuild/darwin-arm64@0.18.20: - resolution: {integrity: sha512-bxRHW5kHU38zS2lPTPOyuyTm+S+eobPUnTNkdJEfAddYgEcll4xkT8DB9d2008DtTbl7uJag2HuE5NZAZgnNEA==} - engines: {node: '>=12'} + resolution: + { + integrity: sha512-bxRHW5kHU38zS2lPTPOyuyTm+S+eobPUnTNkdJEfAddYgEcll4xkT8DB9d2008DtTbl7uJag2HuE5NZAZgnNEA==, + } + engines: { node: '>=12' } cpu: [arm64] os: [darwin] requiresBuild: true @@ -54,8 +65,11 @@ packages: optional: true /@esbuild/darwin-x64@0.18.20: - resolution: {integrity: sha512-pc5gxlMDxzm513qPGbCbDukOdsGtKhfxD1zJKXjCCcU7ju50O7MeAZ8c4krSJcOIJGFR+qx21yMMVYwiQvyTyQ==} - engines: {node: '>=12'} + resolution: + { + integrity: sha512-pc5gxlMDxzm513qPGbCbDukOdsGtKhfxD1zJKXjCCcU7ju50O7MeAZ8c4krSJcOIJGFR+qx21yMMVYwiQvyTyQ==, + } + engines: { node: '>=12' } cpu: [x64] os: [darwin] requiresBuild: true @@ -63,8 +77,11 @@ packages: optional: true /@esbuild/freebsd-arm64@0.18.20: - resolution: {integrity: sha512-yqDQHy4QHevpMAaxhhIwYPMv1NECwOvIpGCZkECn8w2WFHXjEwrBn3CeNIYsibZ/iZEUemj++M26W3cNR5h+Tw==} - engines: {node: '>=12'} + resolution: + { + integrity: sha512-yqDQHy4QHevpMAaxhhIwYPMv1NECwOvIpGCZkECn8w2WFHXjEwrBn3CeNIYsibZ/iZEUemj++M26W3cNR5h+Tw==, + } + engines: { node: '>=12' } cpu: [arm64] os: [freebsd] requiresBuild: true @@ -72,8 +89,11 @@ packages: optional: true /@esbuild/freebsd-x64@0.18.20: - resolution: {integrity: sha512-tgWRPPuQsd3RmBZwarGVHZQvtzfEBOreNuxEMKFcd5DaDn2PbBxfwLcj4+aenoh7ctXcbXmOQIn8HI6mCSw5MQ==} - engines: {node: '>=12'} + resolution: + { + integrity: sha512-tgWRPPuQsd3RmBZwarGVHZQvtzfEBOreNuxEMKFcd5DaDn2PbBxfwLcj4+aenoh7ctXcbXmOQIn8HI6mCSw5MQ==, + } + engines: { node: '>=12' } cpu: [x64] os: [freebsd] requiresBuild: true @@ -81,8 +101,11 @@ packages: optional: true /@esbuild/linux-arm64@0.18.20: - resolution: {integrity: sha512-2YbscF+UL7SQAVIpnWvYwM+3LskyDmPhe31pE7/aoTMFKKzIc9lLbyGUpmmb8a8AixOL61sQ/mFh3jEjHYFvdA==} - engines: {node: '>=12'} + resolution: + { + integrity: sha512-2YbscF+UL7SQAVIpnWvYwM+3LskyDmPhe31pE7/aoTMFKKzIc9lLbyGUpmmb8a8AixOL61sQ/mFh3jEjHYFvdA==, + } + engines: { node: '>=12' } cpu: [arm64] os: [linux] requiresBuild: true @@ -90,8 +113,11 @@ packages: optional: true /@esbuild/linux-arm@0.18.20: - resolution: {integrity: sha512-/5bHkMWnq1EgKr1V+Ybz3s1hWXok7mDFUMQ4cG10AfW3wL02PSZi5kFpYKrptDsgb2WAJIvRcDm+qIvXf/apvg==} - engines: {node: '>=12'} + resolution: + { + integrity: sha512-/5bHkMWnq1EgKr1V+Ybz3s1hWXok7mDFUMQ4cG10AfW3wL02PSZi5kFpYKrptDsgb2WAJIvRcDm+qIvXf/apvg==, + } + engines: { node: '>=12' } cpu: [arm] os: [linux] requiresBuild: true @@ -99,8 +125,11 @@ packages: optional: true /@esbuild/linux-ia32@0.18.20: - resolution: {integrity: sha512-P4etWwq6IsReT0E1KHU40bOnzMHoH73aXp96Fs8TIT6z9Hu8G6+0SHSw9i2isWrD2nbx2qo5yUqACgdfVGx7TA==} - engines: {node: '>=12'} + resolution: + { + integrity: sha512-P4etWwq6IsReT0E1KHU40bOnzMHoH73aXp96Fs8TIT6z9Hu8G6+0SHSw9i2isWrD2nbx2qo5yUqACgdfVGx7TA==, + } + engines: { node: '>=12' } cpu: [ia32] os: [linux] requiresBuild: true @@ -108,8 +137,11 @@ packages: optional: true /@esbuild/linux-loong64@0.18.20: - resolution: {integrity: sha512-nXW8nqBTrOpDLPgPY9uV+/1DjxoQ7DoB2N8eocyq8I9XuqJ7BiAMDMf9n1xZM9TgW0J8zrquIb/A7s3BJv7rjg==} - engines: {node: '>=12'} + resolution: + { + integrity: sha512-nXW8nqBTrOpDLPgPY9uV+/1DjxoQ7DoB2N8eocyq8I9XuqJ7BiAMDMf9n1xZM9TgW0J8zrquIb/A7s3BJv7rjg==, + } + engines: { node: '>=12' } cpu: [loong64] os: [linux] requiresBuild: true @@ -117,8 +149,11 @@ packages: optional: true /@esbuild/linux-mips64el@0.18.20: - resolution: {integrity: sha512-d5NeaXZcHp8PzYy5VnXV3VSd2D328Zb+9dEq5HE6bw6+N86JVPExrA6O68OPwobntbNJ0pzCpUFZTo3w0GyetQ==} - engines: {node: '>=12'} + resolution: + { + integrity: sha512-d5NeaXZcHp8PzYy5VnXV3VSd2D328Zb+9dEq5HE6bw6+N86JVPExrA6O68OPwobntbNJ0pzCpUFZTo3w0GyetQ==, + } + engines: { node: '>=12' } cpu: [mips64el] os: [linux] requiresBuild: true @@ -126,8 +161,11 @@ packages: optional: true /@esbuild/linux-ppc64@0.18.20: - resolution: {integrity: sha512-WHPyeScRNcmANnLQkq6AfyXRFr5D6N2sKgkFo2FqguP44Nw2eyDlbTdZwd9GYk98DZG9QItIiTlFLHJHjxP3FA==} - engines: {node: '>=12'} + resolution: + { + integrity: sha512-WHPyeScRNcmANnLQkq6AfyXRFr5D6N2sKgkFo2FqguP44Nw2eyDlbTdZwd9GYk98DZG9QItIiTlFLHJHjxP3FA==, + } + engines: { node: '>=12' } cpu: [ppc64] os: [linux] requiresBuild: true @@ -135,8 +173,11 @@ packages: optional: true /@esbuild/linux-riscv64@0.18.20: - resolution: {integrity: sha512-WSxo6h5ecI5XH34KC7w5veNnKkju3zBRLEQNY7mv5mtBmrP/MjNBCAlsM2u5hDBlS3NGcTQpoBvRzqBcRtpq1A==} - engines: {node: '>=12'} + resolution: + { + integrity: sha512-WSxo6h5ecI5XH34KC7w5veNnKkju3zBRLEQNY7mv5mtBmrP/MjNBCAlsM2u5hDBlS3NGcTQpoBvRzqBcRtpq1A==, + } + engines: { node: '>=12' } cpu: [riscv64] os: [linux] requiresBuild: true @@ -144,8 +185,11 @@ packages: optional: true /@esbuild/linux-s390x@0.18.20: - resolution: {integrity: sha512-+8231GMs3mAEth6Ja1iK0a1sQ3ohfcpzpRLH8uuc5/KVDFneH6jtAJLFGafpzpMRO6DzJ6AvXKze9LfFMrIHVQ==} - engines: {node: '>=12'} + resolution: + { + integrity: sha512-+8231GMs3mAEth6Ja1iK0a1sQ3ohfcpzpRLH8uuc5/KVDFneH6jtAJLFGafpzpMRO6DzJ6AvXKze9LfFMrIHVQ==, + } + engines: { node: '>=12' } cpu: [s390x] os: [linux] requiresBuild: true @@ -153,8 +197,11 @@ packages: optional: true /@esbuild/linux-x64@0.18.20: - resolution: {integrity: sha512-UYqiqemphJcNsFEskc73jQ7B9jgwjWrSayxawS6UVFZGWrAAtkzjxSqnoclCXxWtfwLdzU+vTpcNYhpn43uP1w==} - engines: {node: '>=12'} + resolution: + { + integrity: sha512-UYqiqemphJcNsFEskc73jQ7B9jgwjWrSayxawS6UVFZGWrAAtkzjxSqnoclCXxWtfwLdzU+vTpcNYhpn43uP1w==, + } + engines: { node: '>=12' } cpu: [x64] os: [linux] requiresBuild: true @@ -162,8 +209,11 @@ packages: optional: true /@esbuild/netbsd-x64@0.18.20: - resolution: {integrity: sha512-iO1c++VP6xUBUmltHZoMtCUdPlnPGdBom6IrO4gyKPFFVBKioIImVooR5I83nTew5UOYrk3gIJhbZh8X44y06A==} - engines: {node: '>=12'} + resolution: + { + integrity: sha512-iO1c++VP6xUBUmltHZoMtCUdPlnPGdBom6IrO4gyKPFFVBKioIImVooR5I83nTew5UOYrk3gIJhbZh8X44y06A==, + } + engines: { node: '>=12' } cpu: [x64] os: [netbsd] requiresBuild: true @@ -171,8 +221,11 @@ packages: optional: true /@esbuild/openbsd-x64@0.18.20: - resolution: {integrity: sha512-e5e4YSsuQfX4cxcygw/UCPIEP6wbIL+se3sxPdCiMbFLBWu0eiZOJ7WoD+ptCLrmjZBK1Wk7I6D/I3NglUGOxg==} - engines: {node: '>=12'} + resolution: + { + integrity: sha512-e5e4YSsuQfX4cxcygw/UCPIEP6wbIL+se3sxPdCiMbFLBWu0eiZOJ7WoD+ptCLrmjZBK1Wk7I6D/I3NglUGOxg==, + } + engines: { node: '>=12' } cpu: [x64] os: [openbsd] requiresBuild: true @@ -180,8 +233,11 @@ packages: optional: true /@esbuild/sunos-x64@0.18.20: - resolution: {integrity: sha512-kDbFRFp0YpTQVVrqUd5FTYmWo45zGaXe0X8E1G/LKFC0v8x0vWrhOWSLITcCn63lmZIxfOMXtCfti/RxN/0wnQ==} - engines: {node: '>=12'} + resolution: + { + integrity: sha512-kDbFRFp0YpTQVVrqUd5FTYmWo45zGaXe0X8E1G/LKFC0v8x0vWrhOWSLITcCn63lmZIxfOMXtCfti/RxN/0wnQ==, + } + engines: { node: '>=12' } cpu: [x64] os: [sunos] requiresBuild: true @@ -189,8 +245,11 @@ packages: optional: true /@esbuild/win32-arm64@0.18.20: - resolution: {integrity: sha512-ddYFR6ItYgoaq4v4JmQQaAI5s7npztfV4Ag6NrhiaW0RrnOXqBkgwZLofVTlq1daVTQNhtI5oieTvkRPfZrePg==} - engines: {node: '>=12'} + resolution: + { + integrity: sha512-ddYFR6ItYgoaq4v4JmQQaAI5s7npztfV4Ag6NrhiaW0RrnOXqBkgwZLofVTlq1daVTQNhtI5oieTvkRPfZrePg==, + } + engines: { node: '>=12' } cpu: [arm64] os: [win32] requiresBuild: true @@ -198,8 +257,11 @@ packages: optional: true /@esbuild/win32-ia32@0.18.20: - resolution: {integrity: sha512-Wv7QBi3ID/rROT08SABTS7eV4hX26sVduqDOTe1MvGMjNd3EjOz4b7zeexIR62GTIEKrfJXKL9LFxTYgkyeu7g==} - engines: {node: '>=12'} + resolution: + { + integrity: sha512-Wv7QBi3ID/rROT08SABTS7eV4hX26sVduqDOTe1MvGMjNd3EjOz4b7zeexIR62GTIEKrfJXKL9LFxTYgkyeu7g==, + } + engines: { node: '>=12' } cpu: [ia32] os: [win32] requiresBuild: true @@ -207,8 +269,11 @@ packages: optional: true /@esbuild/win32-x64@0.18.20: - resolution: {integrity: sha512-kTdfRcSiDfQca/y9QIkng02avJ+NCaQvrMejlsB3RRv5sE9rRoeBPISaZpKxHELzRxZyLvNts1P27W3wV+8geQ==} - engines: {node: '>=12'} + resolution: + { + integrity: sha512-kTdfRcSiDfQca/y9QIkng02avJ+NCaQvrMejlsB3RRv5sE9rRoeBPISaZpKxHELzRxZyLvNts1P27W3wV+8geQ==, + } + engines: { node: '>=12' } cpu: [x64] os: [win32] requiresBuild: true @@ -216,8 +281,11 @@ packages: optional: true /esbuild@0.18.20: - resolution: {integrity: sha512-ceqxoedUrcayh7Y7ZX6NdbbDzGROiyVBgC4PriJThBKSVPWnnFHZAkfI1lJT8QFkOwH4qOS2SJkS4wvpGl8BpA==} - engines: {node: '>=12'} + resolution: + { + integrity: sha512-ceqxoedUrcayh7Y7ZX6NdbbDzGROiyVBgC4PriJThBKSVPWnnFHZAkfI1lJT8QFkOwH4qOS2SJkS4wvpGl8BpA==, + } + engines: { node: '>=12' } hasBin: true requiresBuild: true optionalDependencies: @@ -246,26 +314,38 @@ packages: dev: true /fsevents@2.3.3: - resolution: {integrity: sha512-5xoDfX+fL7faATnagmWPpbFtwh/R77WmMMqqHGS65C3vvB0YHrgF+B1YmZ3441tMj5n63k0212XNoJwzlhffQw==} - engines: {node: ^8.16.0 || ^10.6.0 || >=11.0.0} + resolution: + { + integrity: sha512-5xoDfX+fL7faATnagmWPpbFtwh/R77WmMMqqHGS65C3vvB0YHrgF+B1YmZ3441tMj5n63k0212XNoJwzlhffQw==, + } + engines: { node: ^8.16.0 || ^10.6.0 || >=11.0.0 } os: [darwin] requiresBuild: true dev: true optional: true /nanoid@3.3.6: - resolution: {integrity: sha512-BGcqMMJuToF7i1rt+2PWSNVnWIkGCU78jBG3RxO/bZlnZPK2Cmi2QaffxGO/2RvWi9sL+FAiRiXMgsyxQ1DIDA==} - engines: {node: ^10 || ^12 || ^13.7 || ^14 || >=15.0.1} + resolution: + { + integrity: sha512-BGcqMMJuToF7i1rt+2PWSNVnWIkGCU78jBG3RxO/bZlnZPK2Cmi2QaffxGO/2RvWi9sL+FAiRiXMgsyxQ1DIDA==, + } + engines: { node: ^10 || ^12 || ^13.7 || ^14 || >=15.0.1 } hasBin: true dev: true /picocolors@1.0.0: - resolution: {integrity: sha512-1fygroTLlHu66zi26VoTDv8yRgm0Fccecssto+MhsZ0D/DGW2sm8E8AjW7NU5VVTRt5GxbeZ5qBuJr+HyLYkjQ==} + resolution: + { + integrity: sha512-1fygroTLlHu66zi26VoTDv8yRgm0Fccecssto+MhsZ0D/DGW2sm8E8AjW7NU5VVTRt5GxbeZ5qBuJr+HyLYkjQ==, + } dev: true /postcss@8.4.31: - resolution: {integrity: sha512-PS08Iboia9mts/2ygV3eLpY5ghnUcfLV/EXTOW1E2qYxJKGGBUtNjN76FYHnMs36RmARn41bC0AZmn+rR0OVpQ==} - engines: {node: ^10 || ^12 || >=14} + resolution: + { + integrity: sha512-PS08Iboia9mts/2ygV3eLpY5ghnUcfLV/EXTOW1E2qYxJKGGBUtNjN76FYHnMs36RmARn41bC0AZmn+rR0OVpQ==, + } + engines: { node: ^10 || ^12 || >=14 } dependencies: nanoid: 3.3.6 picocolors: 1.0.0 @@ -273,33 +353,48 @@ packages: dev: true /prettier@3.0.3: - resolution: {integrity: sha512-L/4pUDMxcNa8R/EthV08Zt42WBO4h1rarVtK0K+QJG0X187OLo7l699jWw0GKuwzkPQ//jMFA/8Xm6Fh3J/DAg==} - engines: {node: '>=14'} + resolution: + { + integrity: sha512-L/4pUDMxcNa8R/EthV08Zt42WBO4h1rarVtK0K+QJG0X187OLo7l699jWw0GKuwzkPQ//jMFA/8Xm6Fh3J/DAg==, + } + engines: { node: '>=14' } hasBin: true dev: true /rollup@3.29.4: - resolution: {integrity: sha512-oWzmBZwvYrU0iJHtDmhsm662rC15FRXmcjCk1xD771dFDx5jJ02ufAQQTn0etB2emNk4J9EZg/yWKpsn9BWGRw==} - engines: {node: '>=14.18.0', npm: '>=8.0.0'} + resolution: + { + integrity: sha512-oWzmBZwvYrU0iJHtDmhsm662rC15FRXmcjCk1xD771dFDx5jJ02ufAQQTn0etB2emNk4J9EZg/yWKpsn9BWGRw==, + } + engines: { node: '>=14.18.0', npm: '>=8.0.0' } hasBin: true optionalDependencies: fsevents: 2.3.3 dev: true /source-map-js@1.0.2: - resolution: {integrity: sha512-R0XvVJ9WusLiqTCEiGCmICCMplcCkIwwR11mOSD9CR5u+IXYdiseeEuXCVAjS54zqwkLcPNnmU4OeJ6tUrWhDw==} - engines: {node: '>=0.10.0'} + resolution: + { + integrity: sha512-R0XvVJ9WusLiqTCEiGCmICCMplcCkIwwR11mOSD9CR5u+IXYdiseeEuXCVAjS54zqwkLcPNnmU4OeJ6tUrWhDw==, + } + engines: { node: '>=0.10.0' } dev: true /typescript@5.2.2: - resolution: {integrity: sha512-mI4WrpHsbCIcwT9cF4FZvr80QUeKvsUsUvKDoR+X/7XHQH98xYD8YHZg7ANtz2GtZt/CBq2QJ0thkGJMHfqc1w==} - engines: {node: '>=14.17'} + resolution: + { + integrity: sha512-mI4WrpHsbCIcwT9cF4FZvr80QUeKvsUsUvKDoR+X/7XHQH98xYD8YHZg7ANtz2GtZt/CBq2QJ0thkGJMHfqc1w==, + } + engines: { node: '>=14.17' } hasBin: true dev: true /vite@4.5.0: - resolution: {integrity: sha512-ulr8rNLA6rkyFAlVWw2q5YJ91v098AFQ2R0PRFwPzREXOUJQPtFUG0t+/ZikhaOCDqFoDhN6/v8Sq0o4araFAw==} - engines: {node: ^14.18.0 || >=16.0.0} + resolution: + { + integrity: sha512-ulr8rNLA6rkyFAlVWw2q5YJ91v098AFQ2R0PRFwPzREXOUJQPtFUG0t+/ZikhaOCDqFoDhN6/v8Sq0o4araFAw==, + } + engines: { node: ^14.18.0 || >=16.0.0 } hasBin: true peerDependencies: '@types/node': '>= 14' diff --git a/examples/train-web/src/counter.ts b/examples/train-web/src/counter.ts index 09e5afd2d8..164395d512 100644 --- a/examples/train-web/src/counter.ts +++ b/examples/train-web/src/counter.ts @@ -1,9 +1,9 @@ export function setupCounter(element: HTMLButtonElement) { - let counter = 0 - const setCounter = (count: number) => { - counter = count - element.innerHTML = `count is ${counter}` - } - element.addEventListener('click', () => setCounter(counter + 1)) - setCounter(0) + let counter = 0 + const setCounter = (count: number) => { + counter = count + element.innerHTML = `count is ${counter}` + } + element.addEventListener('click', () => setCounter(counter + 1)) + setCounter(0) } diff --git a/examples/train-web/src/style.css b/examples/train-web/src/style.css index b528b6cc28..b425acb01b 100644 --- a/examples/train-web/src/style.css +++ b/examples/train-web/src/style.css @@ -1,97 +1,97 @@ :root { - font-family: Inter, system-ui, Avenir, Helvetica, Arial, sans-serif; - line-height: 1.5; - font-weight: 400; + font-family: Inter, system-ui, Avenir, Helvetica, Arial, sans-serif; + line-height: 1.5; + font-weight: 400; - color-scheme: light dark; - color: rgba(255, 255, 255, 0.87); - background-color: #242424; + color-scheme: light dark; + color: rgba(255, 255, 255, 0.87); + background-color: #242424; - font-synthesis: none; - text-rendering: optimizeLegibility; - -webkit-font-smoothing: antialiased; - -moz-osx-font-smoothing: grayscale; - -webkit-text-size-adjust: 100%; + font-synthesis: none; + text-rendering: optimizeLegibility; + -webkit-font-smoothing: antialiased; + -moz-osx-font-smoothing: grayscale; + -webkit-text-size-adjust: 100%; } a { - font-weight: 500; - color: #646cff; - text-decoration: inherit; + font-weight: 500; + color: #646cff; + text-decoration: inherit; } a:hover { - color: #535bf2; + color: #535bf2; } body { - margin: 0; - display: flex; - place-items: center; - min-width: 320px; - min-height: 100vh; + margin: 0; + display: flex; + place-items: center; + min-width: 320px; + min-height: 100vh; } h1 { - font-size: 3.2em; - line-height: 1.1; + font-size: 3.2em; + line-height: 1.1; } #app { - max-width: 1280px; - margin: 0 auto; - padding: 2rem; - text-align: center; + max-width: 1280px; + margin: 0 auto; + padding: 2rem; + text-align: center; } .logo { - height: 6em; - padding: 1.5em; - will-change: filter; - transition: filter 300ms; + height: 6em; + padding: 1.5em; + will-change: filter; + transition: filter 300ms; } .logo:hover { - filter: drop-shadow(0 0 2em #646cffaa); + filter: drop-shadow(0 0 2em #646cffaa); } .logo.vanilla:hover { - filter: drop-shadow(0 0 2em #3178c6aa); + filter: drop-shadow(0 0 2em #3178c6aa); } .card { - padding: 2em; + padding: 2em; } .read-the-docs { - color: #888; + color: #888; } button { - border-radius: 8px; - border: 1px solid transparent; - padding: 0.6em 1.2em; - font-size: 1em; - font-weight: 500; - font-family: inherit; - background-color: #1a1a1a; - cursor: pointer; - transition: border-color 0.25s; + border-radius: 8px; + border: 1px solid transparent; + padding: 0.6em 1.2em; + font-size: 1em; + font-weight: 500; + font-family: inherit; + background-color: #1a1a1a; + cursor: pointer; + transition: border-color 0.25s; } button:hover { - border-color: #646cff; + border-color: #646cff; } button:focus, button:focus-visible { - outline: 4px auto -webkit-focus-ring-color; + outline: 4px auto -webkit-focus-ring-color; } @media (prefers-color-scheme: light) { - :root { - color: #213547; - background-color: #ffffff; - } - a:hover { - color: #747bff; - } - button { - background-color: #f9f9f9; - } + :root { + color: #213547; + background-color: #ffffff; + } + a:hover { + color: #747bff; + } + button { + background-color: #f9f9f9; + } } diff --git a/examples/train-web/tsconfig.json b/examples/train-web/tsconfig.json index 75abdef265..e338064653 100644 --- a/examples/train-web/tsconfig.json +++ b/examples/train-web/tsconfig.json @@ -1,23 +1,23 @@ { - "compilerOptions": { - "target": "ES2020", - "useDefineForClassFields": true, - "module": "ESNext", - "lib": ["ES2020", "DOM", "DOM.Iterable"], - "skipLibCheck": true, + "compilerOptions": { + "target": "ES2020", + "useDefineForClassFields": true, + "module": "ESNext", + "lib": ["ES2020", "DOM", "DOM.Iterable"], + "skipLibCheck": true, - /* Bundler mode */ - "moduleResolution": "bundler", - "allowImportingTsExtensions": true, - "resolveJsonModule": true, - "isolatedModules": true, - "noEmit": true, + /* Bundler mode */ + "moduleResolution": "bundler", + "allowImportingTsExtensions": true, + "resolveJsonModule": true, + "isolatedModules": true, + "noEmit": true, - /* Linting */ - "strict": true, - "noUnusedLocals": true, - "noUnusedParameters": true, - "noFallthroughCasesInSwitch": true - }, - "include": ["src"] + /* Linting */ + "strict": true, + "noUnusedLocals": true, + "noUnusedParameters": true, + "noFallthroughCasesInSwitch": true + }, + "include": ["src"] } From bddd331615882af9544f8d0ff74d9759c897ffc3 Mon Sep 17 00:00:00 2001 From: Alex Date: Thu, 2 Nov 2023 17:06:11 -0500 Subject: [PATCH 07/79] move everything to web folder --- examples/train-web/{ => web}/.gitignore | 0 examples/train-web/{ => web}/.prettierrc | 0 examples/train-web/{ => web}/index.html | 0 examples/train-web/{ => web}/package.json | 0 examples/train-web/{ => web}/pnpm-lock.yaml | 0 examples/train-web/{ => web}/public/vite.svg | 0 examples/train-web/{ => web}/src/counter.ts | 0 examples/train-web/{ => web}/src/main.ts | 0 examples/train-web/{ => web}/src/style.css | 0 examples/train-web/{ => web}/src/typescript.svg | 0 examples/train-web/{ => web}/src/vite-env.d.ts | 0 examples/train-web/{ => web}/tsconfig.json | 0 12 files changed, 0 insertions(+), 0 deletions(-) rename examples/train-web/{ => web}/.gitignore (100%) rename examples/train-web/{ => web}/.prettierrc (100%) rename examples/train-web/{ => web}/index.html (100%) rename examples/train-web/{ => web}/package.json (100%) rename examples/train-web/{ => web}/pnpm-lock.yaml (100%) rename examples/train-web/{ => web}/public/vite.svg (100%) rename examples/train-web/{ => web}/src/counter.ts (100%) rename examples/train-web/{ => web}/src/main.ts (100%) rename examples/train-web/{ => web}/src/style.css (100%) rename examples/train-web/{ => web}/src/typescript.svg (100%) rename examples/train-web/{ => web}/src/vite-env.d.ts (100%) rename examples/train-web/{ => web}/tsconfig.json (100%) diff --git a/examples/train-web/.gitignore b/examples/train-web/web/.gitignore similarity index 100% rename from examples/train-web/.gitignore rename to examples/train-web/web/.gitignore diff --git a/examples/train-web/.prettierrc b/examples/train-web/web/.prettierrc similarity index 100% rename from examples/train-web/.prettierrc rename to examples/train-web/web/.prettierrc diff --git a/examples/train-web/index.html b/examples/train-web/web/index.html similarity index 100% rename from examples/train-web/index.html rename to examples/train-web/web/index.html diff --git a/examples/train-web/package.json b/examples/train-web/web/package.json similarity index 100% rename from examples/train-web/package.json rename to examples/train-web/web/package.json diff --git a/examples/train-web/pnpm-lock.yaml b/examples/train-web/web/pnpm-lock.yaml similarity index 100% rename from examples/train-web/pnpm-lock.yaml rename to examples/train-web/web/pnpm-lock.yaml diff --git a/examples/train-web/public/vite.svg b/examples/train-web/web/public/vite.svg similarity index 100% rename from examples/train-web/public/vite.svg rename to examples/train-web/web/public/vite.svg diff --git a/examples/train-web/src/counter.ts b/examples/train-web/web/src/counter.ts similarity index 100% rename from examples/train-web/src/counter.ts rename to examples/train-web/web/src/counter.ts diff --git a/examples/train-web/src/main.ts b/examples/train-web/web/src/main.ts similarity index 100% rename from examples/train-web/src/main.ts rename to examples/train-web/web/src/main.ts diff --git a/examples/train-web/src/style.css b/examples/train-web/web/src/style.css similarity index 100% rename from examples/train-web/src/style.css rename to examples/train-web/web/src/style.css diff --git a/examples/train-web/src/typescript.svg b/examples/train-web/web/src/typescript.svg similarity index 100% rename from examples/train-web/src/typescript.svg rename to examples/train-web/web/src/typescript.svg diff --git a/examples/train-web/src/vite-env.d.ts b/examples/train-web/web/src/vite-env.d.ts similarity index 100% rename from examples/train-web/src/vite-env.d.ts rename to examples/train-web/web/src/vite-env.d.ts diff --git a/examples/train-web/tsconfig.json b/examples/train-web/web/tsconfig.json similarity index 100% rename from examples/train-web/tsconfig.json rename to examples/train-web/web/tsconfig.json From 501b360c890ef9c572825ab194f6d95736e778bf Mon Sep 17 00:00:00 2001 From: Alex Date: Thu, 2 Nov 2023 17:22:32 -0500 Subject: [PATCH 08/79] update workspace --- Cargo.toml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index fe5079484f..eba95dda4a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -24,10 +24,11 @@ members = [ "burn-train", "xtask", "examples/*", + "examples/train-web/train", "backend-comparison", ] -exclude = ["examples/notebook"] +exclude = ["examples/notebook", "examples/train-web"] [workspace.dependencies] async-trait = "0.1.73" From 9ff1615f4cc2d643dbb500168e1f089623355026 Mon Sep 17 00:00:00 2001 From: Alex Date: Thu, 2 Nov 2023 17:36:34 -0500 Subject: [PATCH 09/79] cargo new train --lib --- examples/train-web/train/Cargo.toml | 8 ++++++++ examples/train-web/train/src/lib.rs | 14 ++++++++++++++ 2 files changed, 22 insertions(+) create mode 100644 examples/train-web/train/Cargo.toml create mode 100644 examples/train-web/train/src/lib.rs diff --git a/examples/train-web/train/Cargo.toml b/examples/train-web/train/Cargo.toml new file mode 100644 index 0000000000..0aaad18c7c --- /dev/null +++ b/examples/train-web/train/Cargo.toml @@ -0,0 +1,8 @@ +[package] +name = "train" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] diff --git a/examples/train-web/train/src/lib.rs b/examples/train-web/train/src/lib.rs new file mode 100644 index 0000000000..7d12d9af81 --- /dev/null +++ b/examples/train-web/train/src/lib.rs @@ -0,0 +1,14 @@ +pub fn add(left: usize, right: usize) -> usize { + left + right +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn it_works() { + let result = add(2, 2); + assert_eq!(result, 4); + } +} From 0a279deb47e07c5f46ad51a3225a655e43f78daf Mon Sep 17 00:00:00 2001 From: Alex Date: Fri, 3 Nov 2023 16:55:50 -0500 Subject: [PATCH 10/79] pnpm add ../train/pkg --- examples/train-web/web/package.json | 3 + examples/train-web/web/pnpm-lock.yaml | 228 ++++++++------------------ 2 files changed, 72 insertions(+), 159 deletions(-) diff --git a/examples/train-web/web/package.json b/examples/train-web/web/package.json index a89dfe068f..632a4ba0d1 100644 --- a/examples/train-web/web/package.json +++ b/examples/train-web/web/package.json @@ -12,5 +12,8 @@ "prettier": "3.0.3", "typescript": "^5.2.2", "vite": "^4.5.0" + }, + "dependencies": { + "train": "link:../train/pkg" } } diff --git a/examples/train-web/web/pnpm-lock.yaml b/examples/train-web/web/pnpm-lock.yaml index 41dbb1ab05..682004e490 100644 --- a/examples/train-web/web/pnpm-lock.yaml +++ b/examples/train-web/web/pnpm-lock.yaml @@ -4,6 +4,11 @@ settings: autoInstallPeers: true excludeLinksFromLockfile: false +dependencies: + train: + specifier: link:../train/pkg + version: link:../train/pkg + devDependencies: prettier: specifier: 3.0.3 @@ -16,12 +21,10 @@ devDependencies: version: 4.5.0 packages: + /@esbuild/android-arm64@0.18.20: - resolution: - { - integrity: sha512-Nz4rJcchGDtENV0eMKUNa6L12zz2zBDXuhj/Vjh18zGqB44Bi7MBMSXjgunJgjRhCmKOjnPuZp4Mb6OKqtMHLQ==, - } - engines: { node: '>=12' } + resolution: {integrity: sha512-Nz4rJcchGDtENV0eMKUNa6L12zz2zBDXuhj/Vjh18zGqB44Bi7MBMSXjgunJgjRhCmKOjnPuZp4Mb6OKqtMHLQ==} + engines: {node: '>=12'} cpu: [arm64] os: [android] requiresBuild: true @@ -29,11 +32,8 @@ packages: optional: true /@esbuild/android-arm@0.18.20: - resolution: - { - integrity: sha512-fyi7TDI/ijKKNZTUJAQqiG5T7YjJXgnzkURqmGj13C6dCqckZBLdl4h7bkhHt/t0WP+zO9/zwroDvANaOqO5Sw==, - } - engines: { node: '>=12' } + resolution: {integrity: sha512-fyi7TDI/ijKKNZTUJAQqiG5T7YjJXgnzkURqmGj13C6dCqckZBLdl4h7bkhHt/t0WP+zO9/zwroDvANaOqO5Sw==} + engines: {node: '>=12'} cpu: [arm] os: [android] requiresBuild: true @@ -41,11 +41,8 @@ packages: optional: true /@esbuild/android-x64@0.18.20: - resolution: - { - integrity: sha512-8GDdlePJA8D6zlZYJV/jnrRAi6rOiNaCC/JclcXpB+KIuvfBN4owLtgzY2bsxnx666XjJx2kDPUmnTtR8qKQUg==, - } - engines: { node: '>=12' } + resolution: {integrity: sha512-8GDdlePJA8D6zlZYJV/jnrRAi6rOiNaCC/JclcXpB+KIuvfBN4owLtgzY2bsxnx666XjJx2kDPUmnTtR8qKQUg==} + engines: {node: '>=12'} cpu: [x64] os: [android] requiresBuild: true @@ -53,11 +50,8 @@ packages: optional: true /@esbuild/darwin-arm64@0.18.20: - resolution: - { - integrity: sha512-bxRHW5kHU38zS2lPTPOyuyTm+S+eobPUnTNkdJEfAddYgEcll4xkT8DB9d2008DtTbl7uJag2HuE5NZAZgnNEA==, - } - engines: { node: '>=12' } + resolution: {integrity: sha512-bxRHW5kHU38zS2lPTPOyuyTm+S+eobPUnTNkdJEfAddYgEcll4xkT8DB9d2008DtTbl7uJag2HuE5NZAZgnNEA==} + engines: {node: '>=12'} cpu: [arm64] os: [darwin] requiresBuild: true @@ -65,11 +59,8 @@ packages: optional: true /@esbuild/darwin-x64@0.18.20: - resolution: - { - integrity: sha512-pc5gxlMDxzm513qPGbCbDukOdsGtKhfxD1zJKXjCCcU7ju50O7MeAZ8c4krSJcOIJGFR+qx21yMMVYwiQvyTyQ==, - } - engines: { node: '>=12' } + resolution: {integrity: sha512-pc5gxlMDxzm513qPGbCbDukOdsGtKhfxD1zJKXjCCcU7ju50O7MeAZ8c4krSJcOIJGFR+qx21yMMVYwiQvyTyQ==} + engines: {node: '>=12'} cpu: [x64] os: [darwin] requiresBuild: true @@ -77,11 +68,8 @@ packages: optional: true /@esbuild/freebsd-arm64@0.18.20: - resolution: - { - integrity: sha512-yqDQHy4QHevpMAaxhhIwYPMv1NECwOvIpGCZkECn8w2WFHXjEwrBn3CeNIYsibZ/iZEUemj++M26W3cNR5h+Tw==, - } - engines: { node: '>=12' } + resolution: {integrity: sha512-yqDQHy4QHevpMAaxhhIwYPMv1NECwOvIpGCZkECn8w2WFHXjEwrBn3CeNIYsibZ/iZEUemj++M26W3cNR5h+Tw==} + engines: {node: '>=12'} cpu: [arm64] os: [freebsd] requiresBuild: true @@ -89,11 +77,8 @@ packages: optional: true /@esbuild/freebsd-x64@0.18.20: - resolution: - { - integrity: sha512-tgWRPPuQsd3RmBZwarGVHZQvtzfEBOreNuxEMKFcd5DaDn2PbBxfwLcj4+aenoh7ctXcbXmOQIn8HI6mCSw5MQ==, - } - engines: { node: '>=12' } + resolution: {integrity: sha512-tgWRPPuQsd3RmBZwarGVHZQvtzfEBOreNuxEMKFcd5DaDn2PbBxfwLcj4+aenoh7ctXcbXmOQIn8HI6mCSw5MQ==} + engines: {node: '>=12'} cpu: [x64] os: [freebsd] requiresBuild: true @@ -101,11 +86,8 @@ packages: optional: true /@esbuild/linux-arm64@0.18.20: - resolution: - { - integrity: sha512-2YbscF+UL7SQAVIpnWvYwM+3LskyDmPhe31pE7/aoTMFKKzIc9lLbyGUpmmb8a8AixOL61sQ/mFh3jEjHYFvdA==, - } - engines: { node: '>=12' } + resolution: {integrity: sha512-2YbscF+UL7SQAVIpnWvYwM+3LskyDmPhe31pE7/aoTMFKKzIc9lLbyGUpmmb8a8AixOL61sQ/mFh3jEjHYFvdA==} + engines: {node: '>=12'} cpu: [arm64] os: [linux] requiresBuild: true @@ -113,11 +95,8 @@ packages: optional: true /@esbuild/linux-arm@0.18.20: - resolution: - { - integrity: sha512-/5bHkMWnq1EgKr1V+Ybz3s1hWXok7mDFUMQ4cG10AfW3wL02PSZi5kFpYKrptDsgb2WAJIvRcDm+qIvXf/apvg==, - } - engines: { node: '>=12' } + resolution: {integrity: sha512-/5bHkMWnq1EgKr1V+Ybz3s1hWXok7mDFUMQ4cG10AfW3wL02PSZi5kFpYKrptDsgb2WAJIvRcDm+qIvXf/apvg==} + engines: {node: '>=12'} cpu: [arm] os: [linux] requiresBuild: true @@ -125,11 +104,8 @@ packages: optional: true /@esbuild/linux-ia32@0.18.20: - resolution: - { - integrity: sha512-P4etWwq6IsReT0E1KHU40bOnzMHoH73aXp96Fs8TIT6z9Hu8G6+0SHSw9i2isWrD2nbx2qo5yUqACgdfVGx7TA==, - } - engines: { node: '>=12' } + resolution: {integrity: sha512-P4etWwq6IsReT0E1KHU40bOnzMHoH73aXp96Fs8TIT6z9Hu8G6+0SHSw9i2isWrD2nbx2qo5yUqACgdfVGx7TA==} + engines: {node: '>=12'} cpu: [ia32] os: [linux] requiresBuild: true @@ -137,11 +113,8 @@ packages: optional: true /@esbuild/linux-loong64@0.18.20: - resolution: - { - integrity: sha512-nXW8nqBTrOpDLPgPY9uV+/1DjxoQ7DoB2N8eocyq8I9XuqJ7BiAMDMf9n1xZM9TgW0J8zrquIb/A7s3BJv7rjg==, - } - engines: { node: '>=12' } + resolution: {integrity: sha512-nXW8nqBTrOpDLPgPY9uV+/1DjxoQ7DoB2N8eocyq8I9XuqJ7BiAMDMf9n1xZM9TgW0J8zrquIb/A7s3BJv7rjg==} + engines: {node: '>=12'} cpu: [loong64] os: [linux] requiresBuild: true @@ -149,11 +122,8 @@ packages: optional: true /@esbuild/linux-mips64el@0.18.20: - resolution: - { - integrity: sha512-d5NeaXZcHp8PzYy5VnXV3VSd2D328Zb+9dEq5HE6bw6+N86JVPExrA6O68OPwobntbNJ0pzCpUFZTo3w0GyetQ==, - } - engines: { node: '>=12' } + resolution: {integrity: sha512-d5NeaXZcHp8PzYy5VnXV3VSd2D328Zb+9dEq5HE6bw6+N86JVPExrA6O68OPwobntbNJ0pzCpUFZTo3w0GyetQ==} + engines: {node: '>=12'} cpu: [mips64el] os: [linux] requiresBuild: true @@ -161,11 +131,8 @@ packages: optional: true /@esbuild/linux-ppc64@0.18.20: - resolution: - { - integrity: sha512-WHPyeScRNcmANnLQkq6AfyXRFr5D6N2sKgkFo2FqguP44Nw2eyDlbTdZwd9GYk98DZG9QItIiTlFLHJHjxP3FA==, - } - engines: { node: '>=12' } + resolution: {integrity: sha512-WHPyeScRNcmANnLQkq6AfyXRFr5D6N2sKgkFo2FqguP44Nw2eyDlbTdZwd9GYk98DZG9QItIiTlFLHJHjxP3FA==} + engines: {node: '>=12'} cpu: [ppc64] os: [linux] requiresBuild: true @@ -173,11 +140,8 @@ packages: optional: true /@esbuild/linux-riscv64@0.18.20: - resolution: - { - integrity: sha512-WSxo6h5ecI5XH34KC7w5veNnKkju3zBRLEQNY7mv5mtBmrP/MjNBCAlsM2u5hDBlS3NGcTQpoBvRzqBcRtpq1A==, - } - engines: { node: '>=12' } + resolution: {integrity: sha512-WSxo6h5ecI5XH34KC7w5veNnKkju3zBRLEQNY7mv5mtBmrP/MjNBCAlsM2u5hDBlS3NGcTQpoBvRzqBcRtpq1A==} + engines: {node: '>=12'} cpu: [riscv64] os: [linux] requiresBuild: true @@ -185,11 +149,8 @@ packages: optional: true /@esbuild/linux-s390x@0.18.20: - resolution: - { - integrity: sha512-+8231GMs3mAEth6Ja1iK0a1sQ3ohfcpzpRLH8uuc5/KVDFneH6jtAJLFGafpzpMRO6DzJ6AvXKze9LfFMrIHVQ==, - } - engines: { node: '>=12' } + resolution: {integrity: sha512-+8231GMs3mAEth6Ja1iK0a1sQ3ohfcpzpRLH8uuc5/KVDFneH6jtAJLFGafpzpMRO6DzJ6AvXKze9LfFMrIHVQ==} + engines: {node: '>=12'} cpu: [s390x] os: [linux] requiresBuild: true @@ -197,11 +158,8 @@ packages: optional: true /@esbuild/linux-x64@0.18.20: - resolution: - { - integrity: sha512-UYqiqemphJcNsFEskc73jQ7B9jgwjWrSayxawS6UVFZGWrAAtkzjxSqnoclCXxWtfwLdzU+vTpcNYhpn43uP1w==, - } - engines: { node: '>=12' } + resolution: {integrity: sha512-UYqiqemphJcNsFEskc73jQ7B9jgwjWrSayxawS6UVFZGWrAAtkzjxSqnoclCXxWtfwLdzU+vTpcNYhpn43uP1w==} + engines: {node: '>=12'} cpu: [x64] os: [linux] requiresBuild: true @@ -209,11 +167,8 @@ packages: optional: true /@esbuild/netbsd-x64@0.18.20: - resolution: - { - integrity: sha512-iO1c++VP6xUBUmltHZoMtCUdPlnPGdBom6IrO4gyKPFFVBKioIImVooR5I83nTew5UOYrk3gIJhbZh8X44y06A==, - } - engines: { node: '>=12' } + resolution: {integrity: sha512-iO1c++VP6xUBUmltHZoMtCUdPlnPGdBom6IrO4gyKPFFVBKioIImVooR5I83nTew5UOYrk3gIJhbZh8X44y06A==} + engines: {node: '>=12'} cpu: [x64] os: [netbsd] requiresBuild: true @@ -221,11 +176,8 @@ packages: optional: true /@esbuild/openbsd-x64@0.18.20: - resolution: - { - integrity: sha512-e5e4YSsuQfX4cxcygw/UCPIEP6wbIL+se3sxPdCiMbFLBWu0eiZOJ7WoD+ptCLrmjZBK1Wk7I6D/I3NglUGOxg==, - } - engines: { node: '>=12' } + resolution: {integrity: sha512-e5e4YSsuQfX4cxcygw/UCPIEP6wbIL+se3sxPdCiMbFLBWu0eiZOJ7WoD+ptCLrmjZBK1Wk7I6D/I3NglUGOxg==} + engines: {node: '>=12'} cpu: [x64] os: [openbsd] requiresBuild: true @@ -233,11 +185,8 @@ packages: optional: true /@esbuild/sunos-x64@0.18.20: - resolution: - { - integrity: sha512-kDbFRFp0YpTQVVrqUd5FTYmWo45zGaXe0X8E1G/LKFC0v8x0vWrhOWSLITcCn63lmZIxfOMXtCfti/RxN/0wnQ==, - } - engines: { node: '>=12' } + resolution: {integrity: sha512-kDbFRFp0YpTQVVrqUd5FTYmWo45zGaXe0X8E1G/LKFC0v8x0vWrhOWSLITcCn63lmZIxfOMXtCfti/RxN/0wnQ==} + engines: {node: '>=12'} cpu: [x64] os: [sunos] requiresBuild: true @@ -245,11 +194,8 @@ packages: optional: true /@esbuild/win32-arm64@0.18.20: - resolution: - { - integrity: sha512-ddYFR6ItYgoaq4v4JmQQaAI5s7npztfV4Ag6NrhiaW0RrnOXqBkgwZLofVTlq1daVTQNhtI5oieTvkRPfZrePg==, - } - engines: { node: '>=12' } + resolution: {integrity: sha512-ddYFR6ItYgoaq4v4JmQQaAI5s7npztfV4Ag6NrhiaW0RrnOXqBkgwZLofVTlq1daVTQNhtI5oieTvkRPfZrePg==} + engines: {node: '>=12'} cpu: [arm64] os: [win32] requiresBuild: true @@ -257,11 +203,8 @@ packages: optional: true /@esbuild/win32-ia32@0.18.20: - resolution: - { - integrity: sha512-Wv7QBi3ID/rROT08SABTS7eV4hX26sVduqDOTe1MvGMjNd3EjOz4b7zeexIR62GTIEKrfJXKL9LFxTYgkyeu7g==, - } - engines: { node: '>=12' } + resolution: {integrity: sha512-Wv7QBi3ID/rROT08SABTS7eV4hX26sVduqDOTe1MvGMjNd3EjOz4b7zeexIR62GTIEKrfJXKL9LFxTYgkyeu7g==} + engines: {node: '>=12'} cpu: [ia32] os: [win32] requiresBuild: true @@ -269,11 +212,8 @@ packages: optional: true /@esbuild/win32-x64@0.18.20: - resolution: - { - integrity: sha512-kTdfRcSiDfQca/y9QIkng02avJ+NCaQvrMejlsB3RRv5sE9rRoeBPISaZpKxHELzRxZyLvNts1P27W3wV+8geQ==, - } - engines: { node: '>=12' } + resolution: {integrity: sha512-kTdfRcSiDfQca/y9QIkng02avJ+NCaQvrMejlsB3RRv5sE9rRoeBPISaZpKxHELzRxZyLvNts1P27W3wV+8geQ==} + engines: {node: '>=12'} cpu: [x64] os: [win32] requiresBuild: true @@ -281,11 +221,8 @@ packages: optional: true /esbuild@0.18.20: - resolution: - { - integrity: sha512-ceqxoedUrcayh7Y7ZX6NdbbDzGROiyVBgC4PriJThBKSVPWnnFHZAkfI1lJT8QFkOwH4qOS2SJkS4wvpGl8BpA==, - } - engines: { node: '>=12' } + resolution: {integrity: sha512-ceqxoedUrcayh7Y7ZX6NdbbDzGROiyVBgC4PriJThBKSVPWnnFHZAkfI1lJT8QFkOwH4qOS2SJkS4wvpGl8BpA==} + engines: {node: '>=12'} hasBin: true requiresBuild: true optionalDependencies: @@ -314,38 +251,26 @@ packages: dev: true /fsevents@2.3.3: - resolution: - { - integrity: sha512-5xoDfX+fL7faATnagmWPpbFtwh/R77WmMMqqHGS65C3vvB0YHrgF+B1YmZ3441tMj5n63k0212XNoJwzlhffQw==, - } - engines: { node: ^8.16.0 || ^10.6.0 || >=11.0.0 } + resolution: {integrity: sha512-5xoDfX+fL7faATnagmWPpbFtwh/R77WmMMqqHGS65C3vvB0YHrgF+B1YmZ3441tMj5n63k0212XNoJwzlhffQw==} + engines: {node: ^8.16.0 || ^10.6.0 || >=11.0.0} os: [darwin] requiresBuild: true dev: true optional: true /nanoid@3.3.6: - resolution: - { - integrity: sha512-BGcqMMJuToF7i1rt+2PWSNVnWIkGCU78jBG3RxO/bZlnZPK2Cmi2QaffxGO/2RvWi9sL+FAiRiXMgsyxQ1DIDA==, - } - engines: { node: ^10 || ^12 || ^13.7 || ^14 || >=15.0.1 } + resolution: {integrity: sha512-BGcqMMJuToF7i1rt+2PWSNVnWIkGCU78jBG3RxO/bZlnZPK2Cmi2QaffxGO/2RvWi9sL+FAiRiXMgsyxQ1DIDA==} + engines: {node: ^10 || ^12 || ^13.7 || ^14 || >=15.0.1} hasBin: true dev: true /picocolors@1.0.0: - resolution: - { - integrity: sha512-1fygroTLlHu66zi26VoTDv8yRgm0Fccecssto+MhsZ0D/DGW2sm8E8AjW7NU5VVTRt5GxbeZ5qBuJr+HyLYkjQ==, - } + resolution: {integrity: sha512-1fygroTLlHu66zi26VoTDv8yRgm0Fccecssto+MhsZ0D/DGW2sm8E8AjW7NU5VVTRt5GxbeZ5qBuJr+HyLYkjQ==} dev: true /postcss@8.4.31: - resolution: - { - integrity: sha512-PS08Iboia9mts/2ygV3eLpY5ghnUcfLV/EXTOW1E2qYxJKGGBUtNjN76FYHnMs36RmARn41bC0AZmn+rR0OVpQ==, - } - engines: { node: ^10 || ^12 || >=14 } + resolution: {integrity: sha512-PS08Iboia9mts/2ygV3eLpY5ghnUcfLV/EXTOW1E2qYxJKGGBUtNjN76FYHnMs36RmARn41bC0AZmn+rR0OVpQ==} + engines: {node: ^10 || ^12 || >=14} dependencies: nanoid: 3.3.6 picocolors: 1.0.0 @@ -353,48 +278,33 @@ packages: dev: true /prettier@3.0.3: - resolution: - { - integrity: sha512-L/4pUDMxcNa8R/EthV08Zt42WBO4h1rarVtK0K+QJG0X187OLo7l699jWw0GKuwzkPQ//jMFA/8Xm6Fh3J/DAg==, - } - engines: { node: '>=14' } + resolution: {integrity: sha512-L/4pUDMxcNa8R/EthV08Zt42WBO4h1rarVtK0K+QJG0X187OLo7l699jWw0GKuwzkPQ//jMFA/8Xm6Fh3J/DAg==} + engines: {node: '>=14'} hasBin: true dev: true /rollup@3.29.4: - resolution: - { - integrity: sha512-oWzmBZwvYrU0iJHtDmhsm662rC15FRXmcjCk1xD771dFDx5jJ02ufAQQTn0etB2emNk4J9EZg/yWKpsn9BWGRw==, - } - engines: { node: '>=14.18.0', npm: '>=8.0.0' } + resolution: {integrity: sha512-oWzmBZwvYrU0iJHtDmhsm662rC15FRXmcjCk1xD771dFDx5jJ02ufAQQTn0etB2emNk4J9EZg/yWKpsn9BWGRw==} + engines: {node: '>=14.18.0', npm: '>=8.0.0'} hasBin: true optionalDependencies: fsevents: 2.3.3 dev: true /source-map-js@1.0.2: - resolution: - { - integrity: sha512-R0XvVJ9WusLiqTCEiGCmICCMplcCkIwwR11mOSD9CR5u+IXYdiseeEuXCVAjS54zqwkLcPNnmU4OeJ6tUrWhDw==, - } - engines: { node: '>=0.10.0' } + resolution: {integrity: sha512-R0XvVJ9WusLiqTCEiGCmICCMplcCkIwwR11mOSD9CR5u+IXYdiseeEuXCVAjS54zqwkLcPNnmU4OeJ6tUrWhDw==} + engines: {node: '>=0.10.0'} dev: true /typescript@5.2.2: - resolution: - { - integrity: sha512-mI4WrpHsbCIcwT9cF4FZvr80QUeKvsUsUvKDoR+X/7XHQH98xYD8YHZg7ANtz2GtZt/CBq2QJ0thkGJMHfqc1w==, - } - engines: { node: '>=14.17' } + resolution: {integrity: sha512-mI4WrpHsbCIcwT9cF4FZvr80QUeKvsUsUvKDoR+X/7XHQH98xYD8YHZg7ANtz2GtZt/CBq2QJ0thkGJMHfqc1w==} + engines: {node: '>=14.17'} hasBin: true dev: true /vite@4.5.0: - resolution: - { - integrity: sha512-ulr8rNLA6rkyFAlVWw2q5YJ91v098AFQ2R0PRFwPzREXOUJQPtFUG0t+/ZikhaOCDqFoDhN6/v8Sq0o4araFAw==, - } - engines: { node: ^14.18.0 || >=16.0.0 } + resolution: {integrity: sha512-ulr8rNLA6rkyFAlVWw2q5YJ91v098AFQ2R0PRFwPzREXOUJQPtFUG0t+/ZikhaOCDqFoDhN6/v8Sq0o4araFAw==} + engines: {node: ^14.18.0 || >=16.0.0} hasBin: true peerDependencies: '@types/node': '>= 14' From 0af481a2779543b7b99c8ca7e97f08340edb6670 Mon Sep 17 00:00:00 2001 From: Alex Date: Fri, 3 Nov 2023 16:59:36 -0500 Subject: [PATCH 11/79] add vite config --- examples/train-web/web/vite.config.ts | 10 ++++++++++ 1 file changed, 10 insertions(+) create mode 100644 examples/train-web/web/vite.config.ts diff --git a/examples/train-web/web/vite.config.ts b/examples/train-web/web/vite.config.ts new file mode 100644 index 0000000000..d713a8fbff --- /dev/null +++ b/examples/train-web/web/vite.config.ts @@ -0,0 +1,10 @@ +import { defineConfig } from 'vite' + +export default defineConfig({ + server: { + fs: { + // Allow serving files from one level up to the project root + allow: ['..'], + }, + }, +}) From 2740d4ce3d5f0840693426d155c217d946b90e40 Mon Sep 17 00:00:00 2001 From: Alex Date: Fri, 3 Nov 2023 19:12:37 -0500 Subject: [PATCH 12/79] can call train's run from web --- examples/train-web/train/Cargo.toml | 7 ++++++- examples/train-web/train/build-for-web.sh | 17 +++++++++++++++++ examples/train-web/train/src/lib.rs | 20 +++++++++----------- examples/train-web/web/src/counter.ts | 5 +++++ 4 files changed, 37 insertions(+), 12 deletions(-) create mode 100755 examples/train-web/train/build-for-web.sh diff --git a/examples/train-web/train/Cargo.toml b/examples/train-web/train/Cargo.toml index 0aaad18c7c..168cff30aa 100644 --- a/examples/train-web/train/Cargo.toml +++ b/examples/train-web/train/Cargo.toml @@ -3,6 +3,11 @@ name = "train" version = "0.1.0" edition = "2021" -# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html +[lib] +crate-type = ["cdylib"] [dependencies] +wasm-bindgen = { version = "0.2.88" } +log = { workspace = true } +console_error_panic_hook = "0.1.7" +console_log = { version = "1", features = ["color"] } diff --git a/examples/train-web/train/build-for-web.sh b/examples/train-web/train/build-for-web.sh new file mode 100755 index 0000000000..af36a53eba --- /dev/null +++ b/examples/train-web/train/build-for-web.sh @@ -0,0 +1,17 @@ +#!/usr/bin/env bash + +# Add wasm32 target for compiler. +rustup target add wasm32-unknown-unknown + +if ! command -v wasm-pack &>/dev/null; then + echo "wasm-pack could not be found. Installing ..." + cargo install wasm-pack + exit +fi + +# Set optimization flags +export RUSTFLAGS="-C lto=fat -C embed-bitcode=yes -C codegen-units=1 -C opt-level=3 --cfg web_sys_unstable_apis" + +# Run wasm pack tool to build JS wrapper files and copy wasm to pkg directory. +mkdir -p pkg +wasm-pack build --out-dir pkg --release --target web --no-default-features diff --git a/examples/train-web/train/src/lib.rs b/examples/train-web/train/src/lib.rs index 7d12d9af81..188a46d832 100644 --- a/examples/train-web/train/src/lib.rs +++ b/examples/train-web/train/src/lib.rs @@ -1,14 +1,12 @@ -pub fn add(left: usize, right: usize) -> usize { - left + right -} +use wasm_bindgen::prelude::*; -#[cfg(test)] -mod tests { - use super::*; +#[wasm_bindgen(start)] +pub fn start() { + console_error_panic_hook::set_once(); + console_log::init().expect("Error initializing logger"); +} - #[test] - fn it_works() { - let result = add(2, 2); - assert_eq!(result, 4); - } +#[wasm_bindgen] +pub fn run() { + log::info!("Hello from Rust"); } diff --git a/examples/train-web/web/src/counter.ts b/examples/train-web/web/src/counter.ts index 164395d512..b0b092b2c2 100644 --- a/examples/train-web/web/src/counter.ts +++ b/examples/train-web/web/src/counter.ts @@ -1,6 +1,11 @@ +import wasm, { run } from 'train' + +await wasm() + export function setupCounter(element: HTMLButtonElement) { let counter = 0 const setCounter = (count: number) => { + run() counter = count element.innerHTML = `count is ${counter}` } From cc423e645e672ef1688dc11f9bef556d20836dc9 Mon Sep 17 00:00:00 2001 From: Alex Date: Sat, 4 Nov 2023 11:35:35 -0500 Subject: [PATCH 13/79] add train.rs --- examples/train-web/train/Cargo.toml | 7 ++ examples/train-web/train/src/lib.rs | 4 + examples/train-web/train/src/train.rs | 131 ++++++++++++++++++++++++++ 3 files changed, 142 insertions(+) create mode 100644 examples/train-web/train/src/train.rs diff --git a/examples/train-web/train/Cargo.toml b/examples/train-web/train/Cargo.toml index 168cff30aa..4c7f514cfb 100644 --- a/examples/train-web/train/Cargo.toml +++ b/examples/train-web/train/Cargo.toml @@ -11,3 +11,10 @@ wasm-bindgen = { version = "0.2.88" } log = { workspace = true } console_error_panic_hook = "0.1.7" console_log = { version = "1", features = ["color"] } +burn = { path = "../../../burn", default-features = false, features = [ + "autodiff", + "ndarray-no-std", + "train-minimal", + "wasm-sync", +] } +serde = { workspace = true } diff --git a/examples/train-web/train/src/lib.rs b/examples/train-web/train/src/lib.rs index 188a46d832..40f0a1dcc5 100644 --- a/examples/train-web/train/src/lib.rs +++ b/examples/train-web/train/src/lib.rs @@ -1,5 +1,8 @@ +use burn::backend::{ndarray::NdArrayDevice, Autodiff, NdArray}; use wasm_bindgen::prelude::*; +mod train; + #[wasm_bindgen(start)] pub fn start() { console_error_panic_hook::set_once(); @@ -9,4 +12,5 @@ pub fn start() { #[wasm_bindgen] pub fn run() { log::info!("Hello from Rust"); + train::run::>>(NdArrayDevice::Cpu); } diff --git a/examples/train-web/train/src/train.rs b/examples/train-web/train/src/train.rs new file mode 100644 index 0000000000..5031daf01c --- /dev/null +++ b/examples/train-web/train/src/train.rs @@ -0,0 +1,131 @@ +use burn::{ + config::Config, + module::Module, + nn::{ + conv::{Conv2d, Conv2dConfig}, + pool::{AdaptiveAvgPool2d, AdaptiveAvgPool2dConfig}, + Dropout, DropoutConfig, Linear, LinearConfig, ReLU, + }, + optim::AdamConfig, + tensor::backend::{AutodiffBackend, Backend}, + train::{ + renderer::{MetricState, MetricsRenderer, TrainingProgress}, + ClassificationOutput, LearnerBuilder, + }, +}; + +#[derive(Config, Debug)] +pub struct ModelConfig { + num_classes: usize, + hidden_size: usize, + #[config(default = "0.5")] + dropout: f64, +} + +#[derive(Config)] +pub struct MnistTrainingConfig { + #[config(default = 10)] + pub num_epochs: usize, + #[config(default = 64)] + pub batch_size: usize, + #[config(default = 4)] + pub num_workers: usize, + #[config(default = 42)] + pub seed: u64, + #[config(default = 1e-4)] + pub lr: f64, + pub model: ModelConfig, + pub optimizer: AdamConfig, +} + +struct CustomRenderer {} + +impl MetricsRenderer for CustomRenderer { + fn update_train(&mut self, _state: MetricState) {} + + fn update_valid(&mut self, _state: MetricState) {} + + fn render_train(&mut self, item: TrainingProgress) { + dbg!(item); + } + + fn render_valid(&mut self, item: TrainingProgress) { + dbg!(item); + } +} + +#[derive(Module, Debug)] +pub struct Model { + conv1: Conv2d, + conv2: Conv2d, + pool: AdaptiveAvgPool2d, + dropout: Dropout, + linear1: Linear, + linear2: Linear, + activation: ReLU, +} + +impl ModelConfig { + /// Returns the initialized model. + pub fn init(&self) -> Model { + Model { + conv1: Conv2dConfig::new([1, 8], [3, 3]).init(), + conv2: Conv2dConfig::new([8, 16], [3, 3]).init(), + pool: AdaptiveAvgPool2dConfig::new([8, 8]).init(), + activation: ReLU::new(), + linear1: LinearConfig::new(16 * 8 * 8, self.hidden_size).init(), + linear2: LinearConfig::new(self.hidden_size, self.num_classes).init(), + dropout: DropoutConfig::new(self.dropout).init(), + } + } +} + +pub fn run(device: B::Device) { + // Create the configuration. + let config_model = ModelConfig::new(10, 1024); + let config_optimizer = AdamConfig::new(); + let config = MnistTrainingConfig::new(config_model, config_optimizer); + + B::seed(config.seed); + + // Create the model and optimizer. + let model = config.model.init(); + let optim = config.optimizer.init(); + + // Create the batcher. + // let batcher_train = MNISTBatcher::::new(device.clone()); + // let batcher_valid = MNISTBatcher::::new(device.clone()); + + // Create the dataloaders. + // let dataloader_train = DataLoaderBuilder::new(batcher_train) + // .batch_size(config.batch_size) + // .shuffle(config.seed) + // .num_workers(config.num_workers) + // .build(MNISTDataset::train()); + // + // let dataloader_test = DataLoaderBuilder::new(batcher_valid) + // .batch_size(config.batch_size) + // .shuffle(config.seed) + // .num_workers(config.num_workers) + // .build(MNISTDataset::test()); + + // artifact dir does not need to be provided when log_to_file is false + let builder: LearnerBuilder< + B, + ClassificationOutput, + ClassificationOutput<::InnerBackend>, + _, + _, + _, + > = LearnerBuilder::new("") + .devices(vec![device]) + .num_epochs(config.num_epochs) + .renderer(CustomRenderer {}) + .log_to_file(false); + // can be used to interrupt training + let _interrupter = builder.interrupter(); + + let _learner = builder.build(model, optim, config.lr); + + // let _model_trained = learner.fit(dataloader_train, dataloader_test); +} From 40f4140e982b1392648b671087935ee3d47762b8 Mon Sep 17 00:00:00 2001 From: Alex Date: Sat, 4 Nov 2023 12:09:26 -0500 Subject: [PATCH 14/79] make dev friendly --- examples/train-web/train/build-for-web.sh | 11 +++++++++-- examples/train-web/train/dev.sh | 3 +++ examples/train-web/web/src/counter.ts | 2 ++ 3 files changed, 14 insertions(+), 2 deletions(-) create mode 100755 examples/train-web/train/dev.sh diff --git a/examples/train-web/train/build-for-web.sh b/examples/train-web/train/build-for-web.sh index af36a53eba..7f62254ba1 100755 --- a/examples/train-web/train/build-for-web.sh +++ b/examples/train-web/train/build-for-web.sh @@ -1,5 +1,7 @@ #!/usr/bin/env bash +set -eo pipefail # https://stackoverflow.com/a/2871034 + # Add wasm32 target for compiler. rustup target add wasm32-unknown-unknown @@ -10,8 +12,13 @@ if ! command -v wasm-pack &>/dev/null; then fi # Set optimization flags -export RUSTFLAGS="-C lto=fat -C embed-bitcode=yes -C codegen-units=1 -C opt-level=3 --cfg web_sys_unstable_apis" +if [[ $1 == "release" ]]; then + export RUSTFLAGS="-C lto=fat -C embed-bitcode=yes -C codegen-units=1 -C opt-level=3 --cfg web_sys_unstable_apis" +else + # sets $1 to "dev" + set -- dev +fi # Run wasm pack tool to build JS wrapper files and copy wasm to pkg directory. mkdir -p pkg -wasm-pack build --out-dir pkg --release --target web --no-default-features +wasm-pack build --out-dir pkg --$1 --target web --no-default-features diff --git a/examples/train-web/train/dev.sh b/examples/train-web/train/dev.sh new file mode 100755 index 0000000000..221ab49e51 --- /dev/null +++ b/examples/train-web/train/dev.sh @@ -0,0 +1,3 @@ +#!/usr/bin/env bash + +cargo watch -- ./build-for-web.sh diff --git a/examples/train-web/web/src/counter.ts b/examples/train-web/web/src/counter.ts index b0b092b2c2..af2b245d46 100644 --- a/examples/train-web/web/src/counter.ts +++ b/examples/train-web/web/src/counter.ts @@ -1,5 +1,7 @@ import wasm, { run } from 'train' +Error.stackTraceLimit = 30 + await wasm() export function setupCounter(element: HTMLButtonElement) { From a77829dfd295caa0fbb6b69aa9a999ee7fece973 Mon Sep 17 00:00:00 2001 From: Alex Date: Thu, 9 Nov 2023 07:48:06 -0600 Subject: [PATCH 15/79] implemented spawn using webworkers --- Cargo.toml | 2 +- burn-train/Cargo.toml | 30 +++++----- burn-train/src/lib.rs | 2 + burn-train/src/metric/store/client.rs | 9 +-- burn-train/src/util.rs | 61 ++++++++++++++++++++ burn/Cargo.toml | 12 +++- examples/train-web/readme.md | 14 +++++ examples/train-web/train/.cargo/config.toml | 6 ++ examples/train-web/train/.gitignore | 1 + examples/train-web/train/Cargo.toml | 11 ++-- examples/train-web/train/rust-toolchain.toml | 2 + examples/train-web/train/src/lib.rs | 15 ++++- examples/train-web/web/src/counter.ts | 4 +- examples/train-web/web/src/worker.ts | 40 +++++++++++++ examples/train-web/web/vite.config.ts | 13 +++++ 15 files changed, 194 insertions(+), 28 deletions(-) create mode 100644 burn-train/src/util.rs create mode 100644 examples/train-web/readme.md create mode 100644 examples/train-web/train/.cargo/config.toml create mode 100644 examples/train-web/train/.gitignore create mode 100644 examples/train-web/train/rust-toolchain.toml create mode 100644 examples/train-web/web/src/worker.ts diff --git a/Cargo.toml b/Cargo.toml index eba95dda4a..66d7d86bf1 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -67,7 +67,7 @@ thiserror = "1.0.49" tracing-appender = "0.2.2" tracing-core = "0.1.31" tracing-subscriber = "0.3.17" -wasm-bindgen = "0.2.87" +wasm-bindgen = "=0.2.88" wasm-bindgen-futures = "0.4.37" wasm-logger = "0.2.0" diff --git a/burn-train/Cargo.toml b/burn-train/Cargo.toml index 8440363a38..2d0e00612e 100644 --- a/burn-train/Cargo.toml +++ b/burn-train/Cargo.toml @@ -12,20 +12,14 @@ version = "0.11.0" [features] default = ["metrics", "tui"] -metrics = [ - "nvml-wrapper", - "sysinfo", - "systemstat" -] -tui = [ - "ratatui", - "crossterm" -] +metrics = ["nvml-wrapper", "sysinfo", "systemstat"] +tui = ["ratatui", "crossterm"] +browser = ["js-sys", "web-sys", "wasm-bindgen"] [dependencies] -burn-core = {path = "../burn-core", version = "0.11.0" } +burn-core = { path = "../burn-core", version = "0.11.0" } -log = {workspace = true} +log = { workspace = true } tracing-subscriber.workspace = true tracing-appender.workspace = true tracing-core.workspace = true @@ -40,8 +34,16 @@ ratatui = { version = "0.23", optional = true, features = ["all-widgets"] } crossterm = { version = "0.27", optional = true } # Utilities -derive-new = {workspace = true} -serde = {workspace = true, features = ["std", "derive"]} +derive-new = { workspace = true } +serde = { workspace = true, features = ["std", "derive"] } + +js-sys = { version = "0.3.64", optional = true } +web-sys = { version = "0.3.65", optional = true, features = [ + "Worker", + "WorkerOptions", + "WorkerType", +] } +wasm-bindgen = { workspace = true, optional = true } [dev-dependencies] -burn-ndarray = {path = "../burn-ndarray", version = "0.11.0" } +burn-ndarray = { path = "../burn-ndarray", version = "0.11.0" } diff --git a/burn-train/src/lib.rs b/burn-train/src/lib.rs index cb4f39785a..f4bc430ee3 100644 --- a/burn-train/src/lib.rs +++ b/burn-train/src/lib.rs @@ -2,6 +2,8 @@ //! A library for training neural networks using the burn crate. +pub mod util; + #[macro_use] extern crate derive_new; diff --git a/burn-train/src/metric/store/client.rs b/burn-train/src/metric/store/client.rs index 74ba83ab74..b018456caa 100644 --- a/burn-train/src/metric/store/client.rs +++ b/burn-train/src/metric/store/client.rs @@ -1,11 +1,12 @@ use super::EventStore; use super::{Aggregate, Direction, Event, Split}; -use std::{sync::mpsc, thread::JoinHandle}; +use crate::util; +use std::sync::mpsc; /// Type that allows to communicate with an [event store](EventStore). pub struct EventStoreClient { sender: mpsc::Sender, - handler: Option>, + handler: Option Result<(), ()>>>, } impl EventStoreClient { @@ -17,7 +18,7 @@ impl EventStoreClient { let (sender, receiver) = mpsc::channel(); let thread = WorkerThread::new(store, receiver); - let handler = std::thread::spawn(move || thread.run()); + let handler = util::spawn(move || thread.run()); let handler = Some(handler); Self { sender, handler } @@ -153,7 +154,7 @@ impl Drop for EventStoreClient { let handler = self.handler.take(); if let Some(handler) = handler { - handler.join().expect("The event store thread should stop."); + handler().expect("The event store thread should stop."); } } } diff --git a/burn-train/src/util.rs b/burn-train/src/util.rs new file mode 100644 index 0000000000..e4784a9c83 --- /dev/null +++ b/burn-train/src/util.rs @@ -0,0 +1,61 @@ +#![allow(missing_docs)] + +#[cfg(feature = "browser")] +use wasm_bindgen::prelude::*; + +#[cfg(not(feature = "browser"))] +pub fn spawn(f: F) -> Box Result<(), ()>> +where + F: FnOnce() -> (), + F: Send + 'static, +{ + let handle = std::thread::spawn(f); + Box::new(move || handle.join().map_err(|_| ())) +} + +// High level description at https://www.tweag.io/blog/2022-11-24-wasm-threads-and-messages/ +// Mostly copied from https://github.com/tweag/rust-wasm-threads/blob/main/shared-memory/src/lib.rs +#[cfg(feature = "browser")] +pub fn spawn(f: F) -> Box Result<(), ()>> +where + F: FnOnce() -> (), + F: Send + 'static, +{ + let mut worker_options = web_sys::WorkerOptions::new(); + worker_options.type_(web_sys::WorkerType::Module); + // Double-boxing because `dyn FnOnce` is unsized and so `Box` has + let w = web_sys::Worker::new_with_options( + WORKER_URL + .get() + .expect("You must first call `init` with the worker's url."), + &worker_options, + ) + .expect(&format!("Error initializing worker at {:?}", WORKER_URL)); + // an undefined layout (although I think in practice its a pointer and a length?). + let ptr = Box::into_raw(Box::new(Box::new(f) as Box)); + + // See `worker.js` for the format of this message. + let msg: js_sys::Array = [ + &wasm_bindgen::module(), + &wasm_bindgen::memory(), + &JsValue::from(ptr as u32), + ] + .into_iter() + .collect(); + if let Err(e) = w.post_message(&msg) { + // We expect the worker to deallocate the box, but if there was an error then + // we'll do it ourselves. + let _ = unsafe { Box::from_raw(ptr) }; + panic!("Error initializing worker during post_message: {:?}", e) + } else { + Box::new(move || Ok(w.terminate())) + } +} + +#[cfg(feature = "browser")] +static WORKER_URL: std::sync::OnceLock = std::sync::OnceLock::new(); + +#[cfg(feature = "browser")] +pub fn init(worker_url: String) -> Result<(), String> { + WORKER_URL.set(worker_url) +} diff --git a/burn/Cargo.toml b/burn/Cargo.toml index 0605066e38..1fb432b11e 100644 --- a/burn/Cargo.toml +++ b/burn/Cargo.toml @@ -21,6 +21,8 @@ train = ["burn-train/default", "autodiff", "dataset"] # Useful when targeting WASM and not using WGPU. wasm-sync = ["burn-core/wasm-sync"] +browser = ["burn-train/browser"] + ## Include nothing train-minimal = ["burn-train"] @@ -62,4 +64,12 @@ burn-core = { path = "../burn-core", version = "0.11.0", default-features = fals burn-train = { path = "../burn-train", version = "0.11.0", optional = true, default-features = false } [package.metadata.docs.rs] -features = ["dataset", "default", "std", "train", "train-tui", "train-metrics", "dataset-sqlite"] +features = [ + "dataset", + "default", + "std", + "train", + "train-tui", + "train-metrics", + "dataset-sqlite", +] diff --git a/examples/train-web/readme.md b/examples/train-web/readme.md new file mode 100644 index 0000000000..70b5d78a8f --- /dev/null +++ b/examples/train-web/readme.md @@ -0,0 +1,14 @@ +## Getting Started + +Your `wasm-bindgen-cli` version must *exactly* match the `wasm-bindgen` version in [Cargo.toml](../../Cargo.toml) since `wasm-bindgen-cli` is implicitly used by `wasm-pack`. + +For example, run `cargo install --version 0.2.88 wasm-bindgen-cli --force`. The version in this example command is not guaranteed to be up to date! + +Install [PNPM](https://pnpm.io/). + +Then in separate terminals: + +1. `cd train && dev.sh` +2. `cd web && pnpm i && pnpm dev` + +Any changes to `/train` or `burn` should trigger a recompilation. When a new binary is generated, `web` will automatically refresh the page. diff --git a/examples/train-web/train/.cargo/config.toml b/examples/train-web/train/.cargo/config.toml new file mode 100644 index 0000000000..0b10571c21 --- /dev/null +++ b/examples/train-web/train/.cargo/config.toml @@ -0,0 +1,6 @@ +[unstable] +build-std = ['std', 'panic_abort'] + +[build] +target = "wasm32-unknown-unknown" +rustflags = '-Ctarget-feature=+atomics,+bulk-memory,+mutable-globals' diff --git a/examples/train-web/train/.gitignore b/examples/train-web/train/.gitignore new file mode 100644 index 0000000000..5fff1d9c18 --- /dev/null +++ b/examples/train-web/train/.gitignore @@ -0,0 +1 @@ +pkg diff --git a/examples/train-web/train/Cargo.toml b/examples/train-web/train/Cargo.toml index 4c7f514cfb..01b3f86604 100644 --- a/examples/train-web/train/Cargo.toml +++ b/examples/train-web/train/Cargo.toml @@ -7,14 +7,15 @@ edition = "2021" crate-type = ["cdylib"] [dependencies] -wasm-bindgen = { version = "0.2.88" } +wasm-bindgen = { workspace = true } log = { workspace = true } console_error_panic_hook = "0.1.7" console_log = { version = "1", features = ["color"] } burn = { path = "../../../burn", default-features = false, features = [ - "autodiff", - "ndarray-no-std", - "train-minimal", - "wasm-sync", + "autodiff", + "ndarray-no-std", + "train-minimal", + "wasm-sync", + "browser", ] } serde = { workspace = true } diff --git a/examples/train-web/train/rust-toolchain.toml b/examples/train-web/train/rust-toolchain.toml new file mode 100644 index 0000000000..f29c21eaa5 --- /dev/null +++ b/examples/train-web/train/rust-toolchain.toml @@ -0,0 +1,2 @@ +[toolchain] +channel = "nightly-2023-11-05" diff --git a/examples/train-web/train/src/lib.rs b/examples/train-web/train/src/lib.rs index 40f0a1dcc5..3efc7f9e26 100644 --- a/examples/train-web/train/src/lib.rs +++ b/examples/train-web/train/src/lib.rs @@ -1,4 +1,7 @@ -use burn::backend::{ndarray::NdArrayDevice, Autodiff, NdArray}; +use burn::{ + backend::{ndarray::NdArrayDevice, Autodiff, NdArray}, + train::util, +}; use wasm_bindgen::prelude::*; mod train; @@ -10,7 +13,15 @@ pub fn start() { } #[wasm_bindgen] -pub fn run() { +pub fn run(worker_url: String) -> Result<(), String> { log::info!("Hello from Rust"); + util::init(worker_url)?; train::run::>>(NdArrayDevice::Cpu); + Ok(()) +} + +#[wasm_bindgen] +pub fn child_entry_point(ptr: u32) { + let work = unsafe { Box::from_raw(ptr as *mut Box) }; + (*work)(); } diff --git a/examples/train-web/web/src/counter.ts b/examples/train-web/web/src/counter.ts index af2b245d46..a82ed2db5b 100644 --- a/examples/train-web/web/src/counter.ts +++ b/examples/train-web/web/src/counter.ts @@ -1,13 +1,15 @@ import wasm, { run } from 'train' +import workerUrl from './worker.ts?url' +// @ts-ignore https://github.com/rustwasm/console_error_panic_hook#errorstacktracelimit Error.stackTraceLimit = 30 await wasm() +run(workerUrl) export function setupCounter(element: HTMLButtonElement) { let counter = 0 const setCounter = (count: number) => { - run() counter = count element.innerHTML = `count is ${counter}` } diff --git a/examples/train-web/web/src/worker.ts b/examples/train-web/web/src/worker.ts new file mode 100644 index 0000000000..fbc481f6d0 --- /dev/null +++ b/examples/train-web/web/src/worker.ts @@ -0,0 +1,40 @@ +// Mostly copied from https://github.com/tweag/rust-wasm-threads/blob/main/shared-memory/worker.js + +import trainWasmUrl from 'train/train_bg.wasm?url' +import wasm, { child_entry_point } from 'train' + +self.onmessage = async (event) => { + // We expect a message with three elements [module, memory, ptr], where: + // - module is a WebAssembly.Module + // - memory is the WebAssembly.Memory object that the main thread is using + // (and we want to use it too). + // - ptr is the pointer (within `memory`) of the function that we want to execute. + // + // The auto-generated `wasm_bindgen` function takes two *optional* arguments. + // The first is the module (you can pass in a url, like we did in "index.js", + // or a module object); the second is the memory block to use, and if you + // don't provide one (like we didn't in "index.js") then a new one will be + // allocated for you. + let init = await wasm(trainWasmUrl, event.data[1]).catch((err) => { + // Propagate to main `onerror`: + setTimeout(() => { + throw err + }) + // Rethrow to keep promise rejected and prevent execution of further commands: + throw err + }) + + child_entry_point(event.data[2]) + + // Clean up thread resources. Depending on what you're doing with the thread, this might + // not be what you want. (For example, if the thread spawned some javascript tasks + // and exited, this is going to cancel those tasks.) But if you're using threads in the + // usual native way (where you spin one up to do some work until it finisheds) then + // you'll want to clean up the thread's resources. + + // @ts-ignore -- can take 0 args, the typescript is incorrect https://github.com/rustwasm/wasm-bindgen/blob/962f580d6f5def2df4223588de96feeb30ce7572/crates/threads-xform/src/lib.rs#L394-L395 + // Free memory (stack, thread-locals) held (in the wasm linear memory) by the thread. + init.__wbindgen_thread_destroy() + // Tell the browser to stop the thread. + close() +} diff --git a/examples/train-web/web/vite.config.ts b/examples/train-web/web/vite.config.ts index d713a8fbff..d44a8a3f2f 100644 --- a/examples/train-web/web/vite.config.ts +++ b/examples/train-web/web/vite.config.ts @@ -1,6 +1,19 @@ import { defineConfig } from 'vite' export default defineConfig({ + plugins: [ + { + // https://gist.github.com/mizchi/afcc5cf233c9e6943720fde4b4579a2b + name: 'isolation', + configureServer(server) { + server.middlewares.use((_req, res, next) => { + res.setHeader('Cross-Origin-Opener-Policy', 'same-origin') + res.setHeader('Cross-Origin-Embedder-Policy', 'require-corp') + next() + }) + }, + }, + ], server: { fs: { // Allow serving files from one level up to the project root From 2e074e781102365a6e002ad0e91b70a6229ad035 Mon Sep 17 00:00:00 2001 From: Alex Date: Thu, 9 Nov 2023 07:09:31 -0600 Subject: [PATCH 16/79] update notices --- NOTICES.md | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/NOTICES.md b/NOTICES.md index f0f65e0203..b486e8e637 100644 --- a/NOTICES.md +++ b/NOTICES.md @@ -245,4 +245,28 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. +## Rust threads in the browser +**Source**: https://github.com/tweag/rust-wasm-threads/tree/main/shared-memory + +MIT License + +Copyright (c) 2023 Matthew Toohey and Modus Create LLC + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. From 1cf8dd18e4321843a0fedae993d74cb4b86183be Mon Sep 17 00:00:00 2001 From: Alex Date: Thu, 9 Nov 2023 16:11:20 -0600 Subject: [PATCH 17/79] Arc => Rc --- burn-train/src/checkpoint/strategy/metric.rs | 4 ++-- burn-train/src/learner/base.rs | 3 ++- burn-train/src/learner/builder.rs | 4 ++-- burn-train/src/learner/early_stopping.rs | 4 ++-- burn-train/src/metric/processor/full.rs | 6 +++--- burn-train/src/metric/processor/minimal.rs | 4 ++-- 6 files changed, 13 insertions(+), 12 deletions(-) diff --git a/burn-train/src/checkpoint/strategy/metric.rs b/burn-train/src/checkpoint/strategy/metric.rs index f2aa58efeb..7a1cd6085e 100644 --- a/burn-train/src/checkpoint/strategy/metric.rs +++ b/burn-train/src/checkpoint/strategy/metric.rs @@ -76,7 +76,7 @@ mod tests { }, TestBackend, }; - use std::sync::Arc; + use std::rc::Rc; use super::*; @@ -93,7 +93,7 @@ mod tests { store.register_logger_train(InMemoryMetricLogger::default()); // Register the loss metric. metrics.register_train_metric_numeric(LossMetric::::new()); - let store = Arc::new(EventStoreClient::new(store)); + let store = Rc::new(EventStoreClient::new(store)); let mut processor = MinimalEventProcessor::new(metrics, store.clone()); // Two points for the first epoch. Mean 0.75 diff --git a/burn-train/src/learner/base.rs b/burn-train/src/learner/base.rs index 55a5515ef7..545f429070 100644 --- a/burn-train/src/learner/base.rs +++ b/burn-train/src/learner/base.rs @@ -6,6 +6,7 @@ use burn_core::lr_scheduler::LrScheduler; use burn_core::module::Module; use burn_core::optim::Optimizer; use burn_core::tensor::backend::Backend; +use std::rc::Rc; use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; @@ -24,7 +25,7 @@ pub struct Learner { pub(crate) interrupter: TrainingInterrupter, pub(crate) early_stopping: Option>, pub(crate) event_processor: LC::EventProcessor, - pub(crate) event_store: Arc, + pub(crate) event_store: Rc, } #[derive(new)] diff --git a/burn-train/src/learner/builder.rs b/burn-train/src/learner/builder.rs index 99e6d90bea..ac8568b467 100644 --- a/burn-train/src/learner/builder.rs +++ b/burn-train/src/learner/builder.rs @@ -1,4 +1,4 @@ -use std::sync::Arc; +use std::rc::Rc; use super::log::install_file_logger; use super::Learner; @@ -312,7 +312,7 @@ where )); } - let event_store = Arc::new(EventStoreClient::new(self.event_store)); + let event_store = Rc::new(EventStoreClient::new(self.event_store)); let event_processor = FullEventProcessor::new(self.metrics, renderer, event_store.clone()); let checkpointer = self.checkpointers.map(|(model, optim, scheduler)| { diff --git a/burn-train/src/learner/early_stopping.rs b/burn-train/src/learner/early_stopping.rs index 641d49551b..334756d885 100644 --- a/burn-train/src/learner/early_stopping.rs +++ b/burn-train/src/learner/early_stopping.rs @@ -104,7 +104,7 @@ impl MetricEarlyStoppingStrategy { #[cfg(test)] mod tests { - use std::sync::Arc; + use std::{rc::Rc, sync::Arc}; use crate::{ logger::InMemoryMetricLogger, @@ -188,7 +188,7 @@ mod tests { store.register_logger_train(InMemoryMetricLogger::default()); metrics.register_train_metric_numeric(LossMetric::::new()); - let store = Arc::new(EventStoreClient::new(store)); + let store = Rc::new(EventStoreClient::new(store)); let mut processor = MinimalEventProcessor::new(metrics, store.clone()); let mut epoch = 1; diff --git a/burn-train/src/metric/processor/full.rs b/burn-train/src/metric/processor/full.rs index b25870dfb4..9f76588c88 100644 --- a/burn-train/src/metric/processor/full.rs +++ b/burn-train/src/metric/processor/full.rs @@ -1,7 +1,7 @@ use super::{Event, EventProcessor, Metrics}; use crate::metric::store::EventStoreClient; use crate::renderer::{MetricState, MetricsRenderer}; -use std::sync::Arc; +use std::rc::Rc; /// An [event processor](EventProcessor) that handles: /// - Computing and storing metrics in an [event store](crate::metric::store::EventStore). @@ -9,14 +9,14 @@ use std::sync::Arc; pub struct FullEventProcessor { metrics: Metrics, renderer: Box, - store: Arc, + store: Rc, } impl FullEventProcessor { pub(crate) fn new( metrics: Metrics, renderer: Box, - store: Arc, + store: Rc, ) -> Self { Self { metrics, diff --git a/burn-train/src/metric/processor/minimal.rs b/burn-train/src/metric/processor/minimal.rs index bb60713e45..e95d2e8b46 100644 --- a/burn-train/src/metric/processor/minimal.rs +++ b/burn-train/src/metric/processor/minimal.rs @@ -1,13 +1,13 @@ use super::{Event, EventProcessor, Metrics}; use crate::metric::store::EventStoreClient; -use std::sync::Arc; +use std::rc::Rc; /// An [event processor](EventProcessor) that handles: /// - Computing and storing metrics in an [event store](crate::metric::store::EventStore). #[derive(new)] pub(crate) struct MinimalEventProcessor { metrics: Metrics, - store: Arc, + store: Rc, } impl EventProcessor for MinimalEventProcessor { From f688a3d95e84b1bff879efaf14906d77ee021413 Mon Sep 17 00:00:00 2001 From: Alex Date: Thu, 9 Nov 2023 16:45:13 -0600 Subject: [PATCH 18/79] clippy --- burn-train/src/util.rs | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/burn-train/src/util.rs b/burn-train/src/util.rs index e4784a9c83..22dfcf391f 100644 --- a/burn-train/src/util.rs +++ b/burn-train/src/util.rs @@ -6,7 +6,7 @@ use wasm_bindgen::prelude::*; #[cfg(not(feature = "browser"))] pub fn spawn(f: F) -> Box Result<(), ()>> where - F: FnOnce() -> (), + F: FnOnce(), F: Send + 'static, { let handle = std::thread::spawn(f); @@ -18,7 +18,7 @@ where #[cfg(feature = "browser")] pub fn spawn(f: F) -> Box Result<(), ()>> where - F: FnOnce() -> (), + F: FnOnce(), F: Send + 'static, { let mut worker_options = web_sys::WorkerOptions::new(); @@ -30,7 +30,7 @@ where .expect("You must first call `init` with the worker's url."), &worker_options, ) - .expect(&format!("Error initializing worker at {:?}", WORKER_URL)); + .unwrap_or_else(|_| panic!("Error initializing worker at {:?}", WORKER_URL)); // an undefined layout (although I think in practice its a pointer and a length?). let ptr = Box::into_raw(Box::new(Box::new(f) as Box)); @@ -48,7 +48,10 @@ where let _ = unsafe { Box::from_raw(ptr) }; panic!("Error initializing worker during post_message: {:?}", e) } else { - Box::new(move || Ok(w.terminate())) + Box::new(move || { + w.terminate(); + Ok(()) + }) } } From 13f26c9ab8f96b2f5da363439e075c770d28208c Mon Sep 17 00:00:00 2001 From: Alex Date: Thu, 9 Nov 2023 16:45:23 -0600 Subject: [PATCH 19/79] fix CI --- xtask/src/runchecks.rs | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/xtask/src/runchecks.rs b/xtask/src/runchecks.rs index cf0c3364d5..d7b13b79a2 100644 --- a/xtask/src/runchecks.rs +++ b/xtask/src/runchecks.rs @@ -275,7 +275,9 @@ fn std_checks() { } // Test each workspace - cargo_test(["--workspace"].into()); + cargo_test(["--workspace", "--exclude", "burn-train"].into()); + // Separate out `burn-train` to prevent feature unification https://doc.rust-lang.org/cargo/reference/resolver.html#feature-resolver-version-2:~:text=When%20building%20multiple,separate%20cargo%20invocations. + cargo_test(["-p", "burn-train", "--lib"].into()); // Test burn-dataset features burn_dataset_features_std(); From 11cb46bac6b1bcba404957127aaa6e5b41f9843b Mon Sep 17 00:00:00 2001 From: Alex Date: Mon, 13 Nov 2023 07:56:18 -0600 Subject: [PATCH 20/79] pnpm i sql.js --- examples/train-web/web/package.json | 2 ++ examples/train-web/web/pnpm-lock.yaml | 31 +++++++++++++++++++++++++++ 2 files changed, 33 insertions(+) diff --git a/examples/train-web/web/package.json b/examples/train-web/web/package.json index 632a4ba0d1..b13e35ebb6 100644 --- a/examples/train-web/web/package.json +++ b/examples/train-web/web/package.json @@ -9,11 +9,13 @@ "preview": "vite preview" }, "devDependencies": { + "@types/sql.js": "^1.4.9", "prettier": "3.0.3", "typescript": "^5.2.2", "vite": "^4.5.0" }, "dependencies": { + "sql.js": "^1.8.0", "train": "link:../train/pkg" } } diff --git a/examples/train-web/web/pnpm-lock.yaml b/examples/train-web/web/pnpm-lock.yaml index 682004e490..55074d7bde 100644 --- a/examples/train-web/web/pnpm-lock.yaml +++ b/examples/train-web/web/pnpm-lock.yaml @@ -5,11 +5,17 @@ settings: excludeLinksFromLockfile: false dependencies: + sql.js: + specifier: ^1.8.0 + version: 1.8.0 train: specifier: link:../train/pkg version: link:../train/pkg devDependencies: + '@types/sql.js': + specifier: ^1.4.9 + version: 1.4.9 prettier: specifier: 3.0.3 version: 3.0.3 @@ -220,6 +226,23 @@ packages: dev: true optional: true + /@types/emscripten@1.39.10: + resolution: {integrity: sha512-TB/6hBkYQJxsZHSqyeuO1Jt0AB/bW6G7rHt9g7lML7SOF6lbgcHvw/Lr+69iqN0qxgXLhWKScAon73JNnptuDw==} + dev: true + + /@types/node@20.9.0: + resolution: {integrity: sha512-nekiGu2NDb1BcVofVcEKMIwzlx4NjHlcjhoxxKBNLtz15Y1z7MYf549DFvkHSId02Ax6kGwWntIBPC3l/JZcmw==} + dependencies: + undici-types: 5.26.5 + dev: true + + /@types/sql.js@1.4.9: + resolution: {integrity: sha512-ep8b36RKHlgWPqjNG9ToUrPiwkhwh0AEzy883mO5Xnd+cL6VBH1EvSjBAAuxLUFF2Vn/moE3Me6v9E1Lo+48GQ==} + dependencies: + '@types/emscripten': 1.39.10 + '@types/node': 20.9.0 + dev: true + /esbuild@0.18.20: resolution: {integrity: sha512-ceqxoedUrcayh7Y7ZX6NdbbDzGROiyVBgC4PriJThBKSVPWnnFHZAkfI1lJT8QFkOwH4qOS2SJkS4wvpGl8BpA==} engines: {node: '>=12'} @@ -296,12 +319,20 @@ packages: engines: {node: '>=0.10.0'} dev: true + /sql.js@1.8.0: + resolution: {integrity: sha512-3HD8pSkZL+5YvYUI8nlvNILs61ALqq34xgmF+BHpqxe68yZIJ1H+sIVIODvni25+CcxHUxDyrTJUL0lE/m7afw==} + dev: false + /typescript@5.2.2: resolution: {integrity: sha512-mI4WrpHsbCIcwT9cF4FZvr80QUeKvsUsUvKDoR+X/7XHQH98xYD8YHZg7ANtz2GtZt/CBq2QJ0thkGJMHfqc1w==} engines: {node: '>=14.17'} hasBin: true dev: true + /undici-types@5.26.5: + resolution: {integrity: sha512-JlCMO+ehdEIKqlFxk6IfVoAUVmgz7cU7zD/h9XZ0qzeosSHmUJVOzSQvvYSYWXkFXC+IfLKSIffhv0sVZup6pA==} + dev: true + /vite@4.5.0: resolution: {integrity: sha512-ulr8rNLA6rkyFAlVWw2q5YJ91v098AFQ2R0PRFwPzREXOUJQPtFUG0t+/ZikhaOCDqFoDhN6/v8Sq0o4araFAw==} engines: {node: ^14.18.0 || >=16.0.0} From 5fa446351a631ea6f8338f519fe6d6b345becca9 Mon Sep 17 00:00:00 2001 From: Alex Date: Mon, 13 Nov 2023 16:41:19 -0600 Subject: [PATCH 21/79] add postinstall to cp sql-wasm.wasm into assets --- examples/train-web/web/.gitignore | 2 ++ examples/train-web/web/package.json | 3 ++- examples/train-web/web/postinstall.sh | 4 ++++ 3 files changed, 8 insertions(+), 1 deletion(-) create mode 100755 examples/train-web/web/postinstall.sh diff --git a/examples/train-web/web/.gitignore b/examples/train-web/web/.gitignore index a547bf36d8..57b2f7ef67 100644 --- a/examples/train-web/web/.gitignore +++ b/examples/train-web/web/.gitignore @@ -22,3 +22,5 @@ dist-ssr *.njsproj *.sln *.sw? + +src/assets/* diff --git a/examples/train-web/web/package.json b/examples/train-web/web/package.json index b13e35ebb6..9d643bb8e5 100644 --- a/examples/train-web/web/package.json +++ b/examples/train-web/web/package.json @@ -6,7 +6,8 @@ "scripts": { "dev": "vite", "build": "tsc && vite build", - "preview": "vite preview" + "preview": "vite preview", + "postinstall": "./postinstall.sh" }, "devDependencies": { "@types/sql.js": "^1.4.9", diff --git a/examples/train-web/web/postinstall.sh b/examples/train-web/web/postinstall.sh new file mode 100755 index 0000000000..e80ed03f01 --- /dev/null +++ b/examples/train-web/web/postinstall.sh @@ -0,0 +1,4 @@ +#!/usr/bin/env bash + +mkdir -p ./src/assets +cp ./node_modules/sql.js/dist/sql-wasm.wasm ./src/assets/sql-wasm.wasm From 94d156dc78472d74fc5729df391ea93d7a62542a Mon Sep 17 00:00:00 2001 From: Alex Date: Mon, 13 Nov 2023 16:56:45 -0600 Subject: [PATCH 22/79] can load mnist data in js and send to rust --- examples/train-web/train/src/lib.rs | 7 +++- examples/train-web/web/src/counter.ts | 9 ----- examples/train-web/web/src/main.ts | 5 +++ examples/train-web/web/src/train.ts | 52 +++++++++++++++++++++++++++ 4 files changed, 63 insertions(+), 10 deletions(-) create mode 100644 examples/train-web/web/src/train.ts diff --git a/examples/train-web/train/src/lib.rs b/examples/train-web/train/src/lib.rs index 3efc7f9e26..fbf374482a 100644 --- a/examples/train-web/train/src/lib.rs +++ b/examples/train-web/train/src/lib.rs @@ -13,7 +13,12 @@ pub fn start() { } #[wasm_bindgen] -pub fn run(worker_url: String) -> Result<(), String> { +pub fn run( + worker_url: String, + labels: &[u8], + images: &[u8], + lengths: &[u16], +) -> Result<(), String> { log::info!("Hello from Rust"); util::init(worker_url)?; train::run::>>(NdArrayDevice::Cpu); diff --git a/examples/train-web/web/src/counter.ts b/examples/train-web/web/src/counter.ts index a82ed2db5b..164395d512 100644 --- a/examples/train-web/web/src/counter.ts +++ b/examples/train-web/web/src/counter.ts @@ -1,12 +1,3 @@ -import wasm, { run } from 'train' -import workerUrl from './worker.ts?url' - -// @ts-ignore https://github.com/rustwasm/console_error_panic_hook#errorstacktracelimit -Error.stackTraceLimit = 30 - -await wasm() -run(workerUrl) - export function setupCounter(element: HTMLButtonElement) { let counter = 0 const setCounter = (count: number) => { diff --git a/examples/train-web/web/src/main.ts b/examples/train-web/web/src/main.ts index 791547b0d3..af823defff 100644 --- a/examples/train-web/web/src/main.ts +++ b/examples/train-web/web/src/main.ts @@ -2,6 +2,7 @@ import './style.css' import typescriptLogo from './typescript.svg' import viteLogo from '/vite.svg' import { setupCounter } from './counter.ts' +import { setupTrain } from './train.ts' document.querySelector('#app')!.innerHTML = `
@@ -15,6 +16,9 @@ document.querySelector('#app')!.innerHTML = `
+
+ +

Click on the Vite and TypeScript logos to learn more

@@ -22,3 +26,4 @@ document.querySelector('#app')!.innerHTML = ` ` setupCounter(document.querySelector('#counter')!) +setupTrain(document.querySelector('#dbfile')!) diff --git a/examples/train-web/web/src/train.ts b/examples/train-web/web/src/train.ts new file mode 100644 index 0000000000..88c7a13862 --- /dev/null +++ b/examples/train-web/web/src/train.ts @@ -0,0 +1,52 @@ +import wasm, { run } from 'train' +import workerUrl from './worker.ts?url' +import initSqlJs, { type Database } from 'sql.js' +import sqliteWasmUrl from './assets/sql-wasm.wasm?url' + +// @ts-ignore https://github.com/rustwasm/console_error_panic_hook#errorstacktracelimit +Error.stackTraceLimit = 30 + +export function setupTrain(element: HTMLInputElement) { + element.onchange = async function () { + await wasm() + const db = await getDb(element.files![0]) + const trainQuery = db.prepare('select label, image_bytes from train') + const imageBytes: Uint8Array[] = [] + const labels: number[] = [] + const lengths: number[] = [] + while (trainQuery.step()) { + const row = trainQuery.getAsObject() + const label = row.label as number + labels.push(label) + const bytes = row.image_bytes as Uint8Array + imageBytes.push(bytes) + lengths.push(bytes.length) + } + run( + workerUrl, + Uint8Array.from(labels), + concat(imageBytes), + Uint16Array.from(lengths), + ) + } +} + +async function getDb(sqliteBuffer: File): Promise { + const sql = await initSqlJs({ + locateFile: () => sqliteWasmUrl, + }) + const buffer = await sqliteBuffer.arrayBuffer() + return new sql.Database(new Uint8Array(buffer)) +} + +// https://stackoverflow.com/a/59902602 +function concat(arrays: Uint8Array[]) { + const totalLength = arrays.reduce((acc, value) => acc + value.length, 0) + const result = new Uint8Array(totalLength) + let length = 0 + for (const array of arrays) { + result.set(array, length) + length += array.length + } + return result +} From 66a7ffd90203cfad8e800e31fbbb4b50897e31a6 Mon Sep 17 00:00:00 2001 From: Alex Date: Mon, 13 Nov 2023 20:10:07 -0600 Subject: [PATCH 23/79] extract init --- burn-train/src/util.rs | 7 ++++++- examples/train-web/train/src/lib.rs | 13 ++++++------- examples/train-web/web/src/train.ts | 14 ++++++-------- 3 files changed, 18 insertions(+), 16 deletions(-) diff --git a/burn-train/src/util.rs b/burn-train/src/util.rs index 22dfcf391f..3d3d44ce82 100644 --- a/burn-train/src/util.rs +++ b/burn-train/src/util.rs @@ -60,5 +60,10 @@ static WORKER_URL: std::sync::OnceLock = std::sync::OnceLock::new(); #[cfg(feature = "browser")] pub fn init(worker_url: String) -> Result<(), String> { - WORKER_URL.set(worker_url) + WORKER_URL.set(worker_url).map_err(|worker_url| { + format!( + "You can only call `init` once. You tried to `init` with: {:?}", + worker_url + ) + }) } diff --git a/examples/train-web/train/src/lib.rs b/examples/train-web/train/src/lib.rs index fbf374482a..b589c33185 100644 --- a/examples/train-web/train/src/lib.rs +++ b/examples/train-web/train/src/lib.rs @@ -13,14 +13,13 @@ pub fn start() { } #[wasm_bindgen] -pub fn run( - worker_url: String, - labels: &[u8], - images: &[u8], - lengths: &[u16], -) -> Result<(), String> { +pub fn init(worker_url: String) -> Result<(), String> { + util::init(worker_url) +} + +#[wasm_bindgen] +pub fn run(labels: &[u8], images: &[u8], lengths: &[u16]) -> Result<(), String> { log::info!("Hello from Rust"); - util::init(worker_url)?; train::run::>>(NdArrayDevice::Cpu); Ok(()) } diff --git a/examples/train-web/web/src/train.ts b/examples/train-web/web/src/train.ts index 88c7a13862..2ffabd0cb4 100644 --- a/examples/train-web/web/src/train.ts +++ b/examples/train-web/web/src/train.ts @@ -1,4 +1,4 @@ -import wasm, { run } from 'train' +import wasm, { init, run } from 'train' import workerUrl from './worker.ts?url' import initSqlJs, { type Database } from 'sql.js' import sqliteWasmUrl from './assets/sql-wasm.wasm?url' @@ -6,9 +6,12 @@ import sqliteWasmUrl from './assets/sql-wasm.wasm?url' // @ts-ignore https://github.com/rustwasm/console_error_panic_hook#errorstacktracelimit Error.stackTraceLimit = 30 +wasm() + .then(() => init(workerUrl)) + .catch(console.error) + export function setupTrain(element: HTMLInputElement) { element.onchange = async function () { - await wasm() const db = await getDb(element.files![0]) const trainQuery = db.prepare('select label, image_bytes from train') const imageBytes: Uint8Array[] = [] @@ -22,12 +25,7 @@ export function setupTrain(element: HTMLInputElement) { imageBytes.push(bytes) lengths.push(bytes.length) } - run( - workerUrl, - Uint8Array.from(labels), - concat(imageBytes), - Uint16Array.from(lengths), - ) + run(Uint8Array.from(labels), concat(imageBytes), Uint16Array.from(lengths)) } } From 7735e4771dc14ff72f2badebbca183b57136df6a Mon Sep 17 00:00:00 2001 From: Alex Date: Tue, 14 Nov 2023 16:31:09 -0600 Subject: [PATCH 24/79] autoload/run mnist on pageload --- examples/train-web/readme.md | 2 + examples/train-web/web/.gitignore | 1 + examples/train-web/web/postinstall.sh | 1 + examples/train-web/web/src/train.ts | 56 ++++++++++++++++++--------- 4 files changed, 41 insertions(+), 19 deletions(-) diff --git a/examples/train-web/readme.md b/examples/train-web/readme.md index 70b5d78a8f..736b912544 100644 --- a/examples/train-web/readme.md +++ b/examples/train-web/readme.md @@ -6,6 +6,8 @@ For example, run `cargo install --version 0.2.88 wasm-bindgen-cli --force`. The Install [PNPM](https://pnpm.io/). +The [`postinstall.sh`](./web/postinstall.sh) script expects the mnist database to be at `~/.cache/burn-dataset/mnist.db`. Running `guide` will generate this file. Alternatively, you can download it from [Hugging Face](https://huggingface.co/datasets/mnist). + Then in separate terminals: 1. `cd train && dev.sh` diff --git a/examples/train-web/web/.gitignore b/examples/train-web/web/.gitignore index 57b2f7ef67..3f85cb98d2 100644 --- a/examples/train-web/web/.gitignore +++ b/examples/train-web/web/.gitignore @@ -24,3 +24,4 @@ dist-ssr *.sw? src/assets/* +public/mnist.db diff --git a/examples/train-web/web/postinstall.sh b/examples/train-web/web/postinstall.sh index e80ed03f01..e5c2250f43 100755 --- a/examples/train-web/web/postinstall.sh +++ b/examples/train-web/web/postinstall.sh @@ -2,3 +2,4 @@ mkdir -p ./src/assets cp ./node_modules/sql.js/dist/sql-wasm.wasm ./src/assets/sql-wasm.wasm +cp ~/.cache/burn-dataset/mnist.db ./public/mnist.db diff --git a/examples/train-web/web/src/train.ts b/examples/train-web/web/src/train.ts index 2ffabd0cb4..84167540ac 100644 --- a/examples/train-web/web/src/train.ts +++ b/examples/train-web/web/src/train.ts @@ -10,31 +10,49 @@ wasm() .then(() => init(workerUrl)) .catch(console.error) +const sqlJs = initSqlJs({ + locateFile: () => sqliteWasmUrl, +}) + +// just load and run mnist on pageload +fetch('./mnist.db') + .then((r) => r.arrayBuffer()) + .then(loadSqliteAndRun) + .catch(console.error) + export function setupTrain(element: HTMLInputElement) { element.onchange = async function () { - const db = await getDb(element.files![0]) - const trainQuery = db.prepare('select label, image_bytes from train') - const imageBytes: Uint8Array[] = [] - const labels: number[] = [] - const lengths: number[] = [] - while (trainQuery.step()) { - const row = trainQuery.getAsObject() - const label = row.label as number - labels.push(label) - const bytes = row.image_bytes as Uint8Array - imageBytes.push(bytes) - lengths.push(bytes.length) + const files = element.files + if (files != null) { + const file = files[0] + if (file != null) { + let ab = await file.arrayBuffer() + await loadSqliteAndRun(ab) + } } - run(Uint8Array.from(labels), concat(imageBytes), Uint16Array.from(lengths)) } } -async function getDb(sqliteBuffer: File): Promise { - const sql = await initSqlJs({ - locateFile: () => sqliteWasmUrl, - }) - const buffer = await sqliteBuffer.arrayBuffer() - return new sql.Database(new Uint8Array(buffer)) +async function loadSqliteAndRun(ab: ArrayBuffer) { + const db = await getDb(ab) + const trainQuery = db.prepare('select label, image_bytes from train') + const imageBytes: Uint8Array[] = [] + const labels: number[] = [] + const lengths: number[] = [] + while (trainQuery.step()) { + const row = trainQuery.getAsObject() + const label = row.label as number + labels.push(label) + const bytes = row.image_bytes as Uint8Array + imageBytes.push(bytes) + lengths.push(bytes.length) + } + run(Uint8Array.from(labels), concat(imageBytes), Uint16Array.from(lengths)) +} + +async function getDb(ab: ArrayBuffer): Promise { + const sql = await sqlJs + return new sql.Database(new Uint8Array(ab)) } // https://stackoverflow.com/a/59902602 From 2d5f8c00b7610433d0415295fbdccc3fdff1eb8d Mon Sep 17 00:00:00 2001 From: Alex Date: Tue, 14 Nov 2023 08:18:15 -0600 Subject: [PATCH 25/79] clean up resources --- examples/train-web/web/src/train.ts | 23 ++++++++++++++--------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/examples/train-web/web/src/train.ts b/examples/train-web/web/src/train.ts index 84167540ac..7cbecf207f 100644 --- a/examples/train-web/web/src/train.ts +++ b/examples/train-web/web/src/train.ts @@ -34,18 +34,23 @@ export function setupTrain(element: HTMLInputElement) { } async function loadSqliteAndRun(ab: ArrayBuffer) { - const db = await getDb(ab) - const trainQuery = db.prepare('select label, image_bytes from train') const imageBytes: Uint8Array[] = [] const labels: number[] = [] const lengths: number[] = [] - while (trainQuery.step()) { - const row = trainQuery.getAsObject() - const label = row.label as number - labels.push(label) - const bytes = row.image_bytes as Uint8Array - imageBytes.push(bytes) - lengths.push(bytes.length) + const db = await getDb(ab) + try { + const trainQuery = db.prepare('select label, image_bytes from train') + while (trainQuery.step()) { + const row = trainQuery.getAsObject() + const label = row.label as number + labels.push(label) + const bytes = row.image_bytes as Uint8Array + imageBytes.push(bytes) + lengths.push(bytes.length) + } + trainQuery.free() + } finally { + db.close() } run(Uint8Array.from(labels), concat(imageBytes), Uint16Array.from(lengths)) } From 3d863e7a12a5f66e0e0e4649622d9cdb2cdcc2dc Mon Sep 17 00:00:00 2001 From: Alex Date: Tue, 14 Nov 2023 19:54:04 -0600 Subject: [PATCH 26/79] can create a DataLoader --- examples/train-web/train/Cargo.toml | 1 + examples/train-web/train/src/data.rs | 46 +++++++++++++ examples/train-web/train/src/lib.rs | 4 +- examples/train-web/train/src/mnist.rs | 93 +++++++++++++++++++++++++++ examples/train-web/train/src/train.rs | 18 +++--- 5 files changed, 153 insertions(+), 9 deletions(-) create mode 100644 examples/train-web/train/src/data.rs create mode 100644 examples/train-web/train/src/mnist.rs diff --git a/examples/train-web/train/Cargo.toml b/examples/train-web/train/Cargo.toml index 01b3f86604..c7b6a65b31 100644 --- a/examples/train-web/train/Cargo.toml +++ b/examples/train-web/train/Cargo.toml @@ -19,3 +19,4 @@ burn = { path = "../../../burn", default-features = false, features = [ "browser", ] } serde = { workspace = true } +image = { version = "0.24.7", features = ["png"] } diff --git a/examples/train-web/train/src/data.rs b/examples/train-web/train/src/data.rs new file mode 100644 index 0000000000..3613940c1e --- /dev/null +++ b/examples/train-web/train/src/data.rs @@ -0,0 +1,46 @@ +use crate::mnist::MNISTItem; +use burn::{ + data::dataloader::batcher::Batcher, + tensor::{backend::Backend, Data, ElementConversion, Int, Tensor}, +}; + +pub struct MNISTBatcher { + device: B::Device, +} + +impl MNISTBatcher { + pub fn new(device: B::Device) -> Self { + Self { device } + } +} + +#[derive(Clone, Debug)] +pub struct MNISTBatch { + pub images: Tensor, + pub targets: Tensor, +} + +impl Batcher> for MNISTBatcher { + fn batch(&self, items: Vec) -> MNISTBatch { + let images = items + .iter() + .map(|item| Data::::from(item.image)) + .map(|data| Tensor::::from_data(data.convert())) + .map(|tensor| tensor.reshape([1, 28, 28])) + // normalize: make between [0,1] and make the mean = 0 and std = 1 + // values mean=0.1307,std=0.3081 were copied from Pytorch Mist Example + // https://github.com/pytorch/examples/blob/54f4572509891883a947411fd7239237dd2a39c3/mnist/main.py#L122 + .map(|tensor| ((tensor / 255) - 0.1307) / 0.3081) + .collect(); + + let targets = items + .iter() + .map(|item| Tensor::::from_data([(item.label as i64).elem()])) + .collect(); + + let images = Tensor::cat(images, 0).to_device(&self.device); + let targets = Tensor::cat(targets, 0).to_device(&self.device); + + MNISTBatch { images, targets } + } +} diff --git a/examples/train-web/train/src/lib.rs b/examples/train-web/train/src/lib.rs index b589c33185..84f93a18b6 100644 --- a/examples/train-web/train/src/lib.rs +++ b/examples/train-web/train/src/lib.rs @@ -4,6 +4,8 @@ use burn::{ }; use wasm_bindgen::prelude::*; +mod data; +mod mnist; mod train; #[wasm_bindgen(start)] @@ -20,7 +22,7 @@ pub fn init(worker_url: String) -> Result<(), String> { #[wasm_bindgen] pub fn run(labels: &[u8], images: &[u8], lengths: &[u16]) -> Result<(), String> { log::info!("Hello from Rust"); - train::run::>>(NdArrayDevice::Cpu); + train::run::>>(NdArrayDevice::Cpu, labels, images, lengths); Ok(()) } diff --git a/examples/train-web/train/src/mnist.rs b/examples/train-web/train/src/mnist.rs new file mode 100644 index 0000000000..41fb40b0e8 --- /dev/null +++ b/examples/train-web/train/src/mnist.rs @@ -0,0 +1,93 @@ +use burn::data::dataset::transform::{Mapper, MapperDataset}; +use burn::data::dataset::{Dataset, InMemDataset}; + +use image; +use serde::{Deserialize, Serialize}; + +const WIDTH: usize = 28; +const HEIGHT: usize = 28; + +/// MNIST item. +#[derive(Deserialize, Serialize, Debug, Clone)] +pub struct MNISTItem { + /// Image as a 2D array of floats. + pub image: [[f32; WIDTH]; HEIGHT], + + /// Label of the image. + pub label: u8, +} + +#[derive(Deserialize, Debug, Clone)] +struct MNISTItemRaw { + pub image_bytes: Vec, + pub label: u8, +} + +struct BytesToImage; + +impl Mapper for BytesToImage { + /// Convert a raw MNIST item (image bytes) to a MNIST item (2D array image). + fn map(&self, item: &MNISTItemRaw) -> MNISTItem { + let image = image::load_from_memory(&item.image_bytes).unwrap(); + let image = image.as_luma8().unwrap(); + + // Ensure the image dimensions are correct. + debug_assert_eq!(image.dimensions(), (WIDTH as u32, HEIGHT as u32)); + + // Convert the image to a 2D array of floats. + let mut image_array = [[0f32; WIDTH]; HEIGHT]; + for (i, pixel) in image.as_raw().iter().enumerate() { + let x = i % WIDTH; + let y = i / HEIGHT; + image_array[y][x] = *pixel as f32; + } + + MNISTItem { + image: image_array, + label: item.label, + } + } +} + +type MappedDataset = MapperDataset, BytesToImage, MNISTItemRaw>; + +/// MNIST dataset from Huggingface. +/// +/// The data is downloaded from Huggingface and stored in a SQLite database. +pub struct MNISTDataset { + dataset: MappedDataset, +} + +impl Dataset for MNISTDataset { + fn get(&self, index: usize) -> Option { + self.dataset.get(index) + } + + fn len(&self) -> usize { + self.dataset.len() + } +} + +impl MNISTDataset { + /// Creates a new dataset. + pub fn new(labels: &[u8], images: &[u8], lengths: &[u16]) -> Self { + debug_assert!(labels.len() == lengths.len()); + let mut start = 0 as usize; + let raws = labels + .iter() + .zip(lengths.iter()) + .map(|(label, length_ptr)| { + let length = *length_ptr as usize; + let end = start + length; + let raw = MNISTItemRaw { + label: *label, + image_bytes: images[start..end].to_vec(), + }; + start = end; + raw + }) + .collect(); + let dataset = MapperDataset::new(InMemDataset::new(raws), BytesToImage); + Self { dataset } + } +} diff --git a/examples/train-web/train/src/train.rs b/examples/train-web/train/src/train.rs index 5031daf01c..e31066fbae 100644 --- a/examples/train-web/train/src/train.rs +++ b/examples/train-web/train/src/train.rs @@ -1,5 +1,7 @@ +use crate::{data::MNISTBatcher, mnist::MNISTDataset}; use burn::{ config::Config, + data::dataloader::DataLoaderBuilder, module::Module, nn::{ conv::{Conv2d, Conv2dConfig}, @@ -80,7 +82,7 @@ impl ModelConfig { } } -pub fn run(device: B::Device) { +pub fn run(device: B::Device, labels: &[u8], images: &[u8], lengths: &[u16]) { // Create the configuration. let config_model = ModelConfig::new(10, 1024); let config_optimizer = AdamConfig::new(); @@ -93,16 +95,16 @@ pub fn run(device: B::Device) { let optim = config.optimizer.init(); // Create the batcher. - // let batcher_train = MNISTBatcher::::new(device.clone()); + let batcher_train = MNISTBatcher::::new(device.clone()); // let batcher_valid = MNISTBatcher::::new(device.clone()); // Create the dataloaders. - // let dataloader_train = DataLoaderBuilder::new(batcher_train) - // .batch_size(config.batch_size) - // .shuffle(config.seed) - // .num_workers(config.num_workers) - // .build(MNISTDataset::train()); - // + let dataloader_train = DataLoaderBuilder::new(batcher_train) + .batch_size(config.batch_size) + .shuffle(config.seed) + .num_workers(config.num_workers) + .build(MNISTDataset::new(labels, images, lengths)); + // let dataloader_test = DataLoaderBuilder::new(batcher_valid) // .batch_size(config.batch_size) // .shuffle(config.seed) From 0b0fd86fe1d8c85c5abe2c86a6f775d1a8baf155 Mon Sep 17 00:00:00 2001 From: Alex Date: Tue, 14 Nov 2023 20:39:12 -0600 Subject: [PATCH 27/79] can load train and test --- examples/train-web/train/src/lib.rs | 19 +++++++++++++-- examples/train-web/train/src/train.rs | 24 ++++++++++++------- examples/train-web/web/src/train.ts | 34 +++++++++++++++++++++------ 3 files changed, 60 insertions(+), 17 deletions(-) diff --git a/examples/train-web/train/src/lib.rs b/examples/train-web/train/src/lib.rs index 84f93a18b6..5503821866 100644 --- a/examples/train-web/train/src/lib.rs +++ b/examples/train-web/train/src/lib.rs @@ -20,9 +20,24 @@ pub fn init(worker_url: String) -> Result<(), String> { } #[wasm_bindgen] -pub fn run(labels: &[u8], images: &[u8], lengths: &[u16]) -> Result<(), String> { +pub fn run( + train_labels: &[u8], + train_images: &[u8], + train_lengths: &[u16], + test_labels: &[u8], + test_images: &[u8], + test_lengths: &[u16], +) -> Result<(), String> { log::info!("Hello from Rust"); - train::run::>>(NdArrayDevice::Cpu, labels, images, lengths); + train::run::>>( + NdArrayDevice::Cpu, + train_labels, + train_images, + train_lengths, + test_labels, + test_images, + test_lengths, + ); Ok(()) } diff --git a/examples/train-web/train/src/train.rs b/examples/train-web/train/src/train.rs index e31066fbae..ace3aa2ab6 100644 --- a/examples/train-web/train/src/train.rs +++ b/examples/train-web/train/src/train.rs @@ -82,7 +82,15 @@ impl ModelConfig { } } -pub fn run(device: B::Device, labels: &[u8], images: &[u8], lengths: &[u16]) { +pub fn run( + device: B::Device, + train_labels: &[u8], + train_images: &[u8], + train_lengths: &[u16], + test_labels: &[u8], + test_images: &[u8], + test_lengths: &[u16], +) { // Create the configuration. let config_model = ModelConfig::new(10, 1024); let config_optimizer = AdamConfig::new(); @@ -96,20 +104,20 @@ pub fn run(device: B::Device, labels: &[u8], images: &[u8], // Create the batcher. let batcher_train = MNISTBatcher::::new(device.clone()); - // let batcher_valid = MNISTBatcher::::new(device.clone()); + let batcher_valid = MNISTBatcher::::new(device.clone()); // Create the dataloaders. let dataloader_train = DataLoaderBuilder::new(batcher_train) .batch_size(config.batch_size) .shuffle(config.seed) .num_workers(config.num_workers) - .build(MNISTDataset::new(labels, images, lengths)); + .build(MNISTDataset::new(train_labels, train_images, train_lengths)); - // let dataloader_test = DataLoaderBuilder::new(batcher_valid) - // .batch_size(config.batch_size) - // .shuffle(config.seed) - // .num_workers(config.num_workers) - // .build(MNISTDataset::test()); + let dataloader_test = DataLoaderBuilder::new(batcher_valid) + .batch_size(config.batch_size) + .shuffle(config.seed) + .num_workers(config.num_workers) + .build(MNISTDataset::new(test_labels, test_images, test_lengths)); // artifact dir does not need to be provided when log_to_file is false let builder: LearnerBuilder< diff --git a/examples/train-web/web/src/train.ts b/examples/train-web/web/src/train.ts index 7cbecf207f..53e3109698 100644 --- a/examples/train-web/web/src/train.ts +++ b/examples/train-web/web/src/train.ts @@ -34,25 +34,45 @@ export function setupTrain(element: HTMLInputElement) { } async function loadSqliteAndRun(ab: ArrayBuffer) { - const imageBytes: Uint8Array[] = [] - const labels: number[] = [] - const lengths: number[] = [] + const trainImages: Uint8Array[] = [] + const trainLabels: number[] = [] + const trainLengths: number[] = [] + const testImages: Uint8Array[] = [] + const testLabels: number[] = [] + const testLengths: number[] = [] const db = await getDb(ab) try { const trainQuery = db.prepare('select label, image_bytes from train') while (trainQuery.step()) { const row = trainQuery.getAsObject() const label = row.label as number - labels.push(label) + trainLabels.push(label) const bytes = row.image_bytes as Uint8Array - imageBytes.push(bytes) - lengths.push(bytes.length) + trainImages.push(bytes) + trainLengths.push(bytes.length) } trainQuery.free() + const testQuery = db.prepare('select label, image_bytes from test') + while (testQuery.step()) { + const row = testQuery.getAsObject() + const label = row.label as number + testLabels.push(label) + const bytes = row.image_bytes as Uint8Array + testImages.push(bytes) + testLengths.push(bytes.length) + } + testQuery.free() } finally { db.close() } - run(Uint8Array.from(labels), concat(imageBytes), Uint16Array.from(lengths)) + run( + Uint8Array.from(trainLabels), + concat(trainImages), + Uint16Array.from(trainLengths), + Uint8Array.from(testLabels), + concat(testImages), + Uint16Array.from(testLengths), + ) } async function getDb(ab: ArrayBuffer): Promise { From 6331edeebd66c058f00d09613d987291895896cd Mon Sep 17 00:00:00 2001 From: Alex Date: Tue, 14 Nov 2023 20:49:19 -0600 Subject: [PATCH 28/79] copied over model --- examples/train-web/train/src/lib.rs | 1 + examples/train-web/train/src/model.rs | 83 +++++++++++++++++++++++++++ examples/train-web/train/src/train.rs | 45 +-------------- 3 files changed, 86 insertions(+), 43 deletions(-) create mode 100644 examples/train-web/train/src/model.rs diff --git a/examples/train-web/train/src/lib.rs b/examples/train-web/train/src/lib.rs index 5503821866..c2f905c61d 100644 --- a/examples/train-web/train/src/lib.rs +++ b/examples/train-web/train/src/lib.rs @@ -6,6 +6,7 @@ use wasm_bindgen::prelude::*; mod data; mod mnist; +mod model; mod train; #[wasm_bindgen(start)] diff --git a/examples/train-web/train/src/model.rs b/examples/train-web/train/src/model.rs new file mode 100644 index 0000000000..665612b50b --- /dev/null +++ b/examples/train-web/train/src/model.rs @@ -0,0 +1,83 @@ +use burn::{ + config::Config, + module::Module, + nn::{ + conv::{Conv2d, Conv2dConfig}, + pool::{AdaptiveAvgPool2d, AdaptiveAvgPool2dConfig}, + Dropout, DropoutConfig, Linear, LinearConfig, ReLU, + }, + tensor::{backend::Backend, Tensor}, +}; + +#[derive(Module, Debug)] +pub struct Model { + conv1: Conv2d, + conv2: Conv2d, + pool: AdaptiveAvgPool2d, + dropout: Dropout, + linear1: Linear, + linear2: Linear, + activation: ReLU, +} + +#[derive(Config, Debug)] +pub struct ModelConfig { + num_classes: usize, + hidden_size: usize, + #[config(default = "0.5")] + dropout: f64, +} + +impl ModelConfig { + /// Returns the initialized model. + pub fn init(&self) -> Model { + Model { + conv1: Conv2dConfig::new([1, 8], [3, 3]).init(), + conv2: Conv2dConfig::new([8, 16], [3, 3]).init(), + pool: AdaptiveAvgPool2dConfig::new([8, 8]).init(), + activation: ReLU::new(), + linear1: LinearConfig::new(16 * 8 * 8, self.hidden_size).init(), + linear2: LinearConfig::new(self.hidden_size, self.num_classes).init(), + dropout: DropoutConfig::new(self.dropout).init(), + } + } + /// Returns the initialized model using the recorded weights. + pub fn init_with(&self, record: ModelRecord) -> Model { + Model { + conv1: Conv2dConfig::new([1, 8], [3, 3]).init_with(record.conv1), + conv2: Conv2dConfig::new([8, 16], [3, 3]).init_with(record.conv2), + pool: AdaptiveAvgPool2dConfig::new([8, 8]).init(), + activation: ReLU::new(), + linear1: LinearConfig::new(16 * 8 * 8, self.hidden_size).init_with(record.linear1), + linear2: LinearConfig::new(self.hidden_size, self.num_classes) + .init_with(record.linear2), + dropout: DropoutConfig::new(self.dropout).init(), + } + } +} + +impl Model { + /// # Shapes + /// - Images [batch_size, height, width] + /// - Output [batch_size, class_prob] + pub fn forward(&self, images: Tensor) -> Tensor { + let [batch_size, height, width] = images.dims(); + + // Create a channel. + let x = images.reshape([batch_size, 1, height, width]); + + let x = self.conv1.forward(x); // [batch_size, 8, _, _] + let x = self.dropout.forward(x); + let x = self.conv2.forward(x); // [batch_size, 16, _, _] + let x = self.dropout.forward(x); + let x = self.activation.forward(x); + + let x = self.pool.forward(x); // [batch_size, 16, 8, 8] + let x = x.reshape([batch_size, 16 * 8 * 8]); + let x = self.linear1.forward(x); + let x = self.dropout.forward(x); + let x = self.activation.forward(x); + + self.linear2.forward(x) // [batch_size, num_classes] + } +} diff --git a/examples/train-web/train/src/train.rs b/examples/train-web/train/src/train.rs index ace3aa2ab6..10d0667b2c 100644 --- a/examples/train-web/train/src/train.rs +++ b/examples/train-web/train/src/train.rs @@ -1,29 +1,14 @@ -use crate::{data::MNISTBatcher, mnist::MNISTDataset}; +use crate::{data::MNISTBatcher, mnist::MNISTDataset, model::ModelConfig}; use burn::{ config::Config, data::dataloader::DataLoaderBuilder, - module::Module, - nn::{ - conv::{Conv2d, Conv2dConfig}, - pool::{AdaptiveAvgPool2d, AdaptiveAvgPool2dConfig}, - Dropout, DropoutConfig, Linear, LinearConfig, ReLU, - }, optim::AdamConfig, - tensor::backend::{AutodiffBackend, Backend}, + tensor::backend::AutodiffBackend, train::{ renderer::{MetricState, MetricsRenderer, TrainingProgress}, ClassificationOutput, LearnerBuilder, }, }; - -#[derive(Config, Debug)] -pub struct ModelConfig { - num_classes: usize, - hidden_size: usize, - #[config(default = "0.5")] - dropout: f64, -} - #[derive(Config)] pub struct MnistTrainingConfig { #[config(default = 10)] @@ -56,32 +41,6 @@ impl MetricsRenderer for CustomRenderer { } } -#[derive(Module, Debug)] -pub struct Model { - conv1: Conv2d, - conv2: Conv2d, - pool: AdaptiveAvgPool2d, - dropout: Dropout, - linear1: Linear, - linear2: Linear, - activation: ReLU, -} - -impl ModelConfig { - /// Returns the initialized model. - pub fn init(&self) -> Model { - Model { - conv1: Conv2dConfig::new([1, 8], [3, 3]).init(), - conv2: Conv2dConfig::new([8, 16], [3, 3]).init(), - pool: AdaptiveAvgPool2dConfig::new([8, 8]).init(), - activation: ReLU::new(), - linear1: LinearConfig::new(16 * 8 * 8, self.hidden_size).init(), - linear2: LinearConfig::new(self.hidden_size, self.num_classes).init(), - dropout: DropoutConfig::new(self.dropout).init(), - } - } -} - pub fn run( device: B::Device, train_labels: &[u8], From 469e270b19d39262c6660c855f1839a9721dc96c Mon Sep 17 00:00:00 2001 From: Alex Date: Wed, 15 Nov 2023 09:08:10 -0600 Subject: [PATCH 29/79] copy over guide's `training.rs`, more or less --- examples/train-web/train/src/lib.rs | 7 +- examples/train-web/train/src/model.rs | 13 ---- examples/train-web/train/src/train.rs | 104 ++++++++++++++++++-------- 3 files changed, 78 insertions(+), 46 deletions(-) diff --git a/examples/train-web/train/src/lib.rs b/examples/train-web/train/src/lib.rs index c2f905c61d..3c0698420e 100644 --- a/examples/train-web/train/src/lib.rs +++ b/examples/train-web/train/src/lib.rs @@ -1,5 +1,7 @@ +use crate::{model::ModelConfig, train::TrainingConfig}; use burn::{ backend::{ndarray::NdArrayDevice, Autodiff, NdArray}, + optim::AdamConfig, train::util, }; use wasm_bindgen::prelude::*; @@ -30,7 +32,10 @@ pub fn run( test_lengths: &[u16], ) -> Result<(), String> { log::info!("Hello from Rust"); - train::run::>>( + let config = TrainingConfig::new(ModelConfig::new(10, 512), AdamConfig::new()); + train::train::>>( + "", + config, NdArrayDevice::Cpu, train_labels, train_images, diff --git a/examples/train-web/train/src/model.rs b/examples/train-web/train/src/model.rs index 665612b50b..cce581157d 100644 --- a/examples/train-web/train/src/model.rs +++ b/examples/train-web/train/src/model.rs @@ -41,19 +41,6 @@ impl ModelConfig { dropout: DropoutConfig::new(self.dropout).init(), } } - /// Returns the initialized model using the recorded weights. - pub fn init_with(&self, record: ModelRecord) -> Model { - Model { - conv1: Conv2dConfig::new([1, 8], [3, 3]).init_with(record.conv1), - conv2: Conv2dConfig::new([8, 16], [3, 3]).init_with(record.conv2), - pool: AdaptiveAvgPool2dConfig::new([8, 8]).init(), - activation: ReLU::new(), - linear1: LinearConfig::new(16 * 8 * 8, self.hidden_size).init_with(record.linear1), - linear2: LinearConfig::new(self.hidden_size, self.num_classes) - .init_with(record.linear2), - dropout: DropoutConfig::new(self.dropout).init(), - } - } } impl Model { diff --git a/examples/train-web/train/src/train.rs b/examples/train-web/train/src/train.rs index 10d0667b2c..b7328274c6 100644 --- a/examples/train-web/train/src/train.rs +++ b/examples/train-web/train/src/train.rs @@ -1,16 +1,58 @@ -use crate::{data::MNISTBatcher, mnist::MNISTDataset, model::ModelConfig}; +use crate::{ + data::{MNISTBatch, MNISTBatcher}, + mnist::MNISTDataset, + model::{Model, ModelConfig}, +}; use burn::{ + self, config::Config, data::dataloader::DataLoaderBuilder, + module::Module, + nn::loss::CrossEntropyLoss, optim::AdamConfig, - tensor::backend::AutodiffBackend, + record::CompactRecorder, + tensor::{ + backend::{AutodiffBackend, Backend}, + Int, Tensor, + }, train::{ + metric::{AccuracyMetric, LossMetric}, renderer::{MetricState, MetricsRenderer, TrainingProgress}, - ClassificationOutput, LearnerBuilder, + ClassificationOutput, LearnerBuilder, TrainOutput, TrainStep, ValidStep, }, }; + +impl Model { + pub fn forward_classification( + &self, + images: Tensor, + targets: Tensor, + ) -> ClassificationOutput { + let output = self.forward(images); + let loss = CrossEntropyLoss::default().forward(output.clone(), targets.clone()); + + ClassificationOutput::new(loss, output, targets) + } +} + +impl TrainStep, ClassificationOutput> for Model { + fn step(&self, batch: MNISTBatch) -> TrainOutput> { + let item = self.forward_classification(batch.images, batch.targets); + + TrainOutput::new(self, item.loss.backward(), item) + } +} + +impl ValidStep, ClassificationOutput> for Model { + fn step(&self, batch: MNISTBatch) -> ClassificationOutput { + self.forward_classification(batch.images, batch.targets) + } +} + #[derive(Config)] -pub struct MnistTrainingConfig { +pub struct TrainingConfig { + pub model: ModelConfig, + pub optimizer: AdamConfig, #[config(default = 10)] pub num_epochs: usize, #[config(default = 64)] @@ -19,10 +61,8 @@ pub struct MnistTrainingConfig { pub num_workers: usize, #[config(default = 42)] pub seed: u64, - #[config(default = 1e-4)] - pub lr: f64, - pub model: ModelConfig, - pub optimizer: AdamConfig, + #[config(default = 1.0e-4)] + pub learning_rate: f64, } struct CustomRenderer {} @@ -41,7 +81,9 @@ impl MetricsRenderer for CustomRenderer { } } -pub fn run( +pub fn train( + artifact_dir: &str, + config: TrainingConfig, device: B::Device, train_labels: &[u8], train_images: &[u8], @@ -50,18 +92,13 @@ pub fn run( test_images: &[u8], test_lengths: &[u16], ) { - // Create the configuration. - let config_model = ModelConfig::new(10, 1024); - let config_optimizer = AdamConfig::new(); - let config = MnistTrainingConfig::new(config_model, config_optimizer); + // std::fs::create_dir_all(artifact_dir).ok(); + // config + // .save(format!("{artifact_dir}/config.json")) + // .expect("Save without error"); B::seed(config.seed); - // Create the model and optimizer. - let model = config.model.init(); - let optim = config.optimizer.init(); - - // Create the batcher. let batcher_train = MNISTBatcher::::new(device.clone()); let batcher_valid = MNISTBatcher::::new(device.clone()); @@ -78,23 +115,26 @@ pub fn run( .num_workers(config.num_workers) .build(MNISTDataset::new(test_labels, test_images, test_lengths)); - // artifact dir does not need to be provided when log_to_file is false - let builder: LearnerBuilder< - B, - ClassificationOutput, - ClassificationOutput<::InnerBackend>, - _, - _, - _, - > = LearnerBuilder::new("") + let learner = LearnerBuilder::new(artifact_dir) + .metric_train_numeric(AccuracyMetric::new()) + .metric_valid_numeric(AccuracyMetric::new()) + .metric_train_numeric(LossMetric::new()) + .metric_valid_numeric(LossMetric::new()) + // .with_file_checkpointer(CompactRecorder::new()) + .log_to_file(false) .devices(vec![device]) .num_epochs(config.num_epochs) .renderer(CustomRenderer {}) - .log_to_file(false); - // can be used to interrupt training - let _interrupter = builder.interrupter(); + .num_epochs(config.num_epochs) + .build( + config.model.init::(), + config.optimizer.init(), + config.learning_rate, + ); - let _learner = builder.build(model, optim, config.lr); + let model_trained = learner.fit(dataloader_train, dataloader_test); - // let _model_trained = learner.fit(dataloader_train, dataloader_test); + model_trained + .save_file(format!("{artifact_dir}/model"), &CompactRecorder::new()) + .expect("Failed to save trained model"); } From 939f3d0d12372a64b5b8ba952d037deb5276f3e9 Mon Sep 17 00:00:00 2001 From: Alex Date: Wed, 15 Nov 2023 11:49:34 -0600 Subject: [PATCH 30/79] docs --- examples/train-web/readme.md | 4 +++- examples/train-web/train/src/mnist.rs | 2 ++ examples/train-web/web/src/train.ts | 6 ++++++ 3 files changed, 11 insertions(+), 1 deletion(-) diff --git a/examples/train-web/readme.md b/examples/train-web/readme.md index 736b912544..9244941dff 100644 --- a/examples/train-web/readme.md +++ b/examples/train-web/readme.md @@ -6,7 +6,9 @@ For example, run `cargo install --version 0.2.88 wasm-bindgen-cli --force`. The Install [PNPM](https://pnpm.io/). -The [`postinstall.sh`](./web/postinstall.sh) script expects the mnist database to be at `~/.cache/burn-dataset/mnist.db`. Running `guide` will generate this file. Alternatively, you can download it from [Hugging Face](https://huggingface.co/datasets/mnist). +Install [cargo-watch](https://crates.io/crates/cargo-watch). + +The [`postinstall.sh`](./web/postinstall.sh) script expects the mnist database to be at `~/.cache/burn-dataset/mnist.db`. Running `burn/examples/guide` will generate this file. Alternatively, you can download it from [Hugging Face](https://huggingface.co/datasets/mnist). Then in separate terminals: diff --git a/examples/train-web/train/src/mnist.rs b/examples/train-web/train/src/mnist.rs index 41fb40b0e8..b6f74ebf93 100644 --- a/examples/train-web/train/src/mnist.rs +++ b/examples/train-web/train/src/mnist.rs @@ -71,6 +71,8 @@ impl Dataset for MNISTDataset { impl MNISTDataset { /// Creates a new dataset. pub fn new(labels: &[u8], images: &[u8], lengths: &[u16]) -> Self { + // Decoding is here. + // Encoding is done at `examples/train-web/web/src/train.ts`. debug_assert!(labels.len() == lengths.len()); let mut start = 0 as usize; let raws = labels diff --git a/examples/train-web/web/src/train.ts b/examples/train-web/web/src/train.ts index 53e3109698..aec82c9d5c 100644 --- a/examples/train-web/web/src/train.ts +++ b/examples/train-web/web/src/train.ts @@ -34,6 +34,12 @@ export function setupTrain(element: HTMLInputElement) { } async function loadSqliteAndRun(ab: ArrayBuffer) { + // Images are an array of arrays. + // We can't send an array of arrays to Wasm. + // So instead we merge images into a single large array and + // use another array, `lengths`, to keep track of the image size. + // Encoding is done here. + // Decoding is done at `burn/examples/train-web/train/src/mnist.rs`. const trainImages: Uint8Array[] = [] const trainLabels: number[] = [] const trainLengths: number[] = [] From 437cf004d265ab3be2e84ea6cd8ff60cac451779 Mon Sep 17 00:00:00 2001 From: Alex Date: Wed, 15 Nov 2023 15:40:34 -0600 Subject: [PATCH 31/79] assert is valid png --- examples/train-web/train/src/mnist.rs | 20 ++++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/examples/train-web/train/src/mnist.rs b/examples/train-web/train/src/mnist.rs index b6f74ebf93..a86c99f259 100644 --- a/examples/train-web/train/src/mnist.rs +++ b/examples/train-web/train/src/mnist.rs @@ -73,7 +73,12 @@ impl MNISTDataset { pub fn new(labels: &[u8], images: &[u8], lengths: &[u16]) -> Self { // Decoding is here. // Encoding is done at `examples/train-web/web/src/train.ts`. - debug_assert!(labels.len() == lengths.len()); + assert!( + labels.len() == lengths.len(), + "`labels` has {} elements and `lengths` has {} elements, but they should be equal in size.", + labels.len(), + lengths.len(), + ); let mut start = 0 as usize; let raws = labels .iter() @@ -81,9 +86,20 @@ impl MNISTDataset { .map(|(label, length_ptr)| { let length = *length_ptr as usize; let end = start + length; + let image_bytes = images[start..end].to_vec(); + // Assert that the incoming data is a valid PNG. + // Really just here to ensure the decoding process isn't off by one. + debug_assert_eq!( + image_bytes[0..8], + vec![137, 80, 78, 71, 13, 10, 26, 10] // Starts with bytes from the spec http://www.libpng.org/pub/png/spec/1.2/PNG-Structure.html + ); + debug_assert_eq!( + image_bytes[image_bytes.len() - 8..], + vec![73, 69, 78, 68, 174, 66, 96, 130] // Ends with the IEND chunk. I don't think the spec specifies the last 4 bytes, but `mnist.db`'s images all end with it. + ); let raw = MNISTItemRaw { label: *label, - image_bytes: images[start..end].to_vec(), + image_bytes, }; start = end; raw From e3c84a678624c1088b7ee8c4df1d26988c24f627 Mon Sep 17 00:00:00 2001 From: Alex Date: Wed, 22 Nov 2023 08:20:53 -0600 Subject: [PATCH 32/79] add pool from https://github.com/rustwasm/wasm-bindgen/blob/main/examples/raytrace-parallel/src/pool.rs --- burn-train/src/pool.rs | 219 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 219 insertions(+) create mode 100644 burn-train/src/pool.rs diff --git a/burn-train/src/pool.rs b/burn-train/src/pool.rs new file mode 100644 index 0000000000..37ef4518c0 --- /dev/null +++ b/burn-train/src/pool.rs @@ -0,0 +1,219 @@ +// Silences warnings from the compiler about Work.func and child_entry_point +// being unused when the target is not wasm. +#![cfg_attr(not(target_arch = "wasm32"), allow(dead_code))] + +//! A small module that's intended to provide an example of creating a pool of +//! web workers which can be used to execute `rayon`-style work. + +use std::cell::RefCell; +use std::rc::Rc; +use wasm_bindgen::prelude::*; +use web_sys::{DedicatedWorkerGlobalScope, MessageEvent}; +use web_sys::{ErrorEvent, Event, Worker}; + +#[wasm_bindgen] +pub struct WorkerPool { + state: Rc, +} + +struct PoolState { + workers: RefCell>, + callback: Closure, +} + +struct Work { + func: Box, +} + +#[wasm_bindgen] +impl WorkerPool { + /// Creates a new `WorkerPool` which immediately creates `initial` workers. + /// + /// The pool created here can be used over a long period of time, and it + /// will be initially primed with `initial` workers. Currently workers are + /// never released or gc'd until the whole pool is destroyed. + /// + /// # Errors + /// + /// Returns any error that may happen while a JS web worker is created and a + /// message is sent to it. + #[wasm_bindgen(constructor)] + pub fn new(initial: usize) -> Result { + let pool = WorkerPool { + state: Rc::new(PoolState { + workers: RefCell::new(Vec::with_capacity(initial)), + callback: Closure::new(|event: Event| { + console_log!("unhandled event: {}", event.type_()); + crate::logv(&event); + }), + }), + }; + for _ in 0..initial { + let worker = pool.spawn()?; + pool.state.push(worker); + } + + Ok(pool) + } + + /// Unconditionally spawns a new worker + /// + /// The worker isn't registered with this `WorkerPool` but is capable of + /// executing work for this wasm module. + /// + /// # Errors + /// + /// Returns any error that may happen while a JS web worker is created and a + /// message is sent to it. + fn spawn(&self) -> Result { + console_log!("spawning new worker"); + // TODO: what do do about `./worker.js`: + // + // * the path is only known by the bundler. How can we, as a + // library, know what's going on? + // * How do we not fetch a script N times? It internally then + // causes another script to get fetched N times... + let worker = Worker::new("./worker.js")?; + + // With a worker spun up send it the module/memory so it can start + // instantiating the wasm module. Later it might receive further + // messages about code to run on the wasm module. + let array = js_sys::Array::new(); + array.push(&wasm_bindgen::module()); + array.push(&wasm_bindgen::memory()); + worker.post_message(&array)?; + + Ok(worker) + } + + /// Fetches a worker from this pool, spawning one if necessary. + /// + /// This will attempt to pull an already-spawned web worker from our cache + /// if one is available, otherwise it will spawn a new worker and return the + /// newly spawned worker. + /// + /// # Errors + /// + /// Returns any error that may happen while a JS web worker is created and a + /// message is sent to it. + fn worker(&self) -> Result { + match self.state.workers.borrow_mut().pop() { + Some(worker) => Ok(worker), + None => self.spawn(), + } + } + + /// Executes the work `f` in a web worker, spawning a web worker if + /// necessary. + /// + /// This will acquire a web worker and then send the closure `f` to the + /// worker to execute. The worker won't be usable for anything else while + /// `f` is executing, and no callbacks are registered for when the worker + /// finishes. + /// + /// # Errors + /// + /// Returns any error that may happen while a JS web worker is created and a + /// message is sent to it. + fn execute(&self, f: impl FnOnce() + Send + 'static) -> Result { + let worker = self.worker()?; + let work = Box::new(Work { func: Box::new(f) }); + let ptr = Box::into_raw(work); + match worker.post_message(&JsValue::from(ptr as u32)) { + Ok(()) => Ok(worker), + Err(e) => { + unsafe { + drop(Box::from_raw(ptr)); + } + Err(e) + } + } + } + + /// Configures an `onmessage` callback for the `worker` specified for the + /// web worker to be reclaimed and re-inserted into this pool when a message + /// is received. + /// + /// Currently this `WorkerPool` abstraction is intended to execute one-off + /// style work where the work itself doesn't send any notifications and + /// whatn it's done the worker is ready to execute more work. This method is + /// used for all spawned workers to ensure that when the work is finished + /// the worker is reclaimed back into this pool. + fn reclaim_on_message(&self, worker: Worker) { + let state = Rc::downgrade(&self.state); + let worker2 = worker.clone(); + let reclaim_slot = Rc::new(RefCell::new(None)); + let slot2 = reclaim_slot.clone(); + let reclaim = Closure::::new(move |event: Event| { + if let Some(error) = event.dyn_ref::() { + console_log!("error in worker: {}", error.message()); + // TODO: this probably leaks memory somehow? It's sort of + // unclear what to do about errors in workers right now. + return; + } + + // If this is a completion event then can deallocate our own + // callback by clearing out `slot2` which contains our own closure. + if let Some(_msg) = event.dyn_ref::() { + if let Some(state) = state.upgrade() { + state.push(worker2.clone()); + } + *slot2.borrow_mut() = None; + return; + } + + console_log!("unhandled event: {}", event.type_()); + crate::logv(&event); + // TODO: like above, maybe a memory leak here? + }); + worker.set_onmessage(Some(reclaim.as_ref().unchecked_ref())); + *reclaim_slot.borrow_mut() = Some(reclaim); + } +} + +impl WorkerPool { + /// Executes `f` in a web worker. + /// + /// This pool manages a set of web workers to draw from, and `f` will be + /// spawned quickly into one if the worker is idle. If no idle workers are + /// available then a new web worker will be spawned. + /// + /// Once `f` returns the worker assigned to `f` is automatically reclaimed + /// by this `WorkerPool`. This method provides no method of learning when + /// `f` completes, and for that you'll need to use `run_notify`. + /// + /// # Errors + /// + /// If an error happens while spawning a web worker or sending a message to + /// a web worker, that error is returned. + pub fn run(&self, f: impl FnOnce() + Send + 'static) -> Result<(), JsValue> { + let worker = self.execute(f)?; + self.reclaim_on_message(worker); + Ok(()) + } +} + +impl PoolState { + fn push(&self, worker: Worker) { + worker.set_onmessage(Some(self.callback.as_ref().unchecked_ref())); + worker.set_onerror(Some(self.callback.as_ref().unchecked_ref())); + let mut workers = self.workers.borrow_mut(); + for prev in workers.iter() { + let prev: &JsValue = prev; + let worker: &JsValue = &worker; + assert!(prev != worker); + } + workers.push(worker); + } +} + +/// Entry point invoked by `worker.js`, a bit of a hack but see the "TODO" above +/// about `worker.js` in general. +#[wasm_bindgen] +pub fn child_entry_point(ptr: u32) -> Result<(), JsValue> { + let ptr = unsafe { Box::from_raw(ptr as *mut Work) }; + let global = js_sys::global().unchecked_into::(); + (ptr.func)(); + global.post_message(&JsValue::undefined())?; + Ok(()) +} From 9a355f28843eb7d0b99300bcd65291e81a7bb828 Mon Sep 17 00:00:00 2001 From: Alex Date: Fri, 24 Nov 2023 08:05:06 -0600 Subject: [PATCH 33/79] pool works --- burn-train/Cargo.toml | 3 + burn-train/src/lib.rs | 4 + burn-train/src/pool.rs | 167 +++++++++++++-------------- burn-train/src/util.rs | 56 +++------ examples/train-web/train/src/lib.rs | 10 +- examples/train-web/web/src/train.ts | 11 +- examples/train-web/web/src/worker.ts | 10 +- 7 files changed, 120 insertions(+), 141 deletions(-) diff --git a/burn-train/Cargo.toml b/burn-train/Cargo.toml index 2d0e00612e..e7517ba103 100644 --- a/burn-train/Cargo.toml +++ b/burn-train/Cargo.toml @@ -19,6 +19,7 @@ browser = ["js-sys", "web-sys", "wasm-bindgen"] [dependencies] burn-core = { path = "../burn-core", version = "0.11.0" } +lazy_static = "1.4.0" log = { workspace = true } tracing-subscriber.workspace = true tracing-appender.workspace = true @@ -42,6 +43,8 @@ web-sys = { version = "0.3.65", optional = true, features = [ "Worker", "WorkerOptions", "WorkerType", + "MessageEvent", + "ErrorEvent", ] } wasm-bindgen = { workspace = true, optional = true } diff --git a/burn-train/src/lib.rs b/burn-train/src/lib.rs index f4bc430ee3..50ec3a3316 100644 --- a/burn-train/src/lib.rs +++ b/burn-train/src/lib.rs @@ -2,6 +2,7 @@ //! A library for training neural networks using the burn crate. +pub mod pool; pub mod util; #[macro_use] @@ -27,3 +28,6 @@ pub use learner::*; #[cfg(test)] pub(crate) type TestBackend = burn_ndarray::NdArray; + +#[macro_use] +extern crate lazy_static; diff --git a/burn-train/src/pool.rs b/burn-train/src/pool.rs index 37ef4518c0..d18cdc4392 100644 --- a/burn-train/src/pool.rs +++ b/burn-train/src/pool.rs @@ -1,19 +1,31 @@ -// Silences warnings from the compiler about Work.func and child_entry_point -// being unused when the target is not wasm. -#![cfg_attr(not(target_arch = "wasm32"), allow(dead_code))] - //! A small module that's intended to provide an example of creating a pool of //! web workers which can be used to execute `rayon`-style work. -use std::cell::RefCell; -use std::rc::Rc; +use crate::util::WORKER_URL; +use log::info; +use std::borrow::BorrowMut; +use std::{cell::RefCell, sync::Arc}; use wasm_bindgen::prelude::*; -use web_sys::{DedicatedWorkerGlobalScope, MessageEvent}; -use web_sys::{ErrorEvent, Event, Worker}; +use web_sys::{ErrorEvent, Event, MessageEvent, Worker}; + +// "This is only safe because wasm is currently single-threaded." https://github.com/rustwasm/wasm-bindgen/issues/1505#issuecomment-489300331 +unsafe impl Send for PoolState {} +unsafe impl Sync for PoolState {} + +lazy_static! { + pub static ref WORKER_POOL: WorkerPool = WorkerPool { + state: Arc::new(PoolState { + workers: RefCell::new(vec!()), + callback: Closure::new(|event: Event| { + info!("unhandled event: {:?}", &event); + }) + }) + }; +} #[wasm_bindgen] pub struct WorkerPool { - state: Rc, + state: Arc, } struct PoolState { @@ -21,41 +33,33 @@ struct PoolState { callback: Closure, } -struct Work { - func: Box, +/// Creates a new `WorkerPool` which immediately creates `initial` workers. +/// +/// The pool created here can be used over a long period of time, and it +/// will be initially primed with `initial` workers. Currently workers are +/// never released or gc'd until the whole pool is destroyed. +/// +/// # Errors +/// +/// Returns any error that may happen while a JS web worker is created and a +/// message is sent to it. +pub fn init(initial: usize) -> Result<(), JsValue> { + let pool = WorkerPool { + state: Arc::new(PoolState { + workers: RefCell::new(Vec::with_capacity(initial)), + callback: Closure::new(|event: Event| { + info!("unhandled event: {:?}", &event); + }), + }), + }; + + let workers: Result<_, _> = [0..initial].map(|_| pool.spawn()).into_iter().collect(); + WORKER_POOL.state.workers.replace(workers?); + Ok(()) } #[wasm_bindgen] impl WorkerPool { - /// Creates a new `WorkerPool` which immediately creates `initial` workers. - /// - /// The pool created here can be used over a long period of time, and it - /// will be initially primed with `initial` workers. Currently workers are - /// never released or gc'd until the whole pool is destroyed. - /// - /// # Errors - /// - /// Returns any error that may happen while a JS web worker is created and a - /// message is sent to it. - #[wasm_bindgen(constructor)] - pub fn new(initial: usize) -> Result { - let pool = WorkerPool { - state: Rc::new(PoolState { - workers: RefCell::new(Vec::with_capacity(initial)), - callback: Closure::new(|event: Event| { - console_log!("unhandled event: {}", event.type_()); - crate::logv(&event); - }), - }), - }; - for _ in 0..initial { - let worker = pool.spawn()?; - pool.state.push(worker); - } - - Ok(pool) - } - /// Unconditionally spawns a new worker /// /// The worker isn't registered with this `WorkerPool` but is capable of @@ -66,24 +70,14 @@ impl WorkerPool { /// Returns any error that may happen while a JS web worker is created and a /// message is sent to it. fn spawn(&self) -> Result { - console_log!("spawning new worker"); - // TODO: what do do about `./worker.js`: - // - // * the path is only known by the bundler. How can we, as a - // library, know what's going on? - // * How do we not fetch a script N times? It internally then - // causes another script to get fetched N times... - let worker = Worker::new("./worker.js")?; - - // With a worker spun up send it the module/memory so it can start - // instantiating the wasm module. Later it might receive further - // messages about code to run on the wasm module. - let array = js_sys::Array::new(); - array.push(&wasm_bindgen::module()); - array.push(&wasm_bindgen::memory()); - worker.post_message(&array)?; - - Ok(worker) + let mut worker_options = web_sys::WorkerOptions::new(); + worker_options.type_(web_sys::WorkerType::Module); + web_sys::Worker::new_with_options( + WORKER_URL + .get() + .expect("You must first call `init` with the worker's url."), + &worker_options, + ) } /// Fetches a worker from this pool, spawning one if necessary. @@ -117,16 +111,26 @@ impl WorkerPool { /// message is sent to it. fn execute(&self, f: impl FnOnce() + Send + 'static) -> Result { let worker = self.worker()?; - let work = Box::new(Work { func: Box::new(f) }); - let ptr = Box::into_raw(work); - match worker.post_message(&JsValue::from(ptr as u32)) { - Ok(()) => Ok(worker), - Err(e) => { - unsafe { - drop(Box::from_raw(ptr)); - } - Err(e) - } + + // Double-boxing because `dyn FnOnce` is unsized and so `Box` has + // an undefined layout (although I think in practice its a pointer and a length?). + let ptr = Box::into_raw(Box::new(Box::new(f) as Box)); + + // See `worker.ts` for the format of this message. + let msg: js_sys::Array = [ + &wasm_bindgen::module(), + &wasm_bindgen::memory(), + &JsValue::from(ptr as u32), + ] + .into_iter() + .collect(); + if let Err(e) = worker.post_message(&msg) { + // We expect the worker to deallocate the box, but if there was an error then + // we'll do it ourselves. + let _ = unsafe { Box::from_raw(ptr) }; + Err(format!("Error initializing worker during post_message: {:?}", e).into()) + } else { + Ok(worker) } } @@ -140,13 +144,13 @@ impl WorkerPool { /// used for all spawned workers to ensure that when the work is finished /// the worker is reclaimed back into this pool. fn reclaim_on_message(&self, worker: Worker) { - let state = Rc::downgrade(&self.state); + let state = Arc::downgrade(&self.state); let worker2 = worker.clone(); - let reclaim_slot = Rc::new(RefCell::new(None)); - let slot2 = reclaim_slot.clone(); + let mut reclaim_slot = Arc::new(RefCell::new(None)); + let mut slot2 = reclaim_slot.clone(); let reclaim = Closure::::new(move |event: Event| { if let Some(error) = event.dyn_ref::() { - console_log!("error in worker: {}", error.message()); + info!("error in worker: {}", error.message()); // TODO: this probably leaks memory somehow? It's sort of // unclear what to do about errors in workers right now. return; @@ -158,16 +162,15 @@ impl WorkerPool { if let Some(state) = state.upgrade() { state.push(worker2.clone()); } - *slot2.borrow_mut() = None; + *slot2.borrow_mut() = Arc::new(RefCell::new(None)); return; } - console_log!("unhandled event: {}", event.type_()); - crate::logv(&event); + info!("unhandled event: {:?}", &event); // TODO: like above, maybe a memory leak here? }); worker.set_onmessage(Some(reclaim.as_ref().unchecked_ref())); - *reclaim_slot.borrow_mut() = Some(reclaim); + *reclaim_slot.borrow_mut() = Arc::new(RefCell::new(Some(reclaim))); } } @@ -207,13 +210,9 @@ impl PoolState { } } -/// Entry point invoked by `worker.js`, a bit of a hack but see the "TODO" above -/// about `worker.js` in general. +/// Entry point invoked by `worker.js` #[wasm_bindgen] -pub fn child_entry_point(ptr: u32) -> Result<(), JsValue> { - let ptr = unsafe { Box::from_raw(ptr as *mut Work) }; - let global = js_sys::global().unchecked_into::(); - (ptr.func)(); - global.post_message(&JsValue::undefined())?; - Ok(()) +pub fn child_entry_point(ptr: u32) { + let work = unsafe { Box::from_raw(ptr as *mut Box) }; + (*work)(); } diff --git a/burn-train/src/util.rs b/burn-train/src/util.rs index 3d3d44ce82..bb86945820 100644 --- a/burn-train/src/util.rs +++ b/burn-train/src/util.rs @@ -1,5 +1,7 @@ #![allow(missing_docs)] +#[cfg(feature = "browser")] +use crate::pool::WORKER_POOL; #[cfg(feature = "browser")] use wasm_bindgen::prelude::*; @@ -21,49 +23,23 @@ where F: FnOnce(), F: Send + 'static, { - let mut worker_options = web_sys::WorkerOptions::new(); - worker_options.type_(web_sys::WorkerType::Module); - // Double-boxing because `dyn FnOnce` is unsized and so `Box` has - let w = web_sys::Worker::new_with_options( - WORKER_URL - .get() - .expect("You must first call `init` with the worker's url."), - &worker_options, - ) - .unwrap_or_else(|_| panic!("Error initializing worker at {:?}", WORKER_URL)); - // an undefined layout (although I think in practice its a pointer and a length?). - let ptr = Box::into_raw(Box::new(Box::new(f) as Box)); - - // See `worker.js` for the format of this message. - let msg: js_sys::Array = [ - &wasm_bindgen::module(), - &wasm_bindgen::memory(), - &JsValue::from(ptr as u32), - ] - .into_iter() - .collect(); - if let Err(e) = w.post_message(&msg) { - // We expect the worker to deallocate the box, but if there was an error then - // we'll do it ourselves. - let _ = unsafe { Box::from_raw(ptr) }; - panic!("Error initializing worker during post_message: {:?}", e) - } else { - Box::new(move || { - w.terminate(); - Ok(()) - }) - } + WORKER_POOL.run(f).unwrap(); + Box::new(|| Ok(())) } #[cfg(feature = "browser")] -static WORKER_URL: std::sync::OnceLock = std::sync::OnceLock::new(); +pub static WORKER_URL: std::sync::OnceLock = std::sync::OnceLock::new(); #[cfg(feature = "browser")] -pub fn init(worker_url: String) -> Result<(), String> { - WORKER_URL.set(worker_url).map_err(|worker_url| { - format!( - "You can only call `init` once. You tried to `init` with: {:?}", - worker_url - ) - }) +pub fn init(worker_url: String, worker_count: usize) -> Result<(), JsValue> { + WORKER_URL + .set(worker_url) + .map_err(|worker_url| -> JsValue { + format!( + "You can only call `init` once. You tried to `init` with: {:?}", + worker_url + ) + .into() + })?; + crate::pool::init(worker_count) } diff --git a/examples/train-web/train/src/lib.rs b/examples/train-web/train/src/lib.rs index 3c0698420e..0c043da09e 100644 --- a/examples/train-web/train/src/lib.rs +++ b/examples/train-web/train/src/lib.rs @@ -18,8 +18,8 @@ pub fn start() { } #[wasm_bindgen] -pub fn init(worker_url: String) -> Result<(), String> { - util::init(worker_url) +pub fn init(worker_url: String, worker_count: usize) -> Result<(), JsValue> { + util::init(worker_url, worker_count) } #[wasm_bindgen] @@ -46,9 +46,3 @@ pub fn run( ); Ok(()) } - -#[wasm_bindgen] -pub fn child_entry_point(ptr: u32) { - let work = unsafe { Box::from_raw(ptr as *mut Box) }; - (*work)(); -} diff --git a/examples/train-web/web/src/train.ts b/examples/train-web/web/src/train.ts index aec82c9d5c..ba2f8b4bda 100644 --- a/examples/train-web/web/src/train.ts +++ b/examples/train-web/web/src/train.ts @@ -6,10 +6,6 @@ import sqliteWasmUrl from './assets/sql-wasm.wasm?url' // @ts-ignore https://github.com/rustwasm/console_error_panic_hook#errorstacktracelimit Error.stackTraceLimit = 30 -wasm() - .then(() => init(workerUrl)) - .catch(console.error) - const sqlJs = initSqlJs({ locateFile: () => sqliteWasmUrl, }) @@ -34,6 +30,9 @@ export function setupTrain(element: HTMLInputElement) { } async function loadSqliteAndRun(ab: ArrayBuffer) { + await wasm() + await init(workerUrl, navigator.hardwareConcurrency) + await sleep(1000) // the workers need time to spin up. TODO, post an init message and await a response. Also maybe move worker construction to Javascript. // Images are an array of arrays. // We can't send an array of arrays to Wasm. // So instead we merge images into a single large array and @@ -97,3 +96,7 @@ function concat(arrays: Uint8Array[]) { } return result } + +async function sleep(ms: number): Promise { + return await new Promise((resolve) => setTimeout(resolve, ms)) +} diff --git a/examples/train-web/web/src/worker.ts b/examples/train-web/web/src/worker.ts index fbc481f6d0..15465f22e9 100644 --- a/examples/train-web/web/src/worker.ts +++ b/examples/train-web/web/src/worker.ts @@ -32,9 +32,9 @@ self.onmessage = async (event) => { // usual native way (where you spin one up to do some work until it finisheds) then // you'll want to clean up the thread's resources. - // @ts-ignore -- can take 0 args, the typescript is incorrect https://github.com/rustwasm/wasm-bindgen/blob/962f580d6f5def2df4223588de96feeb30ce7572/crates/threads-xform/src/lib.rs#L394-L395 - // Free memory (stack, thread-locals) held (in the wasm linear memory) by the thread. - init.__wbindgen_thread_destroy() - // Tell the browser to stop the thread. - close() + // // @ts-ignore -- can take 0 args, the typescript is incorrect https://github.com/rustwasm/wasm-bindgen/blob/962f580d6f5def2df4223588de96feeb30ce7572/crates/threads-xform/src/lib.rs#L394-L395 + // // Free memory (stack, thread-locals) held (in the wasm linear memory) by the thread. + // init.__wbindgen_thread_destroy() + // // Tell the browser to stop the thread. + // close() } From 70e1360b00d989918ec55b8cc08c212d457000a1 Mon Sep 17 00:00:00 2001 From: Alex Date: Wed, 29 Nov 2023 15:55:33 -0600 Subject: [PATCH 34/79] fix conflict --- Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index bb8620ecc1..1aa4993047 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -71,7 +71,7 @@ thiserror = "1.0.50" tracing-appender = "0.2.3" tracing-core = "0.1.32" tracing-subscriber = "0.3.18" -wasm-bindgen = "0.2.88" +wasm-bindgen = "=0.2.88" wasm-bindgen-futures = "0.4.38" wasm-logger = "0.2.0" wasm-timer = "0.2.5" From a1f582869a99dd023eaef9377c5182f89c47f639 Mon Sep 17 00:00:00 2001 From: Alex Date: Thu, 30 Nov 2023 17:08:02 -0600 Subject: [PATCH 35/79] it builds --- Cargo.toml | 2 +- burn-train/Cargo.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 1aa4993047..88711e80ad 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -71,7 +71,7 @@ thiserror = "1.0.50" tracing-appender = "0.2.3" tracing-core = "0.1.32" tracing-subscriber = "0.3.18" -wasm-bindgen = "=0.2.88" +wasm-bindgen = "=0.2.89" wasm-bindgen-futures = "0.4.38" wasm-logger = "0.2.0" wasm-timer = "0.2.5" diff --git a/burn-train/Cargo.toml b/burn-train/Cargo.toml index e7517ba103..5df5e84067 100644 --- a/burn-train/Cargo.toml +++ b/burn-train/Cargo.toml @@ -39,7 +39,7 @@ derive-new = { workspace = true } serde = { workspace = true, features = ["std", "derive"] } js-sys = { version = "0.3.64", optional = true } -web-sys = { version = "0.3.65", optional = true, features = [ +web-sys = { version = "0.3.64", optional = true, features = [ "Worker", "WorkerOptions", "WorkerType", From 9f2285e1d846c7ace352e3224552618e7eb87564 Mon Sep 17 00:00:00 2001 From: Alex Date: Fri, 1 Dec 2023 11:10:33 -0600 Subject: [PATCH 36/79] add rayon and wasm-bindgen-rayon --- burn-train/Cargo.toml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/burn-train/Cargo.toml b/burn-train/Cargo.toml index 5df5e84067..84a510211a 100644 --- a/burn-train/Cargo.toml +++ b/burn-train/Cargo.toml @@ -47,6 +47,8 @@ web-sys = { version = "0.3.64", optional = true, features = [ "ErrorEvent", ] } wasm-bindgen = { workspace = true, optional = true } +wasm-bindgen-rayon = "1.0.3" +rayon = { workspace = true } [dev-dependencies] burn-ndarray = { path = "../burn-ndarray", version = "0.11.0" } From 6d926e806d7a02b103677f09988b5225fe908977 Mon Sep 17 00:00:00 2001 From: Alex Date: Fri, 1 Dec 2023 11:37:30 -0600 Subject: [PATCH 37/79] replace with rayon --- burn-train/src/lib.rs | 1 - burn-train/src/pool.rs | 218 --------------------------- burn-train/src/util.rs | 25 +-- examples/train-web/train/src/lib.rs | 5 - examples/train-web/web/src/train.ts | 5 +- examples/train-web/web/src/worker.ts | 40 ----- 6 files changed, 4 insertions(+), 290 deletions(-) delete mode 100644 burn-train/src/pool.rs delete mode 100644 examples/train-web/web/src/worker.ts diff --git a/burn-train/src/lib.rs b/burn-train/src/lib.rs index 50ec3a3316..fb4dc9b3a0 100644 --- a/burn-train/src/lib.rs +++ b/burn-train/src/lib.rs @@ -2,7 +2,6 @@ //! A library for training neural networks using the burn crate. -pub mod pool; pub mod util; #[macro_use] diff --git a/burn-train/src/pool.rs b/burn-train/src/pool.rs deleted file mode 100644 index d18cdc4392..0000000000 --- a/burn-train/src/pool.rs +++ /dev/null @@ -1,218 +0,0 @@ -//! A small module that's intended to provide an example of creating a pool of -//! web workers which can be used to execute `rayon`-style work. - -use crate::util::WORKER_URL; -use log::info; -use std::borrow::BorrowMut; -use std::{cell::RefCell, sync::Arc}; -use wasm_bindgen::prelude::*; -use web_sys::{ErrorEvent, Event, MessageEvent, Worker}; - -// "This is only safe because wasm is currently single-threaded." https://github.com/rustwasm/wasm-bindgen/issues/1505#issuecomment-489300331 -unsafe impl Send for PoolState {} -unsafe impl Sync for PoolState {} - -lazy_static! { - pub static ref WORKER_POOL: WorkerPool = WorkerPool { - state: Arc::new(PoolState { - workers: RefCell::new(vec!()), - callback: Closure::new(|event: Event| { - info!("unhandled event: {:?}", &event); - }) - }) - }; -} - -#[wasm_bindgen] -pub struct WorkerPool { - state: Arc, -} - -struct PoolState { - workers: RefCell>, - callback: Closure, -} - -/// Creates a new `WorkerPool` which immediately creates `initial` workers. -/// -/// The pool created here can be used over a long period of time, and it -/// will be initially primed with `initial` workers. Currently workers are -/// never released or gc'd until the whole pool is destroyed. -/// -/// # Errors -/// -/// Returns any error that may happen while a JS web worker is created and a -/// message is sent to it. -pub fn init(initial: usize) -> Result<(), JsValue> { - let pool = WorkerPool { - state: Arc::new(PoolState { - workers: RefCell::new(Vec::with_capacity(initial)), - callback: Closure::new(|event: Event| { - info!("unhandled event: {:?}", &event); - }), - }), - }; - - let workers: Result<_, _> = [0..initial].map(|_| pool.spawn()).into_iter().collect(); - WORKER_POOL.state.workers.replace(workers?); - Ok(()) -} - -#[wasm_bindgen] -impl WorkerPool { - /// Unconditionally spawns a new worker - /// - /// The worker isn't registered with this `WorkerPool` but is capable of - /// executing work for this wasm module. - /// - /// # Errors - /// - /// Returns any error that may happen while a JS web worker is created and a - /// message is sent to it. - fn spawn(&self) -> Result { - let mut worker_options = web_sys::WorkerOptions::new(); - worker_options.type_(web_sys::WorkerType::Module); - web_sys::Worker::new_with_options( - WORKER_URL - .get() - .expect("You must first call `init` with the worker's url."), - &worker_options, - ) - } - - /// Fetches a worker from this pool, spawning one if necessary. - /// - /// This will attempt to pull an already-spawned web worker from our cache - /// if one is available, otherwise it will spawn a new worker and return the - /// newly spawned worker. - /// - /// # Errors - /// - /// Returns any error that may happen while a JS web worker is created and a - /// message is sent to it. - fn worker(&self) -> Result { - match self.state.workers.borrow_mut().pop() { - Some(worker) => Ok(worker), - None => self.spawn(), - } - } - - /// Executes the work `f` in a web worker, spawning a web worker if - /// necessary. - /// - /// This will acquire a web worker and then send the closure `f` to the - /// worker to execute. The worker won't be usable for anything else while - /// `f` is executing, and no callbacks are registered for when the worker - /// finishes. - /// - /// # Errors - /// - /// Returns any error that may happen while a JS web worker is created and a - /// message is sent to it. - fn execute(&self, f: impl FnOnce() + Send + 'static) -> Result { - let worker = self.worker()?; - - // Double-boxing because `dyn FnOnce` is unsized and so `Box` has - // an undefined layout (although I think in practice its a pointer and a length?). - let ptr = Box::into_raw(Box::new(Box::new(f) as Box)); - - // See `worker.ts` for the format of this message. - let msg: js_sys::Array = [ - &wasm_bindgen::module(), - &wasm_bindgen::memory(), - &JsValue::from(ptr as u32), - ] - .into_iter() - .collect(); - if let Err(e) = worker.post_message(&msg) { - // We expect the worker to deallocate the box, but if there was an error then - // we'll do it ourselves. - let _ = unsafe { Box::from_raw(ptr) }; - Err(format!("Error initializing worker during post_message: {:?}", e).into()) - } else { - Ok(worker) - } - } - - /// Configures an `onmessage` callback for the `worker` specified for the - /// web worker to be reclaimed and re-inserted into this pool when a message - /// is received. - /// - /// Currently this `WorkerPool` abstraction is intended to execute one-off - /// style work where the work itself doesn't send any notifications and - /// whatn it's done the worker is ready to execute more work. This method is - /// used for all spawned workers to ensure that when the work is finished - /// the worker is reclaimed back into this pool. - fn reclaim_on_message(&self, worker: Worker) { - let state = Arc::downgrade(&self.state); - let worker2 = worker.clone(); - let mut reclaim_slot = Arc::new(RefCell::new(None)); - let mut slot2 = reclaim_slot.clone(); - let reclaim = Closure::::new(move |event: Event| { - if let Some(error) = event.dyn_ref::() { - info!("error in worker: {}", error.message()); - // TODO: this probably leaks memory somehow? It's sort of - // unclear what to do about errors in workers right now. - return; - } - - // If this is a completion event then can deallocate our own - // callback by clearing out `slot2` which contains our own closure. - if let Some(_msg) = event.dyn_ref::() { - if let Some(state) = state.upgrade() { - state.push(worker2.clone()); - } - *slot2.borrow_mut() = Arc::new(RefCell::new(None)); - return; - } - - info!("unhandled event: {:?}", &event); - // TODO: like above, maybe a memory leak here? - }); - worker.set_onmessage(Some(reclaim.as_ref().unchecked_ref())); - *reclaim_slot.borrow_mut() = Arc::new(RefCell::new(Some(reclaim))); - } -} - -impl WorkerPool { - /// Executes `f` in a web worker. - /// - /// This pool manages a set of web workers to draw from, and `f` will be - /// spawned quickly into one if the worker is idle. If no idle workers are - /// available then a new web worker will be spawned. - /// - /// Once `f` returns the worker assigned to `f` is automatically reclaimed - /// by this `WorkerPool`. This method provides no method of learning when - /// `f` completes, and for that you'll need to use `run_notify`. - /// - /// # Errors - /// - /// If an error happens while spawning a web worker or sending a message to - /// a web worker, that error is returned. - pub fn run(&self, f: impl FnOnce() + Send + 'static) -> Result<(), JsValue> { - let worker = self.execute(f)?; - self.reclaim_on_message(worker); - Ok(()) - } -} - -impl PoolState { - fn push(&self, worker: Worker) { - worker.set_onmessage(Some(self.callback.as_ref().unchecked_ref())); - worker.set_onerror(Some(self.callback.as_ref().unchecked_ref())); - let mut workers = self.workers.borrow_mut(); - for prev in workers.iter() { - let prev: &JsValue = prev; - let worker: &JsValue = &worker; - assert!(prev != worker); - } - workers.push(worker); - } -} - -/// Entry point invoked by `worker.js` -#[wasm_bindgen] -pub fn child_entry_point(ptr: u32) { - let work = unsafe { Box::from_raw(ptr as *mut Box) }; - (*work)(); -} diff --git a/burn-train/src/util.rs b/burn-train/src/util.rs index bb86945820..5a3ab1838c 100644 --- a/burn-train/src/util.rs +++ b/burn-train/src/util.rs @@ -1,10 +1,3 @@ -#![allow(missing_docs)] - -#[cfg(feature = "browser")] -use crate::pool::WORKER_POOL; -#[cfg(feature = "browser")] -use wasm_bindgen::prelude::*; - #[cfg(not(feature = "browser"))] pub fn spawn(f: F) -> Box Result<(), ()>> where @@ -23,23 +16,9 @@ where F: FnOnce(), F: Send + 'static, { - WORKER_POOL.run(f).unwrap(); + rayon::spawn(f); Box::new(|| Ok(())) } #[cfg(feature = "browser")] -pub static WORKER_URL: std::sync::OnceLock = std::sync::OnceLock::new(); - -#[cfg(feature = "browser")] -pub fn init(worker_url: String, worker_count: usize) -> Result<(), JsValue> { - WORKER_URL - .set(worker_url) - .map_err(|worker_url| -> JsValue { - format!( - "You can only call `init` once. You tried to `init` with: {:?}", - worker_url - ) - .into() - })?; - crate::pool::init(worker_count) -} +pub use wasm_bindgen_rayon::init_thread_pool; diff --git a/examples/train-web/train/src/lib.rs b/examples/train-web/train/src/lib.rs index 0c043da09e..ceda2676fe 100644 --- a/examples/train-web/train/src/lib.rs +++ b/examples/train-web/train/src/lib.rs @@ -17,11 +17,6 @@ pub fn start() { console_log::init().expect("Error initializing logger"); } -#[wasm_bindgen] -pub fn init(worker_url: String, worker_count: usize) -> Result<(), JsValue> { - util::init(worker_url, worker_count) -} - #[wasm_bindgen] pub fn run( train_labels: &[u8], diff --git a/examples/train-web/web/src/train.ts b/examples/train-web/web/src/train.ts index ba2f8b4bda..131ae174b3 100644 --- a/examples/train-web/web/src/train.ts +++ b/examples/train-web/web/src/train.ts @@ -1,5 +1,4 @@ -import wasm, { init, run } from 'train' -import workerUrl from './worker.ts?url' +import wasm, { run, initThreadPool } from 'train' import initSqlJs, { type Database } from 'sql.js' import sqliteWasmUrl from './assets/sql-wasm.wasm?url' @@ -31,7 +30,7 @@ export function setupTrain(element: HTMLInputElement) { async function loadSqliteAndRun(ab: ArrayBuffer) { await wasm() - await init(workerUrl, navigator.hardwareConcurrency) + await initThreadPool(navigator.hardwareConcurrency) await sleep(1000) // the workers need time to spin up. TODO, post an init message and await a response. Also maybe move worker construction to Javascript. // Images are an array of arrays. // We can't send an array of arrays to Wasm. diff --git a/examples/train-web/web/src/worker.ts b/examples/train-web/web/src/worker.ts deleted file mode 100644 index 15465f22e9..0000000000 --- a/examples/train-web/web/src/worker.ts +++ /dev/null @@ -1,40 +0,0 @@ -// Mostly copied from https://github.com/tweag/rust-wasm-threads/blob/main/shared-memory/worker.js - -import trainWasmUrl from 'train/train_bg.wasm?url' -import wasm, { child_entry_point } from 'train' - -self.onmessage = async (event) => { - // We expect a message with three elements [module, memory, ptr], where: - // - module is a WebAssembly.Module - // - memory is the WebAssembly.Memory object that the main thread is using - // (and we want to use it too). - // - ptr is the pointer (within `memory`) of the function that we want to execute. - // - // The auto-generated `wasm_bindgen` function takes two *optional* arguments. - // The first is the module (you can pass in a url, like we did in "index.js", - // or a module object); the second is the memory block to use, and if you - // don't provide one (like we didn't in "index.js") then a new one will be - // allocated for you. - let init = await wasm(trainWasmUrl, event.data[1]).catch((err) => { - // Propagate to main `onerror`: - setTimeout(() => { - throw err - }) - // Rethrow to keep promise rejected and prevent execution of further commands: - throw err - }) - - child_entry_point(event.data[2]) - - // Clean up thread resources. Depending on what you're doing with the thread, this might - // not be what you want. (For example, if the thread spawned some javascript tasks - // and exited, this is going to cancel those tasks.) But if you're using threads in the - // usual native way (where you spin one up to do some work until it finisheds) then - // you'll want to clean up the thread's resources. - - // // @ts-ignore -- can take 0 args, the typescript is incorrect https://github.com/rustwasm/wasm-bindgen/blob/962f580d6f5def2df4223588de96feeb30ce7572/crates/threads-xform/src/lib.rs#L394-L395 - // // Free memory (stack, thread-locals) held (in the wasm linear memory) by the thread. - // init.__wbindgen_thread_destroy() - // // Tell the browser to stop the thread. - // close() -} From 803587092ebf4447390aad243366cb406362e136 Mon Sep 17 00:00:00 2001 From: Alex Date: Fri, 1 Dec 2023 12:15:25 -0600 Subject: [PATCH 38/79] fix conflicts --- burn/Cargo.toml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/burn/Cargo.toml b/burn/Cargo.toml index 13a69ba63e..578afa123c 100644 --- a/burn/Cargo.toml +++ b/burn/Cargo.toml @@ -27,6 +27,8 @@ metrics = ["burn-train?/metrics"] # Useful when targeting WASM and not using WGPU. wasm-sync = ["burn-core/wasm-sync"] +browser = ["burn-train/browser"] + # Datasets dataset = ["burn-core/dataset"] From 7b2787167ca6509432497053da5ec3b24077ccba Mon Sep 17 00:00:00 2001 From: Alex Date: Fri, 1 Dec 2023 12:16:04 -0600 Subject: [PATCH 39/79] nix license notice --- NOTICES.md | 26 -------------------------- 1 file changed, 26 deletions(-) diff --git a/NOTICES.md b/NOTICES.md index b486e8e637..eae90cd7b3 100644 --- a/NOTICES.md +++ b/NOTICES.md @@ -244,29 +244,3 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. - -## Rust threads in the browser - -**Source**: https://github.com/tweag/rust-wasm-threads/tree/main/shared-memory - -MIT License - -Copyright (c) 2023 Matthew Toohey and Modus Create LLC - -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in all -copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE. From da36016512cab7b25238b8d7f7245563a81a0913 Mon Sep 17 00:00:00 2001 From: Alex Date: Fri, 1 Dec 2023 13:42:27 -0600 Subject: [PATCH 40/79] fix conflicts --- burn-train/Cargo.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/burn-train/Cargo.toml b/burn-train/Cargo.toml index 9e30a3034e..02171fd02f 100644 --- a/burn-train/Cargo.toml +++ b/burn-train/Cargo.toml @@ -14,6 +14,7 @@ version = "0.11.0" default = ["metrics", "tui"] metrics = ["nvml-wrapper", "sysinfo", "systemstat"] tui = ["ratatui", "crossterm"] +browser = ["js-sys", "web-sys", "wasm-bindgen"] [dependencies] burn-core = { path = "../burn-core", version = "0.11.0" } From 9aa04399c7da9b00dfc408bf102dab62eaa61f05 Mon Sep 17 00:00:00 2001 From: Alex Date: Fri, 1 Dec 2023 14:31:41 -0600 Subject: [PATCH 41/79] fix --- burn-train/Cargo.toml | 4 +++- burn-train/src/lib.rs | 3 --- examples/train-web/train/Cargo.toml | 4 ++-- 3 files changed, 5 insertions(+), 6 deletions(-) diff --git a/burn-train/Cargo.toml b/burn-train/Cargo.toml index 02171fd02f..4f92c4713c 100644 --- a/burn-train/Cargo.toml +++ b/burn-train/Cargo.toml @@ -17,7 +17,9 @@ tui = ["ratatui", "crossterm"] browser = ["js-sys", "web-sys", "wasm-bindgen"] [dependencies] -burn-core = { path = "../burn-core", version = "0.11.0" } +burn-core = { path = "../burn-core", version = "0.11.0", default-features = false, features = [ + "std", +] } log = { workspace = true } tracing-subscriber.workspace = true diff --git a/burn-train/src/lib.rs b/burn-train/src/lib.rs index fb4dc9b3a0..f4bc430ee3 100644 --- a/burn-train/src/lib.rs +++ b/burn-train/src/lib.rs @@ -27,6 +27,3 @@ pub use learner::*; #[cfg(test)] pub(crate) type TestBackend = burn_ndarray::NdArray; - -#[macro_use] -extern crate lazy_static; diff --git a/examples/train-web/train/Cargo.toml b/examples/train-web/train/Cargo.toml index c7b6a65b31..cfbdb78a46 100644 --- a/examples/train-web/train/Cargo.toml +++ b/examples/train-web/train/Cargo.toml @@ -13,8 +13,8 @@ console_error_panic_hook = "0.1.7" console_log = { version = "1", features = ["color"] } burn = { path = "../../../burn", default-features = false, features = [ "autodiff", - "ndarray-no-std", - "train-minimal", + "train", + "ndarray", "wasm-sync", "browser", ] } From 051b41fb48a1d83f7c9ea0df4b86ad9faa72f394 Mon Sep 17 00:00:00 2001 From: Alex Date: Fri, 1 Dec 2023 15:52:22 -0600 Subject: [PATCH 42/79] fix? --- burn-train/Cargo.toml | 6 +++--- examples/train-web/train/rust-toolchain.toml | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/burn-train/Cargo.toml b/burn-train/Cargo.toml index 4f92c4713c..34ba6ab5a5 100644 --- a/burn-train/Cargo.toml +++ b/burn-train/Cargo.toml @@ -14,7 +14,7 @@ version = "0.11.0" default = ["metrics", "tui"] metrics = ["nvml-wrapper", "sysinfo", "systemstat"] tui = ["ratatui", "crossterm"] -browser = ["js-sys", "web-sys", "wasm-bindgen"] +browser = ["js-sys", "web-sys", "wasm-bindgen", "wasm-bindgen-rayon", "rayon"] [dependencies] burn-core = { path = "../burn-core", version = "0.11.0", default-features = false, features = [ @@ -48,8 +48,8 @@ web-sys = { version = "0.3.64", optional = true, features = [ "ErrorEvent", ] } wasm-bindgen = { workspace = true, optional = true } -wasm-bindgen-rayon = "1.0.3" -rayon = { workspace = true } +wasm-bindgen-rayon = { version = "1.0.3", optional = true } +rayon = { workspace = true, optional = true } [dev-dependencies] burn-ndarray = { path = "../burn-ndarray", version = "0.11.0" } diff --git a/examples/train-web/train/rust-toolchain.toml b/examples/train-web/train/rust-toolchain.toml index f29c21eaa5..68496f835b 100644 --- a/examples/train-web/train/rust-toolchain.toml +++ b/examples/train-web/train/rust-toolchain.toml @@ -1,2 +1,2 @@ [toolchain] -channel = "nightly-2023-11-05" +channel = "nightly-2023-07-01" From 7a8372cb0a486ec03a57df7539c047b36833b27a Mon Sep 17 00:00:00 2001 From: Alex Date: Fri, 1 Dec 2023 16:13:01 -0600 Subject: [PATCH 43/79] fix?? --- xtask/src/runchecks.rs | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/xtask/src/runchecks.rs b/xtask/src/runchecks.rs index 98badcf8de..bf9aff12cd 100644 --- a/xtask/src/runchecks.rs +++ b/xtask/src/runchecks.rs @@ -275,10 +275,19 @@ fn std_checks() { // Build each workspace if disable_wgpu { - cargo_build(["--workspace", "--exclude=xtask", "--exclude=burn-wgpu"].into()); + cargo_build( + [ + "--workspace", + "--exclude=burn-train", + "--exclude=xtask", + "--exclude=burn-wgpu", + ] + .into(), + ); } else { - cargo_build(["--workspace", "--exclude=xtask"].into()); + cargo_build(["--workspace", "--exclude=burn-train", "--exclude=xtask"].into()); } + cargo_build(["-p", "burn-train", "--lib"].into()); // Produce documentation for each workspace cargo_doc(["--workspace"].into()); From 2c9dc069d93e6e46d6cf46b6bc47cae3641a4c0c Mon Sep 17 00:00:00 2001 From: Alex Date: Sat, 2 Dec 2023 12:51:59 -0600 Subject: [PATCH 44/79] runchecks runs each package separately --- Cargo.toml | 1 - burn-train/src/learner/early_stopping.rs | 2 +- burn-train/src/util.rs | 4 +- xtask/src/runchecks.rs | 174 +++++++++++++++++------ 4 files changed, 132 insertions(+), 49 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 88711e80ad..3b3854c3ef 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,5 +1,4 @@ [workspace] -# Try # require version 2 to avoid "feature" additiveness for dev-dependencies # https://doc.rust-lang.org/cargo/reference/resolver.html#feature-resolver-version-2 resolver = "2" diff --git a/burn-train/src/learner/early_stopping.rs b/burn-train/src/learner/early_stopping.rs index 334756d885..d69a2eff0d 100644 --- a/burn-train/src/learner/early_stopping.rs +++ b/burn-train/src/learner/early_stopping.rs @@ -104,7 +104,7 @@ impl MetricEarlyStoppingStrategy { #[cfg(test)] mod tests { - use std::{rc::Rc, sync::Arc}; + use std::rc::Rc; use crate::{ logger::InMemoryMetricLogger, diff --git a/burn-train/src/util.rs b/burn-train/src/util.rs index 5a3ab1838c..4ecfeb82db 100644 --- a/burn-train/src/util.rs +++ b/burn-train/src/util.rs @@ -1,3 +1,5 @@ +#![allow(missing_docs)] + #[cfg(not(feature = "browser"))] pub fn spawn(f: F) -> Box Result<(), ()>> where @@ -8,8 +10,6 @@ where Box::new(move || handle.join().map_err(|_| ())) } -// High level description at https://www.tweag.io/blog/2022-11-24-wasm-threads-and-messages/ -// Mostly copied from https://github.com/tweag/rust-wasm-threads/blob/main/shared-memory/src/lib.rs #[cfg(feature = "browser")] pub fn spawn(f: F) -> Box Result<(), ()>> where diff --git a/xtask/src/runchecks.rs b/xtask/src/runchecks.rs index bf9aff12cd..7784e64075 100644 --- a/xtask/src/runchecks.rs +++ b/xtask/src/runchecks.rs @@ -13,6 +13,41 @@ use std::time::Instant; const WASM32_TARGET: &str = "wasm32-unknown-unknown"; const ARM_TARGET: &str = "thumbv7m-none-eabi"; +const PACKAGES: &[&str] = &[ + "backend-comparison", + "burn-autodiff", + "burn-candle", + "burn-common", + "burn-compute", + "burn-core", + "burn-dataset", + "burn-derive", + "burn-fusion", + "burn-import", + "burn-import/onnx-tests", + "burn-ndarray", + "burn-no-std-tests", + "burn-tch", + "burn-tensor-testgen", + "burn-tensor", + "burn-train", + "burn-wgpu", + "burn", + "examples/custom-renderer", + "examples/custom-training-loop", + "examples/custom-wgpu-kernel", + "examples/guide", + "examples/image-classification-web", + "examples/mnist-inference-web", + "examples/mnist", + "examples/named-tensor", + "examples/onnx-inference", + "examples/text-classification", + "examples/text-generation", + "examples/train-web/train", + "xtask", +]; + // Handle child process fn handle_child_process(mut child: Child, error: &str) { // Wait for the child process to finish @@ -54,66 +89,103 @@ fn rustup(command: &str, target: &str) { } // Define and run a cargo command -fn run_cargo(command: &str, params: Params, error: &str) { - // Print cargo command - println!("\ncargo {} {}\n", command, params); - - // Run cargo - let cargo = Command::new("cargo") - .env("CARGO_INCREMENTAL", "0") - .arg(command) - .args(params.params) - .stdout(Stdio::inherit()) // Send stdout directly to terminal - .stderr(Stdio::inherit()) // Send stderr directly to terminal - .spawn() - .expect(error); +fn run_cargo(command: &str, params: Params, error: &str, exclude: Vec<&str>) { + // Separate out packages to prevent feature unification https://doc.rust-lang.org/cargo/reference/resolver.html#feature-resolver-version-2:~:text=When%20building%20multiple,separate%20cargo%20invocations. + let packages = if params.params.contains(&String::from("-p")) { + vec!["."] + } else { + PACKAGES + .iter() + .filter(|p| !exclude.contains(p)) + .cloned() + .collect::>() + }; + + for p in packages { + // Print cargo command + println!("\n{} $ cargo {} {}\n", p, command, params); + // Run cargo + let cargo = Command::new("cargo") + .env("CARGO_INCREMENTAL", "0") + .current_dir(p) + .arg(command) + .args(¶ms.params) + .stdout(Stdio::inherit()) // Send stdout directly to terminal + .stderr(Stdio::inherit()) // Send stderr directly to terminal + .spawn() + .expect(error); - // Handle cargo child process - handle_child_process(cargo, "Failed to wait for cargo child process"); + // Handle cargo child process + handle_child_process(cargo, "Failed to wait for cargo child process"); + } } // Run cargo build command -fn cargo_build(params: Params) { +fn cargo_build_excluding(params: Params, exclude: Vec<&str>) { // Run cargo build run_cargo( "build", params + "--color=always", "Failed to run cargo build", + exclude, ); } +// Run cargo build command +fn cargo_build(params: Params) { + cargo_build_excluding(params, [].into()) +} + // Run cargo install command -fn cargo_install(params: Params) { +fn cargo_install_excluding(params: Params, exclude: Vec<&str>) { // Run cargo install run_cargo( "install", params + "--color=always", "Failed to run cargo install", + exclude, ); } +// Run cargo install command +fn cargo_install(params: Params) { + cargo_install_excluding(params, [].into()) +} + // Run cargo test command -fn cargo_test(params: Params) { +fn cargo_test_excluding(params: Params, exclude: Vec<&str>) { // Run cargo test run_cargo( "test", params + "--color=always" + "--" + "--color=always", "Failed to run cargo test", + exclude, ); } +// Run cargo test command +fn cargo_test(params: Params) { + cargo_test_excluding(params, [].into()) +} + // Run cargo fmt command -fn cargo_fmt() { +fn cargo_fmt_excluding(exclude: Vec<&str>) { // Run cargo fmt run_cargo( "fmt", ["--check", "--all", "--", "--color=always"].into(), "Failed to run cargo fmt", + exclude, ); } +// Run cargo fmt command +fn cargo_fmt() { + cargo_fmt_excluding([].into()) +} + // Run cargo clippy command -fn cargo_clippy() { +fn cargo_clippy_excluding(exclude: Vec<&str>) { if std::env::var("CI_RUN").is_ok() { return; } @@ -122,13 +194,23 @@ fn cargo_clippy() { "clippy", ["--color=always", "--all-targets", "--", "-D", "warnings"].into(), "Failed to run cargo clippy", + exclude, ); } // Run cargo doc command -fn cargo_doc(params: Params) { +fn cargo_doc_excluding(params: Params, exclude: Vec<&str>) { // Run cargo doc - run_cargo("doc", params + "--color=always", "Failed to run cargo doc"); + run_cargo( + "doc", + params + "--no-deps" + "--color=always", + "Failed to run cargo doc", + exclude, + ); +} + +fn cargo_doc(params: Params) { + cargo_doc_excluding(params, [].into()) } // Build and test a crate in a no_std environment @@ -271,36 +353,38 @@ fn std_checks() { cargo_fmt(); // Check clippy lints - cargo_clippy(); - - // Build each workspace - if disable_wgpu { - cargo_build( - [ - "--workspace", - "--exclude=burn-train", - "--exclude=xtask", - "--exclude=burn-wgpu", - ] - .into(), - ); - } else { - cargo_build(["--workspace", "--exclude=burn-train", "--exclude=xtask"].into()); - } - cargo_build(["-p", "burn-train", "--lib"].into()); + cargo_clippy_excluding( + [ + "burn-no-std-tests", + "examples/image-classification-web", + "examples/mnist-inference-web", + "examples/train-web/train", + ] + .into(), + ); - // Produce documentation for each workspace - cargo_doc(["--workspace"].into()); + // Build each package + let mut exclude = if disable_wgpu { + vec!["burn-wgpu"] + } else { + vec![] + }; + exclude.push("examples/train-web/train"); + cargo_build_excluding([].into(), exclude); + + // Produce documentation for each package + cargo_doc_excluding( + [].into(), + ["burn-no-std-tests", "examples/train-web/train"].into(), + ); // Setup code coverage if is_coverage { setup_coverage(); } - // Test each workspace - cargo_test(["--workspace", "--exclude", "burn-train"].into()); - // Separate out `burn-train` to prevent feature unification https://doc.rust-lang.org/cargo/reference/resolver.html#feature-resolver-version-2:~:text=When%20building%20multiple,separate%20cargo%20invocations. - cargo_test(["-p", "burn-train", "--lib"].into()); + // Test each package + cargo_test_excluding([].into(), ["examples/train-web/train"].into()); // Test burn-candle with accelerate (macOS only) #[cfg(target_os = "macos")] From 1ee5cef1594b67dc10c1a6a215d0d826b6122e4c Mon Sep 17 00:00:00 2001 From: Alex Date: Sat, 2 Dec 2023 13:57:31 -0600 Subject: [PATCH 45/79] skip train-web --- xtask/src/runchecks.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/xtask/src/runchecks.rs b/xtask/src/runchecks.rs index 7784e64075..f8370a5667 100644 --- a/xtask/src/runchecks.rs +++ b/xtask/src/runchecks.rs @@ -439,7 +439,8 @@ fn check_examples() { if !path.is_dir() { return; } - if path.file_name().unwrap().to_str().unwrap() == "notebook" { + let dirname = path.file_name().unwrap().to_str().unwrap(); + if dirname == "notebook" || dirname == "train-web" { // not a crate return; } From 3f52a7816996c3074ee898c505b9a6d1b7ca941e Mon Sep 17 00:00:00 2001 From: Alex Date: Sat, 2 Dec 2023 18:43:56 -0600 Subject: [PATCH 46/79] fix deps --- burn-train/Cargo.toml | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/burn-train/Cargo.toml b/burn-train/Cargo.toml index 675e3773d3..f3d4a9469d 100644 --- a/burn-train/Cargo.toml +++ b/burn-train/Cargo.toml @@ -11,13 +11,20 @@ repository = "https://github.com/tracel-ai/burn/tree/main/burn-train" version = "0.11.0" [features] -default = ["metrics", "tui"] +default = ["metrics", "tui", "burn-core/default", "burn-core/dataset"] metrics = ["nvml-wrapper", "sysinfo", "systemstat"] tui = ["ratatui", "crossterm"] -browser = ["js-sys", "web-sys", "wasm-bindgen", "wasm-bindgen-rayon", "rayon"] +browser = [ + "js-sys", + "web-sys", + "wasm-bindgen", + "wasm-bindgen-rayon", + "rayon", + "burn-core/std", +] [dependencies] -burn-core = { path = "../burn-core", version = "0.11.0", features = ["dataset"] } +burn-core = { path = "../burn-core", version = "0.11.0", default-features = false } log = { workspace = true } tracing-subscriber.workspace = true From f188b7a8ffed45a96cf4c2fad18d5ab52c693d61 Mon Sep 17 00:00:00 2001 From: Alex Date: Thu, 16 Nov 2023 09:22:39 -0600 Subject: [PATCH 47/79] converted train to a worker (as to not block main thread) --- examples/train-web/web/src/counter.ts | 9 ------ examples/train-web/web/src/eventHandlers.ts | 32 +++++++++++++++++++++ examples/train-web/web/src/main.ts | 11 +++++-- examples/train-web/web/src/train.ts | 23 +++++---------- 4 files changed, 47 insertions(+), 28 deletions(-) delete mode 100644 examples/train-web/web/src/counter.ts create mode 100644 examples/train-web/web/src/eventHandlers.ts diff --git a/examples/train-web/web/src/counter.ts b/examples/train-web/web/src/counter.ts deleted file mode 100644 index 164395d512..0000000000 --- a/examples/train-web/web/src/counter.ts +++ /dev/null @@ -1,9 +0,0 @@ -export function setupCounter(element: HTMLButtonElement) { - let counter = 0 - const setCounter = (count: number) => { - counter = count - element.innerHTML = `count is ${counter}` - } - element.addEventListener('click', () => setCounter(counter + 1)) - setCounter(0) -} diff --git a/examples/train-web/web/src/eventHandlers.ts b/examples/train-web/web/src/eventHandlers.ts new file mode 100644 index 0000000000..4eb0e94cfc --- /dev/null +++ b/examples/train-web/web/src/eventHandlers.ts @@ -0,0 +1,32 @@ +import Train from './train.ts?worker' + +export function setupCounter(element: HTMLButtonElement) { + let counter = 0 + const setCounter = (count: number) => { + counter = count + element.innerHTML = `count is ${counter}` + } + element.addEventListener('click', () => setCounter(counter + 1)) + setCounter(0) +} + +export function setupTrain( + element: HTMLButtonElement, + fileElement: HTMLInputElement, +) { + element.addEventListener('click', async () => { + element.innerHTML = `Training...` + element.disabled = true + const train = new Train() + const files = fileElement.files + if (files != null) { + const file = files[0] + if (file != null) { + let ab = await file.arrayBuffer() + train.postMessage(ab, [ab]) + return + } + } + train.postMessage('autotrain') + }) +} diff --git a/examples/train-web/web/src/main.ts b/examples/train-web/web/src/main.ts index af823defff..2aa251579f 100644 --- a/examples/train-web/web/src/main.ts +++ b/examples/train-web/web/src/main.ts @@ -1,8 +1,7 @@ import './style.css' import typescriptLogo from './typescript.svg' import viteLogo from '/vite.svg' -import { setupCounter } from './counter.ts' -import { setupTrain } from './train.ts' +import { setupCounter, setupTrain } from './eventHandlers.ts' document.querySelector('#app')!.innerHTML = `
@@ -16,6 +15,9 @@ document.querySelector('#app')!.innerHTML = `
+
+ +
@@ -26,4 +28,7 @@ document.querySelector('#app')!.innerHTML = ` ` setupCounter(document.querySelector('#counter')!) -setupTrain(document.querySelector('#dbfile')!) +setupTrain( + document.querySelector('#train')!, + document.querySelector('#dbfile')!, +) diff --git a/examples/train-web/web/src/train.ts b/examples/train-web/web/src/train.ts index 131ae174b3..b23c605390 100644 --- a/examples/train-web/web/src/train.ts +++ b/examples/train-web/web/src/train.ts @@ -9,22 +9,13 @@ const sqlJs = initSqlJs({ locateFile: () => sqliteWasmUrl, }) -// just load and run mnist on pageload -fetch('./mnist.db') - .then((r) => r.arrayBuffer()) - .then(loadSqliteAndRun) - .catch(console.error) - -export function setupTrain(element: HTMLInputElement) { - element.onchange = async function () { - const files = element.files - if (files != null) { - const file = files[0] - if (file != null) { - let ab = await file.arrayBuffer() - await loadSqliteAndRun(ab) - } - } +self.onmessage = async (event) => { + if (event.data === 'autotrain') { + let db = await fetch('/mnist.db') + let ab = await db.arrayBuffer() + loadSqliteAndRun(ab) + } else if (event.data instanceof ArrayBuffer) { + loadSqliteAndRun(event.data) } } From 37a748e6eb68903a5d507ce55189aa725c0a4963 Mon Sep 17 00:00:00 2001 From: Alex Date: Thu, 16 Nov 2023 09:52:17 -0600 Subject: [PATCH 48/79] add autotrain checkbox --- examples/train-web/web/src/eventHandlers.ts | 18 ++++++++++++++++++ examples/train-web/web/src/main.ts | 8 +++++++- 2 files changed, 25 insertions(+), 1 deletion(-) diff --git a/examples/train-web/web/src/eventHandlers.ts b/examples/train-web/web/src/eventHandlers.ts index 4eb0e94cfc..fbeee690a5 100644 --- a/examples/train-web/web/src/eventHandlers.ts +++ b/examples/train-web/web/src/eventHandlers.ts @@ -30,3 +30,21 @@ export function setupTrain( train.postMessage('autotrain') }) } + +export function setupAutotrain( + element: HTMLInputElement, + trainElement: HTMLButtonElement, +) { + const autotrain = localStorage.getItem('autotrain') + element.checked = autotrain != null + if (element.checked) { + trainElement.click() + } + element.addEventListener('click', () => { + if (element.checked) { + localStorage.setItem('autotrain', 'true') + } else { + localStorage.removeItem('autotrain') + } + }) +} diff --git a/examples/train-web/web/src/main.ts b/examples/train-web/web/src/main.ts index 2aa251579f..f69f158fea 100644 --- a/examples/train-web/web/src/main.ts +++ b/examples/train-web/web/src/main.ts @@ -1,7 +1,7 @@ import './style.css' import typescriptLogo from './typescript.svg' import viteLogo from '/vite.svg' -import { setupCounter, setupTrain } from './eventHandlers.ts' +import { setupAutotrain, setupCounter, setupTrain } from './eventHandlers.ts' document.querySelector('#app')!.innerHTML = `
@@ -21,6 +21,8 @@ document.querySelector('#app')!.innerHTML = `
+ +

Click on the Vite and TypeScript logos to learn more

@@ -32,3 +34,7 @@ setupTrain( document.querySelector('#train')!, document.querySelector('#dbfile')!, ) +setupAutotrain( + document.querySelector('#autotrain')!, + document.querySelector('#train')!, +) From fa63f6d49005653d531677a106fadf40cc856bfa Mon Sep 17 00:00:00 2001 From: Alex Date: Thu, 16 Nov 2023 11:16:39 -0600 Subject: [PATCH 49/79] MetricsRenderer logs --- examples/train-web/train/src/train.rs | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/examples/train-web/train/src/train.rs b/examples/train-web/train/src/train.rs index b7328274c6..00d694cb06 100644 --- a/examples/train-web/train/src/train.rs +++ b/examples/train-web/train/src/train.rs @@ -21,6 +21,7 @@ use burn::{ ClassificationOutput, LearnerBuilder, TrainOutput, TrainStep, ValidStep, }, }; +use log::info; impl Model { pub fn forward_classification( @@ -68,16 +69,20 @@ pub struct TrainingConfig { struct CustomRenderer {} impl MetricsRenderer for CustomRenderer { - fn update_train(&mut self, _state: MetricState) {} + fn update_train(&mut self, state: MetricState) { + info!("Updated Training: {:?}", state); + } - fn update_valid(&mut self, _state: MetricState) {} + fn update_valid(&mut self, state: MetricState) { + info!("Updated Validation: {:?}", state); + } fn render_train(&mut self, item: TrainingProgress) { - dbg!(item); + info!("Training Progress: {:?}", item); } fn render_valid(&mut self, item: TrainingProgress) { - dbg!(item); + info!("Validation Progress: {:?}", item); } } From f8ab1d41286df2634c5f1a07a259e3068399985b Mon Sep 17 00:00:00 2001 From: Alex Date: Sat, 2 Dec 2023 20:31:27 -0600 Subject: [PATCH 50/79] add to_bytes --- burn-core/src/module/base.rs | 9 +++++++++ examples/train-web/train/src/lib.rs | 5 ++--- examples/train-web/train/src/train.rs | 8 ++++---- 3 files changed, 15 insertions(+), 7 deletions(-) diff --git a/burn-core/src/module/base.rs b/burn-core/src/module/base.rs index ba98cee0cf..3bae2c2c7d 100644 --- a/burn-core/src/module/base.rs +++ b/burn-core/src/module/base.rs @@ -173,6 +173,15 @@ pub trait Module: Clone + Send + Sync + core::fmt::Debug { recorder.record(record, file_path.into()) } + /// Return the module using [BytesRecorder](crate::record::BytesRecorder). + fn to_bytes( + self, + recorder: &BR, + ) -> Result, crate::record::RecorderError> { + let record = Self::into_record(self); + recorder.record(record, ()) + } + #[cfg(feature = "std")] /// Load the module from a file using the provided [file recorder](crate::record::FileRecorder). /// diff --git a/examples/train-web/train/src/lib.rs b/examples/train-web/train/src/lib.rs index ceda2676fe..cfdf23eb05 100644 --- a/examples/train-web/train/src/lib.rs +++ b/examples/train-web/train/src/lib.rs @@ -25,7 +25,7 @@ pub fn run( test_labels: &[u8], test_images: &[u8], test_lengths: &[u16], -) -> Result<(), String> { +) -> Vec { log::info!("Hello from Rust"); let config = TrainingConfig::new(ModelConfig::new(10, 512), AdamConfig::new()); train::train::>>( @@ -38,6 +38,5 @@ pub fn run( test_labels, test_images, test_lengths, - ); - Ok(()) + ) } diff --git a/examples/train-web/train/src/train.rs b/examples/train-web/train/src/train.rs index 00d694cb06..f1f84efebf 100644 --- a/examples/train-web/train/src/train.rs +++ b/examples/train-web/train/src/train.rs @@ -10,7 +10,7 @@ use burn::{ module::Module, nn::loss::CrossEntropyLoss, optim::AdamConfig, - record::CompactRecorder, + record::{BinBytesRecorder, FullPrecisionSettings}, tensor::{ backend::{AutodiffBackend, Backend}, Int, Tensor, @@ -96,7 +96,7 @@ pub fn train( test_labels: &[u8], test_images: &[u8], test_lengths: &[u16], -) { +) -> Vec { // std::fs::create_dir_all(artifact_dir).ok(); // config // .save(format!("{artifact_dir}/config.json")) @@ -140,6 +140,6 @@ pub fn train( let model_trained = learner.fit(dataloader_train, dataloader_test); model_trained - .save_file(format!("{artifact_dir}/model"), &CompactRecorder::new()) - .expect("Failed to save trained model"); + .to_bytes(&BinBytesRecorder::::default()) + .expect("Failed to serialize model") } From 9e9dfb68c8519b55a82fbdbc76a949f3dbd13b26 Mon Sep 17 00:00:00 2001 From: Alex Date: Thu, 16 Nov 2023 12:56:14 -0600 Subject: [PATCH 51/79] comment out off by one check --- examples/train-web/train/src/mnist.rs | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/examples/train-web/train/src/mnist.rs b/examples/train-web/train/src/mnist.rs index a86c99f259..d4729eadf3 100644 --- a/examples/train-web/train/src/mnist.rs +++ b/examples/train-web/train/src/mnist.rs @@ -89,14 +89,14 @@ impl MNISTDataset { let image_bytes = images[start..end].to_vec(); // Assert that the incoming data is a valid PNG. // Really just here to ensure the decoding process isn't off by one. - debug_assert_eq!( - image_bytes[0..8], - vec![137, 80, 78, 71, 13, 10, 26, 10] // Starts with bytes from the spec http://www.libpng.org/pub/png/spec/1.2/PNG-Structure.html - ); - debug_assert_eq!( - image_bytes[image_bytes.len() - 8..], - vec![73, 69, 78, 68, 174, 66, 96, 130] // Ends with the IEND chunk. I don't think the spec specifies the last 4 bytes, but `mnist.db`'s images all end with it. - ); + // debug_assert_eq!( + // image_bytes[0..8], + // vec![137, 80, 78, 71, 13, 10, 26, 10] // Starts with bytes from the spec http://www.libpng.org/pub/png/spec/1.2/PNG-Structure.html + // ); + // debug_assert_eq!( + // image_bytes[image_bytes.len() - 8..], + // vec![73, 69, 78, 68, 174, 66, 96, 130] // Ends with the IEND chunk. I don't think the spec specifies the last 4 bytes, but `mnist.db`'s images all end with it. + // ); let raw = MNISTItemRaw { label: *label, image_bytes, From 0a9a4673650ac80b26999711a7103331bf661cb8 Mon Sep 17 00:00:00 2001 From: Alex Date: Sun, 3 Dec 2023 16:56:04 -0600 Subject: [PATCH 52/79] don't hardcode dirs --- .github/workflows/test.yml | 6 +++ xtask/Cargo.toml | 1 + xtask/src/runchecks.rs | 84 +++++++++++++------------------------- 3 files changed, 35 insertions(+), 56 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index f79770a622..9a2e58d9b0 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -63,6 +63,12 @@ jobs: - name: checkout uses: actions/checkout@v4 + - name: install nightly rust + uses: dtolnay/rust-toolchain@master + with: + components: rustfmt, clippy + toolchain: nightly-2023-07-01 + - name: install rust uses: dtolnay/rust-toolchain@master with: diff --git a/xtask/Cargo.toml b/xtask/Cargo.toml index d92e7bb24b..04d1743d54 100644 --- a/xtask/Cargo.toml +++ b/xtask/Cargo.toml @@ -9,3 +9,4 @@ license = "MIT OR Apache-2.0" [dependencies] anyhow = "1.0.75" clap = { version = "4.4.8", features = ["derive"] } +glob = "0.3.1" diff --git a/xtask/src/runchecks.rs b/xtask/src/runchecks.rs index f8370a5667..e48fdb5e9b 100644 --- a/xtask/src/runchecks.rs +++ b/xtask/src/runchecks.rs @@ -4,50 +4,32 @@ //! //! It is also used to check that the code is formatted correctly and passes clippy. +use glob::glob; use std::env; use std::process::{Child, Command, Stdio}; use std::str; use std::time::Instant; +fn get_package_dirs(glob_str: &str) -> Vec { + glob(glob_str) + .expect("Failed to read glob pattern") + .filter_map(|dir| { + let dir = dir.unwrap(); + let dir = dir.as_path().to_str().unwrap(); + if dir == "Cargo.toml" { + None + } else { + let length = dir.len() - "/Cargo.toml".len(); + Some(dir.to_string()[..length].to_string()) + } + }) + .collect() +} + // Targets constants const WASM32_TARGET: &str = "wasm32-unknown-unknown"; const ARM_TARGET: &str = "thumbv7m-none-eabi"; -const PACKAGES: &[&str] = &[ - "backend-comparison", - "burn-autodiff", - "burn-candle", - "burn-common", - "burn-compute", - "burn-core", - "burn-dataset", - "burn-derive", - "burn-fusion", - "burn-import", - "burn-import/onnx-tests", - "burn-ndarray", - "burn-no-std-tests", - "burn-tch", - "burn-tensor-testgen", - "burn-tensor", - "burn-train", - "burn-wgpu", - "burn", - "examples/custom-renderer", - "examples/custom-training-loop", - "examples/custom-wgpu-kernel", - "examples/guide", - "examples/image-classification-web", - "examples/mnist-inference-web", - "examples/mnist", - "examples/named-tensor", - "examples/onnx-inference", - "examples/text-classification", - "examples/text-generation", - "examples/train-web/train", - "xtask", -]; - // Handle child process fn handle_child_process(mut child: Child, error: &str) { // Wait for the child process to finish @@ -92,12 +74,12 @@ fn rustup(command: &str, target: &str) { fn run_cargo(command: &str, params: Params, error: &str, exclude: Vec<&str>) { // Separate out packages to prevent feature unification https://doc.rust-lang.org/cargo/reference/resolver.html#feature-resolver-version-2:~:text=When%20building%20multiple,separate%20cargo%20invocations. let packages = if params.params.contains(&String::from("-p")) { - vec!["."] + vec![String::from(".")] } else { - PACKAGES + get_package_dirs("./**/Cargo.toml") .iter() - .filter(|p| !exclude.contains(p)) - .cloned() + .filter(|p| !exclude.contains(&p.as_str())) + .map(String::from) .collect::>() }; @@ -432,25 +414,15 @@ fn check_typos() { fn check_examples() { println!("Checking examples compile \n\n"); - std::fs::read_dir("examples").unwrap().for_each(|dir| { - let dir = dir.unwrap(); - let path = dir.path(); - // Skip if not a directory - if !path.is_dir() { - return; - } - let dirname = path.file_name().unwrap().to_str().unwrap(); - if dirname == "notebook" || dirname == "train-web" { - // not a crate - return; - } - let path = path.to_str().unwrap(); - println!("Checking {path} \n\n"); + get_package_dirs("./examples/**/Cargo.toml") + .iter() + .for_each(|dir| { + println!("Checking {dir} \n\n"); - let child = Command::new("cargo") - .arg("check") + let child = Command::new("cargo") + .arg("check") .arg("--examples") - .current_dir(dir.path()) + .current_dir(dir) .stdout(Stdio::inherit()) // Send stdout directly to terminal .stderr(Stdio::inherit()) // Send stderr directly to terminal .spawn() From cfffb93f5e17a1548271b0d6f288148707841c94 Mon Sep 17 00:00:00 2001 From: Alex Date: Sun, 3 Dec 2023 17:00:18 -0600 Subject: [PATCH 53/79] oops --- examples/train-web/train/src/lib.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/examples/train-web/train/src/lib.rs b/examples/train-web/train/src/lib.rs index cfdf23eb05..d209244c9e 100644 --- a/examples/train-web/train/src/lib.rs +++ b/examples/train-web/train/src/lib.rs @@ -2,7 +2,6 @@ use crate::{model::ModelConfig, train::TrainingConfig}; use burn::{ backend::{ndarray::NdArrayDevice, Autodiff, NdArray}, optim::AdamConfig, - train::util, }; use wasm_bindgen::prelude::*; From d4106bb25b2a8c39f89242c4f9afc7d01b1aaf6a Mon Sep 17 00:00:00 2001 From: Alex Date: Sun, 3 Dec 2023 17:05:25 -0600 Subject: [PATCH 54/79] fmt --- xtask/src/runchecks.rs | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/xtask/src/runchecks.rs b/xtask/src/runchecks.rs index e48fdb5e9b..f8b0210e21 100644 --- a/xtask/src/runchecks.rs +++ b/xtask/src/runchecks.rs @@ -421,16 +421,16 @@ fn check_examples() { let child = Command::new("cargo") .arg("check") - .arg("--examples") + .arg("--examples") .current_dir(dir) - .stdout(Stdio::inherit()) // Send stdout directly to terminal - .stderr(Stdio::inherit()) // Send stderr directly to terminal - .spawn() - .expect("Failed to check examples"); - - // Handle typos child process - handle_child_process(child, "Failed to wait for examples child process"); - }); + .stdout(Stdio::inherit()) // Send stdout directly to terminal + .stderr(Stdio::inherit()) // Send stderr directly to terminal + .spawn() + .expect("Failed to check examples"); + + // Handle typos child process + handle_child_process(child, "Failed to wait for examples child process"); + }); } #[derive(clap::ValueEnum, Default, Copy, Clone, PartialEq, Eq)] From c84d171161134ecbcb2801f33419c1fa4cd11fd5 Mon Sep 17 00:00:00 2001 From: Alex Date: Sun, 3 Dec 2023 20:02:37 -0600 Subject: [PATCH 55/79] exclude `examples/train-web/train` less --- xtask/src/runchecks.rs | 32 ++++++++++++++++++++++---------- 1 file changed, 22 insertions(+), 10 deletions(-) diff --git a/xtask/src/runchecks.rs b/xtask/src/runchecks.rs index f8b0210e21..913a2cab71 100644 --- a/xtask/src/runchecks.rs +++ b/xtask/src/runchecks.rs @@ -73,7 +73,7 @@ fn rustup(command: &str, target: &str) { // Define and run a cargo command fn run_cargo(command: &str, params: Params, error: &str, exclude: Vec<&str>) { // Separate out packages to prevent feature unification https://doc.rust-lang.org/cargo/reference/resolver.html#feature-resolver-version-2:~:text=When%20building%20multiple,separate%20cargo%20invocations. - let packages = if params.params.contains(&String::from("-p")) { + let package_dirs = if params.params.contains(&String::from("-p")) { vec![String::from(".")] } else { get_package_dirs("./**/Cargo.toml") @@ -83,13 +83,29 @@ fn run_cargo(command: &str, params: Params, error: &str, exclude: Vec<&str>) { .collect::>() }; - for p in packages { + for package_dir in package_dirs { + let nightly = if package_dir == "examples/train-web/train" { + vec!["+nightly-2023-07-01"] + } else { + vec![] + }; // Print cargo command - println!("\n{} $ cargo {} {}\n", p, command, params); + println!( + "\n{} $ cargo{} {} {}\n", + package_dir, + if nightly.is_empty() { + "".to_string() + } else { + " ".to_string() + nightly[0] + }, + command, + params + ); // Run cargo let cargo = Command::new("cargo") .env("CARGO_INCREMENTAL", "0") - .current_dir(p) + .current_dir(package_dir) + .args(nightly) .arg(command) .args(¶ms.params) .stdout(Stdio::inherit()) // Send stdout directly to terminal @@ -346,19 +362,15 @@ fn std_checks() { ); // Build each package - let mut exclude = if disable_wgpu { + let exclude = if disable_wgpu { vec!["burn-wgpu"] } else { vec![] }; - exclude.push("examples/train-web/train"); cargo_build_excluding([].into(), exclude); // Produce documentation for each package - cargo_doc_excluding( - [].into(), - ["burn-no-std-tests", "examples/train-web/train"].into(), - ); + cargo_doc_excluding([].into(), ["burn-no-std-tests"].into()); // Setup code coverage if is_coverage { From 354150342460c9c2a7ddd7ee119e056a1dff34b9 Mon Sep 17 00:00:00 2001 From: Alex Date: Mon, 4 Dec 2023 11:59:58 -0600 Subject: [PATCH 56/79] I regret everything (in this file) --- xtask/src/runchecks.rs | 262 +++++++++++++++++------------------------ 1 file changed, 105 insertions(+), 157 deletions(-) diff --git a/xtask/src/runchecks.rs b/xtask/src/runchecks.rs index 913a2cab71..89c2249fd7 100644 --- a/xtask/src/runchecks.rs +++ b/xtask/src/runchecks.rs @@ -4,28 +4,15 @@ //! //! It is also used to check that the code is formatted correctly and passes clippy. -use glob::glob; +use crate::logging::init_logger; +use crate::utils::{format_duration, get_workspaces, WorkspaceMemberType}; +use crate::{endgroup, group}; use std::env; +use std::path::Path; use std::process::{Child, Command, Stdio}; use std::str; use std::time::Instant; -fn get_package_dirs(glob_str: &str) -> Vec { - glob(glob_str) - .expect("Failed to read glob pattern") - .filter_map(|dir| { - let dir = dir.unwrap(); - let dir = dir.as_path().to_str().unwrap(); - if dir == "Cargo.toml" { - None - } else { - let length = dir.len() - "/Cargo.toml".len(); - Some(dir.to_string()[..length].to_string()) - } - }) - .collect() -} - // Targets constants const WASM32_TARGET: &str = "wasm32-unknown-unknown"; const ARM_TARGET: &str = "thumbv7m-none-eabi"; @@ -46,7 +33,7 @@ fn handle_child_process(mut child: Child, error: &str) { // Run a command fn run_command(command: &str, args: &[&str], command_error: &str, child_error: &str) { // Format command - println!("{command} {}\n\n", args.join(" ")); + info!("{command} {}\n\n", args.join(" ")); // Run command as child process let command = Command::new(command) @@ -62,129 +49,94 @@ fn run_command(command: &str, args: &[&str], command_error: &str, child_error: & // Define and run rustup command fn rustup(command: &str, target: &str) { + group!("Rustup: {} add {}", command, target); run_command( "rustup", &[command, "add", target], "Failed to run rustup", "Failed to wait for rustup child process", - ) + ); + endgroup!(); } // Define and run a cargo command -fn run_cargo(command: &str, params: Params, error: &str, exclude: Vec<&str>) { - // Separate out packages to prevent feature unification https://doc.rust-lang.org/cargo/reference/resolver.html#feature-resolver-version-2:~:text=When%20building%20multiple,separate%20cargo%20invocations. - let package_dirs = if params.params.contains(&String::from("-p")) { - vec![String::from(".")] - } else { - get_package_dirs("./**/Cargo.toml") - .iter() - .filter(|p| !exclude.contains(&p.as_str())) - .map(String::from) - .collect::>() - }; - - for package_dir in package_dirs { - let nightly = if package_dir == "examples/train-web/train" { - vec!["+nightly-2023-07-01"] - } else { - vec![] - }; - // Print cargo command - println!( - "\n{} $ cargo{} {} {}\n", - package_dir, - if nightly.is_empty() { - "".to_string() - } else { - " ".to_string() + nightly[0] - }, - command, - params - ); - // Run cargo - let cargo = Command::new("cargo") - .env("CARGO_INCREMENTAL", "0") - .current_dir(package_dir) - .args(nightly) - .arg(command) - .args(¶ms.params) - .stdout(Stdio::inherit()) // Send stdout directly to terminal - .stderr(Stdio::inherit()) // Send stderr directly to terminal - .spawn() - .expect(error); - - // Handle cargo child process - handle_child_process(cargo, "Failed to wait for cargo child process"); +fn run_cargo(command: &str, params: Params, error: &str) { + run_cargo_with_path::(command, params, None, error) +} + +// Define and run a cargo command with curr dir +fn run_cargo_with_path>( + command: &str, + params: Params, + path: Option

, + error: &str, +) { + // Print cargo command + info!("cargo {} {}\n", command, params); + + // Run cargo + let mut cargo = Command::new("cargo"); + cargo + .env("CARGO_INCREMENTAL", "0") + .arg(command) + .args(params.params) + .stdout(Stdio::inherit()) // Send stdout directly to terminal + .stderr(Stdio::inherit()); // Send stderr directly to terminal + + if let Some(path) = path { + cargo.current_dir(path); } + + let cargo_process = cargo.spawn().expect(error); + + // Handle cargo child process + handle_child_process(cargo_process, "Failed to wait for cargo child process"); } // Run cargo build command -fn cargo_build_excluding(params: Params, exclude: Vec<&str>) { +fn cargo_build(params: Params) { // Run cargo build run_cargo( "build", params + "--color=always", "Failed to run cargo build", - exclude, ); } -// Run cargo build command -fn cargo_build(params: Params) { - cargo_build_excluding(params, [].into()) -} - // Run cargo install command -fn cargo_install_excluding(params: Params, exclude: Vec<&str>) { +fn cargo_install(params: Params) { // Run cargo install run_cargo( "install", params + "--color=always", "Failed to run cargo install", - exclude, ); } -// Run cargo install command -fn cargo_install(params: Params) { - cargo_install_excluding(params, [].into()) -} - // Run cargo test command -fn cargo_test_excluding(params: Params, exclude: Vec<&str>) { +fn cargo_test(params: Params) { // Run cargo test run_cargo( "test", params + "--color=always" + "--" + "--color=always", "Failed to run cargo test", - exclude, ); } -// Run cargo test command -fn cargo_test(params: Params) { - cargo_test_excluding(params, [].into()) -} - // Run cargo fmt command -fn cargo_fmt_excluding(exclude: Vec<&str>) { - // Run cargo fmt +fn cargo_fmt() { + group!("Cargo: fmt"); run_cargo( "fmt", ["--check", "--all", "--", "--color=always"].into(), "Failed to run cargo fmt", - exclude, ); -} - -// Run cargo fmt command -fn cargo_fmt() { - cargo_fmt_excluding([].into()) + endgroup!(); } // Run cargo clippy command -fn cargo_clippy_excluding(exclude: Vec<&str>) { - if std::env::var("CI_RUN").is_ok() { +fn cargo_clippy() { + if std::env::var("CI").is_ok() { return; } // Run cargo clippy @@ -192,28 +144,18 @@ fn cargo_clippy_excluding(exclude: Vec<&str>) { "clippy", ["--color=always", "--all-targets", "--", "-D", "warnings"].into(), "Failed to run cargo clippy", - exclude, ); } // Run cargo doc command -fn cargo_doc_excluding(params: Params, exclude: Vec<&str>) { - // Run cargo doc - run_cargo( - "doc", - params + "--no-deps" + "--color=always", - "Failed to run cargo doc", - exclude, - ); -} - fn cargo_doc(params: Params) { - cargo_doc_excluding(params, [].into()) + // Run cargo doc + run_cargo("doc", params + "--color=always", "Failed to run cargo doc"); } // Build and test a crate in a no_std environment fn build_and_test_no_std(crate_name: &str, extra_args: [&str; N]) { - println!("\nRun checks for `{}` crate", crate_name); + group!("Checks: {} (no-std)", crate_name); // Run cargo build --no-default-features cargo_build(Params::from(["-p", crate_name, "--no-default-features"]) + extra_args); @@ -242,6 +184,8 @@ fn build_and_test_no_std(crate_name: &str, extra_args: [&str; N] ARM_TARGET, ]) + extra_args, ); + + endgroup!(); } // Setup code coverage @@ -281,8 +225,6 @@ fn run_grcov() { // Run no_std checks fn no_std_checks() { - println!("Checks for no_std environment...\n\n"); - // Install wasm32 target rustup("target", WASM32_TARGET); @@ -304,20 +246,22 @@ fn no_std_checks() { // Test burn-core with tch and wgpu backend fn burn_core_std() { - println!("\n\nRun checks for burn-core crate with tch and wgpu backend"); - // Run cargo test --features test-tch + group!("Test: burn-core (tch)"); cargo_test(["-p", "burn-core", "--features", "test-tch"].into()); + endgroup!(); // Run cargo test --features test-wgpu if std::env::var("DISABLE_WGPU").is_err() { + group!("Test: burn-core (wgpu)"); cargo_test(["-p", "burn-core", "--features", "test-wgpu"].into()); + endgroup!(); } } // Test burn-dataset features fn burn_dataset_features_std() { - println!("\n\nRun checks for burn-dataset features"); + group!("Checks: burn-dataset (all-features)"); // Run cargo build --all-features cargo_build(["-p", "burn-dataset", "--all-features"].into()); @@ -327,13 +271,17 @@ fn burn_dataset_features_std() { // Run cargo doc --all-features cargo_doc(["-p", "burn-dataset", "--all-features"].into()); + + endgroup!(); } // Test burn-candle with accelerate (macOS only) // Leverages the macOS Accelerate framework: https://developer.apple.com/documentation/accelerate #[cfg(target_os = "macos")] fn burn_candle_accelerate() { + group!("Checks: burn-candle (accelerate)"); cargo_test(["-p", "burn-candle", "--features", "accelerate"].into()); + endgroup!(); } fn std_checks() { @@ -345,40 +293,38 @@ fn std_checks() { let is_coverage = std::env::var("COVERAGE").is_ok(); let disable_wgpu = std::env::var("DISABLE_WGPU").is_ok(); - println!("Running std checks"); - // Check format cargo_fmt(); // Check clippy lints - cargo_clippy_excluding( - [ - "burn-no-std-tests", - "examples/image-classification-web", - "examples/mnist-inference-web", - "examples/train-web/train", - ] - .into(), - ); - - // Build each package - let exclude = if disable_wgpu { - vec!["burn-wgpu"] - } else { - vec![] - }; - cargo_build_excluding([].into(), exclude); + cargo_clippy(); - // Produce documentation for each package - cargo_doc_excluding([].into(), ["burn-no-std-tests"].into()); + // Produce documentation for each workspace + group!("Docs: workspaces"); + cargo_doc(["--workspace"].into()); + endgroup!(); // Setup code coverage if is_coverage { setup_coverage(); } - // Test each package - cargo_test_excluding([].into(), ["examples/train-web/train"].into()); + // Build & test each workspace + let workspaces = get_workspaces(WorkspaceMemberType::Crate); + for workspace in workspaces { + if disable_wgpu && workspace.name == "burn-wgpu" { + continue; + } + + if workspace.name == "burn-tch" { + continue; + } + + group!("Checks: {}", workspace.name); + cargo_build(Params::from(["-p", &workspace.name])); + cargo_test(Params::from(["-p", &workspace.name])); + endgroup!(); + } // Test burn-candle with accelerate (macOS only) #[cfg(target_os = "macos")] @@ -405,12 +351,12 @@ fn check_typos() { // Do not run cargo install on CI to speed up the computation. // Check whether the file has been installed on - if std::env::var("CI_RUN").is_err() && !typos_cli_path.exists() { + if std::env::var("CI").is_err() && !typos_cli_path.exists() { // Install typos-cli cargo_install(["typos-cli", "--version", "1.16.5"].into()); } - println!("Running typos check \n\n"); + info!("Running typos check \n\n"); // Run typos command as child process let typos = Command::new("typos") @@ -424,25 +370,21 @@ fn check_typos() { } fn check_examples() { - println!("Checking examples compile \n\n"); - - get_package_dirs("./examples/**/Cargo.toml") - .iter() - .for_each(|dir| { - println!("Checking {dir} \n\n"); - - let child = Command::new("cargo") - .arg("check") - .arg("--examples") - .current_dir(dir) - .stdout(Stdio::inherit()) // Send stdout directly to terminal - .stderr(Stdio::inherit()) // Send stderr directly to terminal - .spawn() - .expect("Failed to check examples"); - - // Handle typos child process - handle_child_process(child, "Failed to wait for examples child process"); - }); + let workspaces = get_workspaces(WorkspaceMemberType::Example); + for workspace in workspaces { + if workspace.name == "notebook" { + continue; + } + + group!("Checks: Example - {}", workspace.name); + run_cargo_with_path( + "check", + ["--examples"].into(), + Some(workspace.path), + "Failed to check example", + ); + endgroup!(); + } } #[derive(clap::ValueEnum, Default, Copy, Clone, PartialEq, Eq)] @@ -461,6 +403,9 @@ pub enum CheckType { } pub fn run(env: CheckType) -> anyhow::Result<()> { + // Setup logger + init_logger().init(); + // Start time measurement let start = Instant::now(); @@ -491,7 +436,10 @@ pub fn run(env: CheckType) -> anyhow::Result<()> { let duration = start.elapsed(); // Print duration - println!("Time elapsed for the current execution: {:?}", duration); + info!( + "\x1B[32;1mTime elapsed for the current execution: {}\x1B[0m", + format_duration(&duration) + ); Ok(()) } From 15af6fd9fbcfacd7dc78a2daf62eb261f4ee4242 Mon Sep 17 00:00:00 2001 From: Alex Date: Mon, 4 Dec 2023 16:37:38 -0600 Subject: [PATCH 57/79] fix ci --- examples/train-web/train/Cargo.toml | 2 +- xtask/src/runchecks.rs | 70 ++++++++++++++++++++++++----- xtask/src/utils.rs | 28 +++++------- 3 files changed, 73 insertions(+), 27 deletions(-) diff --git a/examples/train-web/train/Cargo.toml b/examples/train-web/train/Cargo.toml index cfbdb78a46..456dd569a8 100644 --- a/examples/train-web/train/Cargo.toml +++ b/examples/train-web/train/Cargo.toml @@ -1,5 +1,5 @@ [package] -name = "train" +name = "train-web" version = "0.1.0" edition = "2021" diff --git a/xtask/src/runchecks.rs b/xtask/src/runchecks.rs index 89c2249fd7..9558cbf6e0 100644 --- a/xtask/src/runchecks.rs +++ b/xtask/src/runchecks.rs @@ -71,13 +71,41 @@ fn run_cargo_with_path>( path: Option

, error: &str, ) { + let nightly = match path.as_ref() { + None => vec![], + Some(p) => { + if p.as_ref() + .to_str() + .unwrap() + .ends_with("examples/train-web/train") + { + vec!["+nightly-2023-07-01"] + } else { + vec![] + } + } + }; // Print cargo command - info!("cargo {} {}\n", command, params); + info!( + "{}cargo{} {} {}\n", + match path.as_ref() { + None => "".to_string(), + Some(p) => p.as_ref().to_str().unwrap().to_string() + " $ ", + }, + if nightly.is_empty() { + "".to_string() + } else { + " ".to_string() + nightly[0] + }, + command, + params + ); // Run cargo let mut cargo = Command::new("cargo"); cargo .env("CARGO_INCREMENTAL", "0") + .args(nightly) .arg(command) .args(params.params) .stdout(Stdio::inherit()) // Send stdout directly to terminal @@ -139,18 +167,39 @@ fn cargo_clippy() { if std::env::var("CI").is_ok() { return; } - // Run cargo clippy - run_cargo( - "clippy", - ["--color=always", "--all-targets", "--", "-D", "warnings"].into(), - "Failed to run cargo clippy", - ); + for workspace in get_workspaces(WorkspaceMemberType::Both) { + if vec![ + "burn-no-std-tests", + "image-classification-web", + "mnist-inference-web", + "train-web", + ] + .into_iter() + .collect::() + .contains(&workspace.name) + { + continue; + } + // Run cargo clippy + run_cargo_with_path( + "clippy", + ["--color=always", "--all-targets", "--", "-D", "warnings"].into(), + Some(workspace.path), + "Failed to run cargo clippy", + ); + } } // Run cargo doc command fn cargo_doc(params: Params) { - // Run cargo doc - run_cargo("doc", params + "--color=always", "Failed to run cargo doc"); + for workspace in get_workspaces(WorkspaceMemberType::Both) { + run_cargo_with_path( + "doc", + (params.clone() + ["--no-deps", "--color=always"]).into(), + Some(workspace.path), + "Failed to run cargo doc", + ); + } } // Build and test a crate in a no_std environment @@ -301,7 +350,7 @@ fn std_checks() { // Produce documentation for each workspace group!("Docs: workspaces"); - cargo_doc(["--workspace"].into()); + cargo_doc([].into()); endgroup!(); // Setup code coverage @@ -444,6 +493,7 @@ pub fn run(env: CheckType) -> anyhow::Result<()> { Ok(()) } +#[derive(Clone)] struct Params { params: Vec, } diff --git a/xtask/src/utils.rs b/xtask/src/utils.rs index 107852d678..6dcb709d68 100644 --- a/xtask/src/utils.rs +++ b/xtask/src/utils.rs @@ -4,6 +4,7 @@ use std::{process::Command, time::Duration}; pub(crate) enum WorkspaceMemberType { Crate, Example, + Both, } #[derive(Debug)] @@ -39,25 +40,20 @@ pub(crate) fn get_workspaces(w_type: WorkspaceMemberType) -> Vec - { - Some(WorkspaceMember::new( - workspace_name.to_string(), - workspace_path.to_string(), - )) - } - WorkspaceMemberType::Example - if workspace_name != "xtask" && workspace_path.contains("examples/") => - { - Some(WorkspaceMember::new( - workspace_name.to_string(), - workspace_path.to_string(), - )) - } + WorkspaceMemberType::Crate if !workspace_path.contains("examples/") => member, + WorkspaceMemberType::Example if workspace_path.contains("examples/") => member, + WorkspaceMemberType::Both => member, _ => None, } }) From 49e2c3849c03f90f17eb53a5007a4672f199e54a Mon Sep 17 00:00:00 2001 From: Alex Date: Mon, 4 Dec 2023 17:19:59 -0600 Subject: [PATCH 58/79] rustup component add rust-src --toolchain nightly-2023-07-01-x86_64-unknown-linux-gnu --- xtask/src/runchecks.rs | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/xtask/src/runchecks.rs b/xtask/src/runchecks.rs index 9558cbf6e0..060e0165c2 100644 --- a/xtask/src/runchecks.rs +++ b/xtask/src/runchecks.rs @@ -48,11 +48,11 @@ fn run_command(command: &str, args: &[&str], command_error: &str, child_error: & } // Define and run rustup command -fn rustup(command: &str, target: &str) { - group!("Rustup: {} add {}", command, target); +fn rustup(args: &[&str]) { + group!("rustup {}", args.join(" ")); run_command( "rustup", - &[command, "add", target], + args, "Failed to run rustup", "Failed to wait for rustup child process", ); @@ -240,7 +240,7 @@ fn build_and_test_no_std(crate_name: &str, extra_args: [&str; N] // Setup code coverage fn setup_coverage() { // Install llvm-tools-preview - rustup("component", "llvm-tools-preview"); + rustup(&["component", "add", "llvm-tools-preview"]); // Set coverage environment variables env::set_var("RUSTFLAGS", "-Cinstrument-coverage"); @@ -275,10 +275,10 @@ fn run_grcov() { // Run no_std checks fn no_std_checks() { // Install wasm32 target - rustup("target", WASM32_TARGET); + rustup(&["target", "add", WASM32_TARGET]); // Install ARM target - rustup("target", ARM_TARGET); + rustup(&["target", "add", ARM_TARGET]); // Run checks for the following crates build_and_test_no_std("burn", []); @@ -419,6 +419,13 @@ fn check_typos() { } fn check_examples() { + rustup(&[ + "component", + "add", + "rust-src", + "--toolchain", + "nightly-2023-07-01-x86_64-unknown-linux-gnu", + ]); let workspaces = get_workspaces(WorkspaceMemberType::Example); for workspace in workspaces { if workspace.name == "notebook" { From 519b7520e6c8bf9caab4898ff1d5d87505592293 Mon Sep 17 00:00:00 2001 From: Alex Date: Mon, 4 Dec 2023 17:47:54 -0600 Subject: [PATCH 59/79] 0 --- xtask/src/runchecks.rs | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/xtask/src/runchecks.rs b/xtask/src/runchecks.rs index 060e0165c2..40eec8ccd7 100644 --- a/xtask/src/runchecks.rs +++ b/xtask/src/runchecks.rs @@ -419,13 +419,7 @@ fn check_typos() { } fn check_examples() { - rustup(&[ - "component", - "add", - "rust-src", - "--toolchain", - "nightly-2023-07-01-x86_64-unknown-linux-gnu", - ]); + rustup(&["default", "nightly-2023-07-01"]); let workspaces = get_workspaces(WorkspaceMemberType::Example); for workspace in workspaces { if workspace.name == "notebook" { From 6d8c7213cb16afcce4c9f03afbec3ae563f3e71c Mon Sep 17 00:00:00 2001 From: Alex Date: Mon, 4 Dec 2023 17:47:54 -0600 Subject: [PATCH 60/79] Revert "0" This reverts commit 519b7520e6c8bf9caab4898ff1d5d87505592293. --- xtask/src/runchecks.rs | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/xtask/src/runchecks.rs b/xtask/src/runchecks.rs index 40eec8ccd7..060e0165c2 100644 --- a/xtask/src/runchecks.rs +++ b/xtask/src/runchecks.rs @@ -419,7 +419,13 @@ fn check_typos() { } fn check_examples() { - rustup(&["default", "nightly-2023-07-01"]); + rustup(&[ + "component", + "add", + "rust-src", + "--toolchain", + "nightly-2023-07-01-x86_64-unknown-linux-gnu", + ]); let workspaces = get_workspaces(WorkspaceMemberType::Example); for workspace in workspaces { if workspace.name == "notebook" { From c025d358cf464ff332d8d484d73c891d3b718580 Mon Sep 17 00:00:00 2001 From: Alex Date: Mon, 4 Dec 2023 17:19:59 -0600 Subject: [PATCH 61/79] Revert "rustup component add rust-src --toolchain nightly-2023-07-01-x86_64-unknown-linux-gnu" This reverts commit 49e2c3849c03f90f17eb53a5007a4672f199e54a. --- xtask/src/runchecks.rs | 19 ++++++------------- 1 file changed, 6 insertions(+), 13 deletions(-) diff --git a/xtask/src/runchecks.rs b/xtask/src/runchecks.rs index 060e0165c2..9558cbf6e0 100644 --- a/xtask/src/runchecks.rs +++ b/xtask/src/runchecks.rs @@ -48,11 +48,11 @@ fn run_command(command: &str, args: &[&str], command_error: &str, child_error: & } // Define and run rustup command -fn rustup(args: &[&str]) { - group!("rustup {}", args.join(" ")); +fn rustup(command: &str, target: &str) { + group!("Rustup: {} add {}", command, target); run_command( "rustup", - args, + &[command, "add", target], "Failed to run rustup", "Failed to wait for rustup child process", ); @@ -240,7 +240,7 @@ fn build_and_test_no_std(crate_name: &str, extra_args: [&str; N] // Setup code coverage fn setup_coverage() { // Install llvm-tools-preview - rustup(&["component", "add", "llvm-tools-preview"]); + rustup("component", "llvm-tools-preview"); // Set coverage environment variables env::set_var("RUSTFLAGS", "-Cinstrument-coverage"); @@ -275,10 +275,10 @@ fn run_grcov() { // Run no_std checks fn no_std_checks() { // Install wasm32 target - rustup(&["target", "add", WASM32_TARGET]); + rustup("target", WASM32_TARGET); // Install ARM target - rustup(&["target", "add", ARM_TARGET]); + rustup("target", ARM_TARGET); // Run checks for the following crates build_and_test_no_std("burn", []); @@ -419,13 +419,6 @@ fn check_typos() { } fn check_examples() { - rustup(&[ - "component", - "add", - "rust-src", - "--toolchain", - "nightly-2023-07-01-x86_64-unknown-linux-gnu", - ]); let workspaces = get_workspaces(WorkspaceMemberType::Example); for workspace in workspaces { if workspace.name == "notebook" { From b0fbeefbe8660fbee89247757b9d172642fe404a Mon Sep 17 00:00:00 2001 From: Alex Date: Mon, 4 Dec 2023 18:07:26 -0600 Subject: [PATCH 62/79] =?UTF-8?q?can=20I=20just=20get=20a=20=E2=9C=94=20pl?= =?UTF-8?q?ease?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- xtask/src/runchecks.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xtask/src/runchecks.rs b/xtask/src/runchecks.rs index 9558cbf6e0..259d7203cd 100644 --- a/xtask/src/runchecks.rs +++ b/xtask/src/runchecks.rs @@ -421,7 +421,7 @@ fn check_typos() { fn check_examples() { let workspaces = get_workspaces(WorkspaceMemberType::Example); for workspace in workspaces { - if workspace.name == "notebook" { + if workspace.name == "notebook" || workspace.name == "train-web" { continue; } From 57472fe9aaba2bf976a58a5644556f3e81e08719 Mon Sep 17 00:00:00 2001 From: Alex Date: Mon, 4 Dec 2023 17:19:59 -0600 Subject: [PATCH 63/79] Revert "Revert "rustup component add rust-src --toolchain nightly-2023-07-01-x86_64-unknown-linux-gnu"" This reverts commit c025d358cf464ff332d8d484d73c891d3b718580. --- xtask/src/runchecks.rs | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/xtask/src/runchecks.rs b/xtask/src/runchecks.rs index 259d7203cd..aa98564be0 100644 --- a/xtask/src/runchecks.rs +++ b/xtask/src/runchecks.rs @@ -48,11 +48,11 @@ fn run_command(command: &str, args: &[&str], command_error: &str, child_error: & } // Define and run rustup command -fn rustup(command: &str, target: &str) { - group!("Rustup: {} add {}", command, target); +fn rustup(args: &[&str]) { + group!("rustup {}", args.join(" ")); run_command( "rustup", - &[command, "add", target], + args, "Failed to run rustup", "Failed to wait for rustup child process", ); @@ -240,7 +240,7 @@ fn build_and_test_no_std(crate_name: &str, extra_args: [&str; N] // Setup code coverage fn setup_coverage() { // Install llvm-tools-preview - rustup("component", "llvm-tools-preview"); + rustup(&["component", "add", "llvm-tools-preview"]); // Set coverage environment variables env::set_var("RUSTFLAGS", "-Cinstrument-coverage"); @@ -275,10 +275,10 @@ fn run_grcov() { // Run no_std checks fn no_std_checks() { // Install wasm32 target - rustup("target", WASM32_TARGET); + rustup(&["target", "add", WASM32_TARGET]); // Install ARM target - rustup("target", ARM_TARGET); + rustup(&["target", "add", ARM_TARGET]); // Run checks for the following crates build_and_test_no_std("burn", []); @@ -419,6 +419,13 @@ fn check_typos() { } fn check_examples() { + rustup(&[ + "component", + "add", + "rust-src", + "--toolchain", + "nightly-2023-07-01-x86_64-unknown-linux-gnu", + ]); let workspaces = get_workspaces(WorkspaceMemberType::Example); for workspace in workspaces { if workspace.name == "notebook" || workspace.name == "train-web" { From 8fa5a0396501a589e992441b2144630415c7e590 Mon Sep 17 00:00:00 2001 From: Alex Date: Mon, 4 Dec 2023 19:29:28 -0600 Subject: [PATCH 64/79] not always linux --- xtask/src/runchecks.rs | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/xtask/src/runchecks.rs b/xtask/src/runchecks.rs index aa98564be0..d46c06f0f6 100644 --- a/xtask/src/runchecks.rs +++ b/xtask/src/runchecks.rs @@ -419,13 +419,7 @@ fn check_typos() { } fn check_examples() { - rustup(&[ - "component", - "add", - "rust-src", - "--toolchain", - "nightly-2023-07-01-x86_64-unknown-linux-gnu", - ]); + rustup(&["+nightly-2023-07-01", "component", "add", "rust-src"]); let workspaces = get_workspaces(WorkspaceMemberType::Example); for workspace in workspaces { if workspace.name == "notebook" || workspace.name == "train-web" { From 1b6d619df99c73e38d118c673ded269b4f150b2c Mon Sep 17 00:00:00 2001 From: Alex Date: Mon, 4 Dec 2023 19:50:54 -0600 Subject: [PATCH 65/79] 1 --- xtask/src/runchecks.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/xtask/src/runchecks.rs b/xtask/src/runchecks.rs index d46c06f0f6..d456202fdc 100644 --- a/xtask/src/runchecks.rs +++ b/xtask/src/runchecks.rs @@ -79,6 +79,7 @@ fn run_cargo_with_path>( .unwrap() .ends_with("examples/train-web/train") { + rustup(&["+nightly-2023-07-01", "component", "add", "rust-src"]); vec!["+nightly-2023-07-01"] } else { vec![] From c8575e368a316cd296c80cb2d4c5f8436ffe867c Mon Sep 17 00:00:00 2001 From: Alex Date: Mon, 4 Dec 2023 20:10:19 -0600 Subject: [PATCH 66/79] 2 --- xtask/src/runchecks.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xtask/src/runchecks.rs b/xtask/src/runchecks.rs index d456202fdc..4c383dcc5a 100644 --- a/xtask/src/runchecks.rs +++ b/xtask/src/runchecks.rs @@ -193,7 +193,7 @@ fn cargo_clippy() { // Run cargo doc command fn cargo_doc(params: Params) { - for workspace in get_workspaces(WorkspaceMemberType::Both) { + for workspace in get_workspaces(WorkspaceMemberType::Crate) { run_cargo_with_path( "doc", (params.clone() + ["--no-deps", "--color=always"]).into(), From de55ae8edf4573a27cbe40813248ea94c6a4d3b6 Mon Sep 17 00:00:00 2001 From: Alex Date: Mon, 4 Dec 2023 18:07:26 -0600 Subject: [PATCH 67/79] =?UTF-8?q?Revert=20"can=20I=20just=20get=20a=20?= =?UTF-8?q?=E2=9C=94=20please"?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This reverts commit b0fbeefbe8660fbee89247757b9d172642fe404a. --- xtask/src/runchecks.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xtask/src/runchecks.rs b/xtask/src/runchecks.rs index 4c383dcc5a..12632219fc 100644 --- a/xtask/src/runchecks.rs +++ b/xtask/src/runchecks.rs @@ -423,7 +423,7 @@ fn check_examples() { rustup(&["+nightly-2023-07-01", "component", "add", "rust-src"]); let workspaces = get_workspaces(WorkspaceMemberType::Example); for workspace in workspaces { - if workspace.name == "notebook" || workspace.name == "train-web" { + if workspace.name == "notebook" { continue; } From bebab5eaba213876c5de1fe1e2c078372b1850bc Mon Sep 17 00:00:00 2001 From: Alex Date: Tue, 5 Dec 2023 11:31:47 -0600 Subject: [PATCH 68/79] 3 --- xtask/src/runchecks.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xtask/src/runchecks.rs b/xtask/src/runchecks.rs index 12632219fc..c57c52f673 100644 --- a/xtask/src/runchecks.rs +++ b/xtask/src/runchecks.rs @@ -430,7 +430,7 @@ fn check_examples() { group!("Checks: Example - {}", workspace.name); run_cargo_with_path( "check", - ["--examples"].into(), + [].into(), Some(workspace.path), "Failed to check example", ); From 64a11b5144e682cac3b35e5bf36bf0ff59a4e789 Mon Sep 17 00:00:00 2001 From: Alex Date: Tue, 5 Dec 2023 11:50:43 -0600 Subject: [PATCH 69/79] 4 --- xtask/src/runchecks.rs | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/xtask/src/runchecks.rs b/xtask/src/runchecks.rs index c57c52f673..6efe16f0ef 100644 --- a/xtask/src/runchecks.rs +++ b/xtask/src/runchecks.rs @@ -421,6 +421,13 @@ fn check_typos() { fn check_examples() { rustup(&["+nightly-2023-07-01", "component", "add", "rust-src"]); + rustup(&[ + "component", + "add", + "rust-src", + "--toolchain", + "nightly-2023-07-01", + ]); let workspaces = get_workspaces(WorkspaceMemberType::Example); for workspace in workspaces { if workspace.name == "notebook" { From 944e53ad9772666950af0dbac46b9b53154188f5 Mon Sep 17 00:00:00 2001 From: Alex Date: Tue, 5 Dec 2023 12:09:02 -0600 Subject: [PATCH 70/79] 5 --- xtask/src/runchecks.rs | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/xtask/src/runchecks.rs b/xtask/src/runchecks.rs index 6efe16f0ef..6abc7a9b5f 100644 --- a/xtask/src/runchecks.rs +++ b/xtask/src/runchecks.rs @@ -79,7 +79,13 @@ fn run_cargo_with_path>( .unwrap() .ends_with("examples/train-web/train") { - rustup(&["+nightly-2023-07-01", "component", "add", "rust-src"]); + rustup(&[ + "component", + "add", + "rust-src", + "--toolchain", + "nightly-2023-07-01", + ]); vec!["+nightly-2023-07-01"] } else { vec![] @@ -420,7 +426,6 @@ fn check_typos() { } fn check_examples() { - rustup(&["+nightly-2023-07-01", "component", "add", "rust-src"]); rustup(&[ "component", "add", From 1bddd635c50a5c97447ba61bae5d616defe6ce91 Mon Sep 17 00:00:00 2001 From: Alex Date: Tue, 5 Dec 2023 12:30:55 -0600 Subject: [PATCH 71/79] 6 --- xtask/src/runchecks.rs | 26 ++++++++++++-------------- 1 file changed, 12 insertions(+), 14 deletions(-) diff --git a/xtask/src/runchecks.rs b/xtask/src/runchecks.rs index 6abc7a9b5f..47db57f2fe 100644 --- a/xtask/src/runchecks.rs +++ b/xtask/src/runchecks.rs @@ -59,6 +59,16 @@ fn rustup(args: &[&str]) { endgroup!(); } +fn add_nightly() { + rustup(&[ + "component", + "add", + "rust-src", + "--toolchain", + "nightly-2023-07-01", + ]) +} + // Define and run a cargo command fn run_cargo(command: &str, params: Params, error: &str) { run_cargo_with_path::(command, params, None, error) @@ -79,13 +89,7 @@ fn run_cargo_with_path>( .unwrap() .ends_with("examples/train-web/train") { - rustup(&[ - "component", - "add", - "rust-src", - "--toolchain", - "nightly-2023-07-01", - ]); + add_nightly(); vec!["+nightly-2023-07-01"] } else { vec![] @@ -426,13 +430,7 @@ fn check_typos() { } fn check_examples() { - rustup(&[ - "component", - "add", - "rust-src", - "--toolchain", - "nightly-2023-07-01", - ]); + add_nightly(); let workspaces = get_workspaces(WorkspaceMemberType::Example); for workspace in workspaces { if workspace.name == "notebook" { From 218e17d7d3a97801dc1f56c9cadb4e9950e94936 Mon Sep 17 00:00:00 2001 From: Alex Date: Sat, 23 Dec 2023 20:09:05 -0600 Subject: [PATCH 72/79] fix breaking changes --- examples/train-web/train/src/data.rs | 4 ++-- examples/train-web/train/src/model.rs | 10 +++++----- examples/train-web/train/src/train.rs | 4 ++-- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/examples/train-web/train/src/data.rs b/examples/train-web/train/src/data.rs index 3613940c1e..cc6d87e5cb 100644 --- a/examples/train-web/train/src/data.rs +++ b/examples/train-web/train/src/data.rs @@ -25,7 +25,7 @@ impl Batcher> for MNISTBatcher { let images = items .iter() .map(|item| Data::::from(item.image)) - .map(|data| Tensor::::from_data(data.convert())) + .map(|data| Tensor::::from_data_devauto(data.convert())) .map(|tensor| tensor.reshape([1, 28, 28])) // normalize: make between [0,1] and make the mean = 0 and std = 1 // values mean=0.1307,std=0.3081 were copied from Pytorch Mist Example @@ -35,7 +35,7 @@ impl Batcher> for MNISTBatcher { let targets = items .iter() - .map(|item| Tensor::::from_data([(item.label as i64).elem()])) + .map(|item| Tensor::::from_data_devauto([(item.label as i64).elem()])) .collect(); let images = Tensor::cat(images, 0).to_device(&self.device); diff --git a/examples/train-web/train/src/model.rs b/examples/train-web/train/src/model.rs index cce581157d..f7bcc96d9a 100644 --- a/examples/train-web/train/src/model.rs +++ b/examples/train-web/train/src/model.rs @@ -30,14 +30,14 @@ pub struct ModelConfig { impl ModelConfig { /// Returns the initialized model. - pub fn init(&self) -> Model { + pub fn init(&self, device: &B::Device) -> Model { Model { - conv1: Conv2dConfig::new([1, 8], [3, 3]).init(), - conv2: Conv2dConfig::new([8, 16], [3, 3]).init(), + conv1: Conv2dConfig::new([1, 8], [3, 3]).init(device), + conv2: Conv2dConfig::new([8, 16], [3, 3]).init(device), pool: AdaptiveAvgPool2dConfig::new([8, 8]).init(), activation: ReLU::new(), - linear1: LinearConfig::new(16 * 8 * 8, self.hidden_size).init(), - linear2: LinearConfig::new(self.hidden_size, self.num_classes).init(), + linear1: LinearConfig::new(16 * 8 * 8, self.hidden_size).init(device), + linear2: LinearConfig::new(self.hidden_size, self.num_classes).init(device), dropout: DropoutConfig::new(self.dropout).init(), } } diff --git a/examples/train-web/train/src/train.rs b/examples/train-web/train/src/train.rs index f1f84efebf..4ab5dc9c9d 100644 --- a/examples/train-web/train/src/train.rs +++ b/examples/train-web/train/src/train.rs @@ -127,12 +127,12 @@ pub fn train( .metric_valid_numeric(LossMetric::new()) // .with_file_checkpointer(CompactRecorder::new()) .log_to_file(false) - .devices(vec![device]) + .devices(vec![device.clone()]) .num_epochs(config.num_epochs) .renderer(CustomRenderer {}) .num_epochs(config.num_epochs) .build( - config.model.init::(), + config.model.init::(&device), config.optimizer.init(), config.learning_rate, ); From b96df97468dd935d54b2acfbb4726dd8649ed1f3 Mon Sep 17 00:00:00 2001 From: Alex Date: Sun, 24 Dec 2023 10:18:09 -0600 Subject: [PATCH 73/79] make windows work? --- xtask/src/runchecks.rs | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/xtask/src/runchecks.rs b/xtask/src/runchecks.rs index a23e5cdd7c..7d210a8163 100644 --- a/xtask/src/runchecks.rs +++ b/xtask/src/runchecks.rs @@ -112,11 +112,13 @@ fn run_cargo_with_path>( params ); - // Run cargo - let mut cargo = Command::new("cargo"); + // Run cargo through rustup + let mut cargo = Command::new("rustup"); cargo .env("CARGO_INCREMENTAL", "0") + .arg("run") .args(nightly) + .arg("cargo") .arg(command) .args(params.params) .stdout(Stdio::inherit()) // Send stdout directly to terminal From d3cabca7e37c1d3b270cfba2c3fb12f039e38619 Mon Sep 17 00:00:00 2001 From: Alex Date: Sun, 28 Jan 2024 16:49:03 -0600 Subject: [PATCH 74/79] reject CI changes --- .github/workflows/test.yml | 6 ------ xtask/src/runchecks.rs | 1 - 2 files changed, 7 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 12e6f9655a..2ada121e21 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -93,12 +93,6 @@ jobs: - name: checkout uses: actions/checkout@v4 - - name: install nightly rust - uses: dtolnay/rust-toolchain@master - with: - components: rustfmt, clippy - toolchain: nightly-2023-07-01 - - name: install rust uses: dtolnay/rust-toolchain@master with: diff --git a/xtask/src/runchecks.rs b/xtask/src/runchecks.rs index 788ed6f978..36e566fbb5 100644 --- a/xtask/src/runchecks.rs +++ b/xtask/src/runchecks.rs @@ -381,7 +381,6 @@ fn check_typos() { } fn check_examples() { - add_nightly(); let workspaces = get_workspaces(WorkspaceMemberType::Example); for workspace in workspaces { if workspace.name == "notebook" { From 2d0a4887116465ecc130e11db1415e1043a66112 Mon Sep 17 00:00:00 2001 From: Alex Date: Sun, 28 Jan 2024 16:58:25 -0600 Subject: [PATCH 75/79] less CI changes --- xtask/src/utils/workspace.rs | 28 ++++++++++++++++------------ 1 file changed, 16 insertions(+), 12 deletions(-) diff --git a/xtask/src/utils/workspace.rs b/xtask/src/utils/workspace.rs index a469689500..234bb8333f 100644 --- a/xtask/src/utils/workspace.rs +++ b/xtask/src/utils/workspace.rs @@ -5,7 +5,6 @@ use serde_json::Value; pub(crate) enum WorkspaceMemberType { Crate, Example, - Both, } #[derive(Debug)] @@ -41,10 +40,6 @@ pub(crate) fn get_workspaces(w_type: WorkspaceMemberType) -> Vec Vec member, - WorkspaceMemberType::Example if workspace_path.contains("examples/") => member, - WorkspaceMemberType::Both => member, + WorkspaceMemberType::Crate + if workspace_name != "xtask" && !workspace_path.contains("examples/") => + { + Some(WorkspaceMember::new( + workspace_name.to_string(), + workspace_path.to_string(), + )) + } + WorkspaceMemberType::Example + if workspace_name != "xtask" && workspace_path.contains("examples/") => + { + Some(WorkspaceMember::new( + workspace_name.to_string(), + workspace_path.to_string(), + )) + } _ => None, } }) From f053b4dbc797be4249a692d0fbb094a1e7a2c615 Mon Sep 17 00:00:00 2001 From: Alex Date: Sun, 28 Jan 2024 18:58:38 -0600 Subject: [PATCH 76/79] bump wasm-bindgen --- Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index 96685859db..9c5ad0aefa 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -86,7 +86,7 @@ tokio = { version = "1.35.1", features = ["rt", "macros"] } tracing-appender = "0.2.3" tracing-core = "0.1.32" tracing-subscriber = "0.3.18" -wasm-bindgen = "=0.2.89" +wasm-bindgen = "=0.2.90" wasm-bindgen-futures = "0.4.38" wasm-logger = "0.2.0" wasm-timer = "0.2.5" From 116d3f53912e39479c843a6323ee76a2d7c0486c Mon Sep 17 00:00:00 2001 From: Alex Date: Sun, 28 Jan 2024 19:15:41 -0600 Subject: [PATCH 77/79] add generic --- burn-core/src/module/base.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/burn-core/src/module/base.rs b/burn-core/src/module/base.rs index 562a3012b9..16678fe8c6 100644 --- a/burn-core/src/module/base.rs +++ b/burn-core/src/module/base.rs @@ -178,7 +178,7 @@ pub trait Module: Clone + Send + Sync + core::fmt::Debug { } /// Return the module using [BytesRecorder](crate::record::BytesRecorder). - fn to_bytes( + fn to_bytes>( self, recorder: &BR, ) -> Result, crate::record::RecorderError> { From 3b53c3247501cdd9f70f91769b7d54ccbe94111d Mon Sep 17 00:00:00 2001 From: Alex Date: Mon, 29 Jan 2024 12:50:04 -0600 Subject: [PATCH 78/79] it builds --- examples/train-web/train/src/data.rs | 13 +++++++++---- examples/train-web/train/src/train.rs | 6 ++++-- 2 files changed, 13 insertions(+), 6 deletions(-) diff --git a/examples/train-web/train/src/data.rs b/examples/train-web/train/src/data.rs index cc6d87e5cb..49f617bfcb 100644 --- a/examples/train-web/train/src/data.rs +++ b/examples/train-web/train/src/data.rs @@ -25,7 +25,7 @@ impl Batcher> for MNISTBatcher { let images = items .iter() .map(|item| Data::::from(item.image)) - .map(|data| Tensor::::from_data_devauto(data.convert())) + .map(|data| Tensor::::from_data(data.convert(), &self.device)) .map(|tensor| tensor.reshape([1, 28, 28])) // normalize: make between [0,1] and make the mean = 0 and std = 1 // values mean=0.1307,std=0.3081 were copied from Pytorch Mist Example @@ -35,11 +35,16 @@ impl Batcher> for MNISTBatcher { let targets = items .iter() - .map(|item| Tensor::::from_data_devauto([(item.label as i64).elem()])) + .map(|item| { + Tensor::::from_data( + Data::from([(item.label as i64).elem()]), + &self.device, + ) + }) .collect(); - let images = Tensor::cat(images, 0).to_device(&self.device); - let targets = Tensor::cat(targets, 0).to_device(&self.device); + let images = Tensor::cat(images, 0); + let targets = Tensor::cat(targets, 0); MNISTBatch { images, targets } } diff --git a/examples/train-web/train/src/train.rs b/examples/train-web/train/src/train.rs index 4ab5dc9c9d..77fbfd082f 100644 --- a/examples/train-web/train/src/train.rs +++ b/examples/train-web/train/src/train.rs @@ -8,7 +8,7 @@ use burn::{ config::Config, data::dataloader::DataLoaderBuilder, module::Module, - nn::loss::CrossEntropyLoss, + nn::loss::CrossEntropyLossConfig, optim::AdamConfig, record::{BinBytesRecorder, FullPrecisionSettings}, tensor::{ @@ -30,7 +30,9 @@ impl Model { targets: Tensor, ) -> ClassificationOutput { let output = self.forward(images); - let loss = CrossEntropyLoss::default().forward(output.clone(), targets.clone()); + let loss = CrossEntropyLossConfig::new() + .init(&output.device()) + .forward(output.clone(), targets.clone()); ClassificationOutput::new(loss, output, targets) } From a527fae7a6fd95ffb819dfd4c824763972e15a9a Mon Sep 17 00:00:00 2001 From: Alex Date: Mon, 29 Jan 2024 13:02:12 -0600 Subject: [PATCH 79/79] add AsyncTask(Boxed) --- burn-train/src/metric/store/client.rs | 7 ++++--- burn-train/src/util.rs | 30 +++++++++++++++++++++++---- 2 files changed, 30 insertions(+), 7 deletions(-) diff --git a/burn-train/src/metric/store/client.rs b/burn-train/src/metric/store/client.rs index b018456caa..95a0b9cc95 100644 --- a/burn-train/src/metric/store/client.rs +++ b/burn-train/src/metric/store/client.rs @@ -1,12 +1,13 @@ use super::EventStore; use super::{Aggregate, Direction, Event, Split}; -use crate::util; +use crate::util::{self, AsyncTaskBoxed}; +use log::info; use std::sync::mpsc; /// Type that allows to communicate with an [event store](EventStore). pub struct EventStoreClient { sender: mpsc::Sender, - handler: Option Result<(), ()>>>, + handler: Option, } impl EventStoreClient { @@ -154,7 +155,7 @@ impl Drop for EventStoreClient { let handler = self.handler.take(); if let Some(handler) = handler { - handler().expect("The event store thread should stop."); + handler.join().expect("The event store thread should stop."); } } } diff --git a/burn-train/src/util.rs b/burn-train/src/util.rs index 4ecfeb82db..88847e3eb7 100644 --- a/burn-train/src/util.rs +++ b/burn-train/src/util.rs @@ -1,23 +1,45 @@ #![allow(missing_docs)] +pub trait AsyncTask { + fn join(self: Box) -> Result<(), ()>; +} + +pub type AsyncTaskBoxed = Box; + +struct Thread { + join: Box Result<(), ()>>, +} + +impl AsyncTask for Thread { + fn join(self: Box) -> Result<(), ()> { + (self.join)() + } +} + +impl Thread { + fn new(join: Box Result<(), ()>>) -> Self { + Thread { join } + } +} + #[cfg(not(feature = "browser"))] -pub fn spawn(f: F) -> Box Result<(), ()>> +pub fn spawn(f: F) -> AsyncTaskBoxed where F: FnOnce(), F: Send + 'static, { let handle = std::thread::spawn(f); - Box::new(move || handle.join().map_err(|_| ())) + Box::new(Thread::new(Box::new(move || handle.join().map_err(|_| ())))) } #[cfg(feature = "browser")] -pub fn spawn(f: F) -> Box Result<(), ()>> +pub fn spawn(f: F) -> AsyncTaskBoxed where F: FnOnce(), F: Send + 'static, { rayon::spawn(f); - Box::new(|| Ok(())) + Box::new(Thread::new(Box::new(|| Ok(())))) } #[cfg(feature = "browser")]