Skip to content

Commit

Permalink
Add WebGPU + dtypes docs
Browse files Browse the repository at this point in the history
  • Loading branch information
xenova committed Oct 22, 2024
1 parent 14d0360 commit 270eee8
Show file tree
Hide file tree
Showing 3 changed files with 222 additions and 1 deletion.
6 changes: 5 additions & 1 deletion docs/source/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,14 @@
title: Server-side Inference in Node.js
title: Tutorials
- sections:
- local: guides/webgpu
title: Running models on WebGPU
- local: guides/dtypes
title: Using quantized models (dtypes)
- local: guides/private
title: Accessing Private/Gated Models
- local: guides/node-audio-processing
title: Server-side Audio Processing in Node.js
title: Server-side Audio Processing
title: Developer Guides
- sections:
- local: api/transformers
Expand Down
130 changes: 130 additions & 0 deletions docs/source/guides/dtypes.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
# Using quantized models (dtypes)

Before Transformers.js v3, we used the `quantized` option to specify whether to use a quantized (q8) or full-precision (fp32) variant of the model by setting `quantized` to `true` or `false`, respectively. Now, we've added the ability to select from a much larger list with the `dtype` parameter.

The list of available quantizations depends on the model, but some common ones are: full-precision (`"fp32"`), half-precision (`"fp16"`), 8-bit (`"q8"`, `"int8"`, `"uint8"`), and 4-bit (`"q4"`, `"bnb4"`, `"q4f16"`).

<p align="center">
<picture>
<source media="(prefers-color-scheme: dark)" srcset="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/blog/transformersjs-v3/dtypes-dark.jpg" style="max-width: 100%;">
<source media="(prefers-color-scheme: light)" srcset="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/blog/transformersjs-v3/dtypes-light.jpg" style="max-width: 100%;">
<img alt="Available dtypes for mixedbread-ai/mxbai-embed-xsmall-v1" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/blog/transformersjs-v3/dtypes-dark.jpg" style="max-width: 100%;">
</picture>
<a href="https://huggingface.co/mixedbread-ai/mxbai-embed-xsmall-v1/tree/main/onnx">(e.g., mixedbread-ai/mxbai-embed-xsmall-v1)</a>
</p>

## Basic usage

**Example:** Run Qwen2.5-0.5B-Instruct in 4-bit quantization ([demo](https://v2.scrimba.com/s0dlcpv0ci))

```js
import { pipeline } from "@huggingface/transformers";

// Create a text generation pipeline
const generator = await pipeline(
"text-generation",
"onnx-community/Qwen2.5-0.5B-Instruct",
{ dtype: "q4", device: "webgpu" },
);

// Define the list of messages
const messages = [
{ role: "system", content: "You are a helpful assistant." },
{ role: "user", content: "Tell me a funny joke." },
];

// Generate a response
const output = await generator(messages, { max_new_tokens: 128 });
console.log(output[0].generated_text.at(-1).content);
```

## Per-module dtypes

Some encoder-decoder models, like Whisper or Florence-2, are extremely sensitive to quantization settings: especially of the encoder. For this reason, we added the ability to select per-module dtypes, which can be done by providing a mapping from module name to dtype.

**Example:** Run Florence-2 on WebGPU ([demo](https://v2.scrimba.com/s0pdm485fo))

```js
import { Florence2ForConditionalGeneration } from "@huggingface/transformers";

const model = await Florence2ForConditionalGeneration.from_pretrained(
"onnx-community/Florence-2-base-ft",
{
dtype: {
embed_tokens: "fp16",
vision_encoder: "fp16",
encoder_model: "q4",
decoder_model_merged: "q4",
},
device: "webgpu",
},
);
```

<p align="middle">
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/blog/transformersjs-v3/florence-2-webgpu.gif" alt="Florence-2 running on WebGPU" />
</p>

<details>
<summary>
See full code example
</summary>

```js
import {
Florence2ForConditionalGeneration,
AutoProcessor,
AutoTokenizer,
RawImage,
} from "@huggingface/transformers";

// Load model, processor, and tokenizer
const model_id = "onnx-community/Florence-2-base-ft";
const model = await Florence2ForConditionalGeneration.from_pretrained(
model_id,
{
dtype: {
embed_tokens: "fp16",
vision_encoder: "fp16",
encoder_model: "q4",
decoder_model_merged: "q4",
},
device: "webgpu",
},
);
const processor = await AutoProcessor.from_pretrained(model_id);
const tokenizer = await AutoTokenizer.from_pretrained(model_id);

// Load image and prepare vision inputs
const url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg";
const image = await RawImage.fromURL(url);
const vision_inputs = await processor(image);

// Specify task and prepare text inputs
const task = "<MORE_DETAILED_CAPTION>";
const prompts = processor.construct_prompts(task);
const text_inputs = tokenizer(prompts);

// Generate text
const generated_ids = await model.generate({
...text_inputs,
...vision_inputs,
max_new_tokens: 100,
});

// Decode generated text
const generated_text = tokenizer.batch_decode(generated_ids, {
skip_special_tokens: false,
})[0];

// Post-process the generated text
const result = processor.post_process_generation(
generated_text,
task,
image.size,
);
console.log(result);
// { '<MORE_DETAILED_CAPTION>': 'A green car is parked in front of a tan building. The building has a brown door and two brown windows. The car is a two door and the door is closed. The green car has black tires.' }
```

</details>
87 changes: 87 additions & 0 deletions docs/source/guides/webgpu.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
# Running models on WebGPU

WebGPU is a new web standard for accelerated graphics and compute. The [API](https://developer.mozilla.org/en-US/docs/Web/API/WebGPU_API) enables web developers to use the underlying system's GPU to carry out high-performance computations directly in the browser. WebGPU is the successor to [WebGL](https://developer.mozilla.org/en-US/docs/Web/API/WebGL_API) and provides significantly better performance, because it allows for more direct interaction with modern GPUs. Lastly, it supports general-purpose GPU computations, which makes it just perfect for machine learning!

> [!WARNING]
> As of October 2024, global WebGPU support is around 70% (according to [caniuse.com](https://caniuse.com/webgpu)), meaning some users may not be able to use the API.
>
> If the following demos do not work in your browser, you may need to enable it using a feature flag:
>
> - Firefox: with the `dom.webgpu.enabled` flag (see [here](https://developer.mozilla.org/en-US/docs/Mozilla/Firefox/Experimental_features#:~:text=tested%20by%20Firefox.-,WebGPU%20API,-The%20WebGPU%20API)).
> - Safari: with the `WebGPU` feature flag (see [here](https://webkit.org/blog/14879/webgpu-now-available-for-testing-in-safari-technology-preview/)).
> - Older Chromium browsers (on Windows, macOS, Linux): with the `enable-unsafe-webgpu` flag (see [here](https://developer.chrome.com/docs/web-platform/webgpu/troubleshooting-tips)).
## Usage in Transformers.js v3

Thanks to our collaboration with [ONNX Runtime Web](https://www.npmjs.com/package/onnxruntime-web), enabling WebGPU acceleration is as simple as setting `device: 'webgpu'` when loading a model. Let's see some examples!

**Example:** Compute text embeddings on WebGPU ([demo](https://v2.scrimba.com/s06a2smeej))

```js
import { pipeline } from "@huggingface/transformers";

// Create a feature-extraction pipeline
const extractor = await pipeline(
"feature-extraction",
"mixedbread-ai/mxbai-embed-xsmall-v1",
{ device: "webgpu" },
});

// Compute embeddings
const texts = ["Hello world!", "This is an example sentence."];
const embeddings = await extractor(texts, { pooling: "mean", normalize: true });
console.log(embeddings.tolist());
// [
// [-0.016986183822155, 0.03228696808218956, -0.0013630966423079371, ... ],
// [0.09050482511520386, 0.07207386940717697, 0.05762749910354614, ... ],
// ]
```

**Example:** Perform automatic speech recognition with OpenAI whisper on WebGPU ([demo](https://v2.scrimba.com/s0oi76h82g))

```js
import { pipeline } from "@huggingface/transformers";

// Create automatic speech recognition pipeline
const transcriber = await pipeline(
"automatic-speech-recognition",
"onnx-community/whisper-tiny.en",
{ device: "webgpu" },
);

// Transcribe audio from a URL
const url = "https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/jfk.wav";
const output = await transcriber(url);
console.log(output);
// { text: ' And so my fellow Americans ask not what your country can do for you, ask what you can do for your country.' }
```

**Example:** Perform image classification with MobileNetV4 on WebGPU ([demo](https://v2.scrimba.com/s0fv2uab1t))

```js
import { pipeline } from "@huggingface/transformers";

// Create image classification pipeline
const classifier = await pipeline(
"image-classification",
"onnx-community/mobilenetv4_conv_small.e2400_r224_in1k",
{ device: "webgpu" },
);

// Classify an image from a URL
const url = "https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/tiger.jpg";
const output = await classifier(url);
console.log(output);
// [
// { label: 'tiger, Panthera tigris', score: 0.6149784922599792 },
// { label: 'tiger cat', score: 0.30281734466552734 },
// { label: 'tabby, tabby cat', score: 0.0019135422771796584 },
// { label: 'lynx, catamount', score: 0.0012161266058683395 },
// { label: 'Egyptian cat', score: 0.0011465961579233408 }
// ]
```

## Reporting bugs and providing feedback

Due to the experimental nature of the WebGPU API, especially in non-Chromium browsers, you may

0 comments on commit 270eee8

Please sign in to comment.