Skip to content

Commit

Permalink
Tracing for the phi model (huggingface#936)
Browse files Browse the repository at this point in the history
* Add some tracing bits to mixformers.

* Add the missing file.

* Add the conv2d layer to with-tracing.

* Improve the tracing usage.
  • Loading branch information
LaurentMazare authored Sep 23, 2023
1 parent cda1786 commit b54acfa
Show file tree
Hide file tree
Showing 9 changed files with 140 additions and 100 deletions.
17 changes: 16 additions & 1 deletion candle-examples/examples/phi/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ impl TextGeneration {
}
let dt = start_gen.elapsed();
println!(
"{sample_len} tokens generated ({:.3} token/s)",
"\n{sample_len} tokens generated ({:.2} token/s)",
sample_len as f64 / dt.as_secs_f64(),
);
Ok(())
Expand All @@ -84,6 +84,10 @@ struct Args {
#[arg(long)]
cpu: bool,

/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
tracing: bool,

#[arg(long)]
prompt: String,

Expand Down Expand Up @@ -114,8 +118,19 @@ struct Args {
}

fn main() -> Result<()> {
use tracing_chrome::ChromeLayerBuilder;
use tracing_subscriber::prelude::*;

let args = Args::parse();

let _guard = if args.tracing {
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
tracing_subscriber::registry().with(chrome_layer).init();
Some(guard)
} else {
None
};

let start = std::time::Instant::now();
let api = Api::new()?;
let repo = api.repo(Repo::with_revision(
Expand Down
40 changes: 27 additions & 13 deletions candle-transformers/src/models/mixformer.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use crate::models::with_tracing::{linear, Embedding as E, Linear};
/// MixFormer model.
/// https://huggingface.co/microsoft/phi-1_5
/// https://arxiv.org/abs/2309.05463
Expand Down Expand Up @@ -58,12 +59,12 @@ impl Config {

#[derive(Debug)]
struct Embedding {
wte: candle_nn::Embedding,
wte: E,
}

impl Embedding {
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let wte = candle_nn::embedding(cfg.vocab_size, cfg.n_embd, vb.pp("wte"))?;
let wte = E::new(cfg.vocab_size, cfg.n_embd, vb.pp("wte"))?;
Ok(Self { wte })
}
}
Expand Down Expand Up @@ -143,16 +144,16 @@ impl RotaryEmbedding {
#[derive(Debug)]
#[allow(clippy::upper_case_acronyms)]
struct MLP {
fc1: candle_nn::Linear,
fc2: candle_nn::Linear,
fc1: Linear,
fc2: Linear,
act: Activation,
}

impl MLP {
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let n_inner = cfg.n_inner.unwrap_or(4 * cfg.n_embd);
let fc1 = candle_nn::linear(cfg.n_embd, n_inner, vb.pp("fc1"))?;
let fc2 = candle_nn::linear(n_inner, cfg.n_embd, vb.pp("fc2"))?;
let fc1 = linear(cfg.n_embd, n_inner, vb.pp("fc1"))?;
let fc2 = linear(n_inner, cfg.n_embd, vb.pp("fc2"))?;
Ok(Self {
fc1,
fc2,
Expand All @@ -170,13 +171,13 @@ impl Module for MLP {
#[derive(Debug)]
struct CausalLMHead {
ln: candle_nn::LayerNorm,
linear: candle_nn::Linear,
linear: Linear,
}

impl CausalLMHead {
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let ln = candle_nn::layer_norm(cfg.n_embd, cfg.layer_norm_epsilon, vb.pp("ln"))?;
let linear = candle_nn::linear(cfg.n_embd, cfg.vocab_size, vb.pp("linear"))?;
let linear = linear(cfg.n_embd, cfg.vocab_size, vb.pp("linear"))?;
Ok(Self { ln, linear })
}
}
Expand All @@ -192,20 +193,21 @@ impl Module for CausalLMHead {
#[derive(Debug)]
#[allow(clippy::upper_case_acronyms)]
struct MHA {
wqkv: candle_nn::Linear,
out_proj: candle_nn::Linear,
wqkv: Linear,
out_proj: Linear,
rotary_emb: RotaryEmbedding,
kv_cache: Option<(Tensor, Tensor)>,
head_dim: usize,
softmax_scale: f64,
span: tracing::Span,
}

impl MHA {
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let head_dim = cfg.n_embd / cfg.n_head;
let op_size = cfg.n_embd;
let wqkv = candle_nn::linear(cfg.n_embd, 3 * op_size, vb.pp("Wqkv"))?;
let out_proj = candle_nn::linear(op_size, cfg.n_embd, vb.pp("out_proj"))?;
let wqkv = linear(cfg.n_embd, 3 * op_size, vb.pp("Wqkv"))?;
let out_proj = linear(op_size, cfg.n_embd, vb.pp("out_proj"))?;
let rotary_emb = RotaryEmbedding::new(cfg.rotary_dim, MAX_SEQ_LEN, vb.device())?;
let softmax_scale = 1f64 / (head_dim as f64).sqrt();
Ok(Self {
Expand All @@ -215,10 +217,12 @@ impl MHA {
kv_cache: None,
rotary_emb,
softmax_scale,
span: tracing::span!(tracing::Level::TRACE, "mha"),
})
}

fn forward(&mut self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let (b_size, seq_len, _n_embd) = xs.dims3()?;
let qkv = self
.wqkv
Expand Down Expand Up @@ -267,17 +271,24 @@ struct ParallelBlock {
ln: candle_nn::LayerNorm,
mixer: MHA,
mlp: MLP,
span: tracing::Span,
}

impl ParallelBlock {
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let ln = candle_nn::layer_norm(cfg.n_embd, cfg.layer_norm_epsilon, vb.pp("ln"))?;
let mixer = MHA::new(cfg, vb.pp("mixer"))?;
let mlp = MLP::new(cfg, vb.pp("mlp"))?;
Ok(Self { ln, mixer, mlp })
Ok(Self {
ln,
mixer,
mlp,
span: tracing::span!(tracing::Level::TRACE, "block"),
})
}

fn forward(&mut self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let residual = xs;
let xs = xs.apply(&self.ln)?;
let attn_outputs = self.mixer.forward(&xs)?;
Expand All @@ -291,6 +302,7 @@ pub struct MixFormerSequentialForCausalLM {
embedding: Embedding,
blocks: Vec<ParallelBlock>,
head: CausalLMHead,
span: tracing::Span,
}

impl MixFormerSequentialForCausalLM {
Expand All @@ -307,10 +319,12 @@ impl MixFormerSequentialForCausalLM {
embedding,
blocks,
head,
span: tracing::span!(tracing::Level::TRACE, "mixformer"),
})
}

pub fn forward(&mut self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let (_b_size, seq_len) = xs.dims2()?;
let mut xs = xs.apply(&self.embedding)?;
for block in self.blocks.iter_mut() {
Expand Down
1 change: 1 addition & 0 deletions candle-transformers/src/models/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,5 @@ pub mod segment_anything;
pub mod stable_diffusion;
pub mod t5;
pub mod whisper;
pub mod with_tracing;
pub mod wuerstchen;
2 changes: 1 addition & 1 deletion candle-transformers/src/models/stable_diffusion/resnet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
//!
//! Denoising Diffusion Implicit Models, K. He and al, 2015.
//! https://arxiv.org/abs/1512.03385
use super::utils::{conv2d, Conv2d};
use crate::models::with_tracing::{conv2d, Conv2d};
use candle::{Result, Tensor, D};
use candle_nn as nn;
use candle_nn::Module;
Expand Down
2 changes: 1 addition & 1 deletion candle-transformers/src/models/stable_diffusion/unet_2d.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
//! timestep and return a denoised version of the input.
use super::embeddings::{TimestepEmbedding, Timesteps};
use super::unet_2d_blocks::*;
use super::utils::{conv2d, Conv2d};
use crate::models::with_tracing::{conv2d, Conv2d};
use candle::{Result, Tensor};
use candle_nn as nn;
use candle_nn::Module;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use super::attention::{
AttentionBlock, AttentionBlockConfig, SpatialTransformer, SpatialTransformerConfig,
};
use super::resnet::{ResnetBlock2D, ResnetBlock2DConfig};
use super::utils::{conv2d, Conv2d};
use crate::models::with_tracing::{conv2d, Conv2d};
use candle::{Module, Result, Tensor, D};
use candle_nn as nn;

Expand Down
27 changes: 0 additions & 27 deletions candle-transformers/src/models/stable_diffusion/utils.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
use candle::{Device, Result, Tensor};
use candle_nn::Module;

pub fn linspace(start: f64, stop: f64, steps: usize) -> Result<Tensor> {
if steps < 1 {
Expand All @@ -11,29 +10,3 @@ pub fn linspace(start: f64, stop: f64, steps: usize) -> Result<Tensor> {
.collect::<Vec<_>>();
Tensor::from_vec(vs, steps, &Device::Cpu)
}

// Wrap the conv2d op to provide some tracing.
#[derive(Debug)]
pub struct Conv2d {
inner: candle_nn::Conv2d,
span: tracing::Span,
}

impl Conv2d {
pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
self.inner.forward(x)
}
}

pub fn conv2d(
in_channels: usize,
out_channels: usize,
kernel_size: usize,
cfg: candle_nn::Conv2dConfig,
vs: candle_nn::VarBuilder,
) -> Result<Conv2d> {
let span = tracing::span!(tracing::Level::TRACE, "conv2d");
let inner = candle_nn::conv2d(in_channels, out_channels, kernel_size, cfg, vs)?;
Ok(Conv2d { inner, span })
}
71 changes: 15 additions & 56 deletions candle-transformers/src/models/t5.rs
Original file line number Diff line number Diff line change
@@ -1,57 +1,12 @@
// T5 Text Model
// https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py

use crate::models::with_tracing::{linear_no_bias, Embedding, Linear};
use candle::{DType, Device, Module, Result, Tensor, D};
use candle_nn::{Activation, VarBuilder};
use serde::Deserialize;
use std::sync::Arc;

#[derive(Debug)]
struct Embedding {
inner: candle_nn::Embedding,
span: tracing::Span,
}

impl Embedding {
fn new(d1: usize, d2: usize, vb: VarBuilder) -> Result<Self> {
let inner = candle_nn::embedding(d1, d2, vb)?;
let span = tracing::span!(tracing::Level::TRACE, "embedding");
Ok(Self { inner, span })
}

fn embeddings(&self) -> &Tensor {
self.inner.embeddings()
}
}

impl Module for Embedding {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
self.inner.forward(xs)
}
}

#[derive(Debug)]
struct Linear {
inner: candle_nn::Linear,
span: tracing::Span,
}

impl Linear {
fn new(d1: usize, d2: usize, vb: VarBuilder) -> Result<Self> {
let inner = candle_nn::linear_no_bias(d1, d2, vb)?;
let span = tracing::span!(tracing::Level::TRACE, "linear");
Ok(Self { inner, span })
}
}

impl Module for Linear {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
self.inner.forward(xs)
}
}

fn default_relative_attention_max_distance() -> usize {
128
}
Expand Down Expand Up @@ -205,8 +160,8 @@ struct T5DenseActDense {

impl T5DenseActDense {
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
let wi = Linear::new(cfg.d_model, cfg.d_ff, vb.pp("wi"))?;
let wo = Linear::new(cfg.d_ff, cfg.d_model, vb.pp("wo"))?;
let wi = linear_no_bias(cfg.d_model, cfg.d_ff, vb.pp("wi"))?;
let wo = linear_no_bias(cfg.d_ff, cfg.d_model, vb.pp("wo"))?;
Ok(Self {
wi,
wo,
Expand Down Expand Up @@ -237,9 +192,9 @@ struct T5DenseGatedActDense {

impl T5DenseGatedActDense {
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
let wi_0 = Linear::new(cfg.d_model, cfg.d_ff, vb.pp("wi_0"))?;
let wi_1 = Linear::new(cfg.d_model, cfg.d_ff, vb.pp("wi_1"))?;
let wo = Linear::new(cfg.d_ff, cfg.d_model, vb.pp("wo"))?;
let wi_0 = linear_no_bias(cfg.d_model, cfg.d_ff, vb.pp("wi_0"))?;
let wi_1 = linear_no_bias(cfg.d_model, cfg.d_ff, vb.pp("wi_1"))?;
let wo = linear_no_bias(cfg.d_ff, cfg.d_model, vb.pp("wo"))?;
Ok(Self {
wi_0,
wi_1,
Expand Down Expand Up @@ -334,10 +289,10 @@ impl T5Attention {
cfg: &Config,
) -> Result<Self> {
let inner_dim = cfg.num_heads * cfg.d_kv;
let q = Linear::new(cfg.d_model, inner_dim, vb.pp("q"))?;
let k = Linear::new(cfg.d_model, inner_dim, vb.pp("k"))?;
let v = Linear::new(cfg.d_model, inner_dim, vb.pp("v"))?;
let o = Linear::new(inner_dim, cfg.d_model, vb.pp("o"))?;
let q = linear_no_bias(cfg.d_model, inner_dim, vb.pp("q"))?;
let k = linear_no_bias(cfg.d_model, inner_dim, vb.pp("k"))?;
let v = linear_no_bias(cfg.d_model, inner_dim, vb.pp("v"))?;
let o = linear_no_bias(inner_dim, cfg.d_model, vb.pp("o"))?;
let relative_attention_bias = if has_relative_attention_bias {
let emb = Embedding::new(
cfg.relative_attention_num_buckets,
Expand Down Expand Up @@ -772,7 +727,11 @@ impl T5ForConditionalGeneration {
let lm_head = if tie_word_embeddings {
None
} else {
Some(Linear::new(cfg.d_model, cfg.vocab_size, vb.pp("lm_head"))?)
Some(linear_no_bias(
cfg.d_model,
cfg.vocab_size,
vb.pp("lm_head"),
)?)
};

Ok(Self {
Expand Down
Loading

0 comments on commit b54acfa

Please sign in to comment.