From 40a7c4d6a1fb69fc5ccf1c6df0345f4cdb1959af Mon Sep 17 00:00:00 2001 From: Jani Monoses Date: Wed, 14 Aug 2024 16:13:41 +0300 Subject: [PATCH] Add FastViT model. --- candle-examples/examples/fastvit/README.md | 20 + candle-examples/examples/fastvit/main.rs | 102 ++++ candle-transformers/src/models/fastvit.rs | 512 +++++++++++++++++++++ candle-transformers/src/models/mod.rs | 1 + 4 files changed, 635 insertions(+) create mode 100644 candle-examples/examples/fastvit/README.md create mode 100644 candle-examples/examples/fastvit/main.rs create mode 100644 candle-transformers/src/models/fastvit.rs diff --git a/candle-examples/examples/fastvit/README.md b/candle-examples/examples/fastvit/README.md new file mode 100644 index 000000000..499685bd3 --- /dev/null +++ b/candle-examples/examples/fastvit/README.md @@ -0,0 +1,20 @@ +# candle-fastvit + +[FastViT: A Fast Hybrid Vision Transformer using Structural Reparameterization](https://arxiv.org/abs/2303.14189). +This candle implementation uses a pre-trained FastViT network for inference. The +classification head has been trained on the ImageNet dataset and returns the +probabilities for the top-5 classes. + +## Running an example + +``` +$ cargo run --example fastvit --release -- --image candle-examples/examples/yolo-v8/assets/bike.jpg --which sa12 + +loaded image Tensor[dims 3, 256, 256; f32] +model built +mountain bike, all-terrain bike, off-roader: 43.45% +bicycle-built-for-two, tandem bicycle, tandem: 14.16% +unicycle, monocycle : 4.12% +crash helmet : 2.26% +alp : 1.40% +``` diff --git a/candle-examples/examples/fastvit/main.rs b/candle-examples/examples/fastvit/main.rs new file mode 100644 index 000000000..520fd0aed --- /dev/null +++ b/candle-examples/examples/fastvit/main.rs @@ -0,0 +1,102 @@ +#[cfg(feature = "mkl")] +extern crate intel_mkl_src; + +#[cfg(feature = "accelerate")] +extern crate accelerate_src; + +use clap::{Parser, ValueEnum}; + +use candle::{DType, IndexOp, D}; +use candle_nn::{Module, VarBuilder}; +use candle_transformers::models::fastvit; + +#[derive(Clone, Copy, Debug, ValueEnum)] +enum Which { + T8, + T12, + S12, + SA12, + SA24, + SA36, + MA36, +} + +impl Which { + fn model_filename(&self) -> String { + let name = match self { + Self::T8 => "t8", + Self::T12 => "t12", + Self::S12 => "s12", + Self::SA12 => "sa12", + Self::SA24 => "sa24", + Self::SA36 => "sa36", + Self::MA36 => "ma36", + }; + format!("timm/fastvit_{}.apple_in1k", name) + } + + fn config(&self) -> fastvit::Config { + match self { + Self::T8 => fastvit::Config::t8(), + Self::T12 => fastvit::Config::t12(), + Self::S12 => fastvit::Config::s12(), + Self::SA12 => fastvit::Config::sa12(), + Self::SA24 => fastvit::Config::sa24(), + Self::SA36 => fastvit::Config::sa36(), + Self::MA36 => fastvit::Config::ma36(), + } + } +} + +#[derive(Parser)] +struct Args { + #[arg(long)] + model: Option, + + #[arg(long)] + image: String, + + /// Run on CPU rather than on GPU. + #[arg(long)] + cpu: bool, + + #[arg(value_enum, long, default_value_t=Which::S12)] + which: Which, +} + +pub fn main() -> anyhow::Result<()> { + let args = Args::parse(); + + let device = candle_examples::device(args.cpu)?; + + let image = candle_examples::imagenet::load_image(args.image, 256)?.to_device(&device)?; + println!("loaded image {image:?}"); + + let model_file = match args.model { + None => { + let model_name = args.which.model_filename(); + let api = hf_hub::api::sync::Api::new()?; + let api = api.model(model_name); + api.get("model.safetensors")? + } + Some(model) => model.into(), + }; + + let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, &device)? }; + let model = fastvit::fastvit(&args.which.config(), 1000, vb)?; + println!("model built"); + let logits = model.forward(&image.unsqueeze(0)?)?; + let prs = candle_nn::ops::softmax(&logits, D::Minus1)? + .i(0)? + .to_vec1::()?; + let mut prs = prs.iter().enumerate().collect::>(); + prs.sort_by(|(_, p1), (_, p2)| p2.total_cmp(p1)); + for &(category_idx, pr) in prs.iter().take(5) { + println!( + "{:24}: {:.2}%", + candle_examples::imagenet::CLASSES[category_idx], + 100. * pr + ); + } + Ok(()) +} diff --git a/candle-transformers/src/models/fastvit.rs b/candle-transformers/src/models/fastvit.rs new file mode 100644 index 000000000..a0b3cc3e5 --- /dev/null +++ b/candle-transformers/src/models/fastvit.rs @@ -0,0 +1,512 @@ +//! FastViT inference implementation based on timm +//! +//! See "FastViT: A Fast Hybrid Vision Transformer using Structural Reparameterization" +//! https://arxiv.org/pdf/2303.14189 +//! +//! https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/fastvit.py + +use candle::{DType, Result, Tensor, D}; +use candle_nn::{ + batch_norm, conv2d, conv2d_no_bias, linear, linear_no_bias, ops::sigmoid, ops::softmax, + BatchNorm, Conv2d, Conv2dConfig, Func, VarBuilder, +}; + +#[derive(Clone, Debug)] +pub struct Config { + exp_ratio: usize, + in_channels: usize, + blocks: [usize; 4], + attn: bool, + lkc_use_act: bool, +} + +impl Config { + pub fn t8() -> Self { + Self { + exp_ratio: 3, + in_channels: 48, + blocks: [2, 2, 4, 2], + attn: false, + lkc_use_act: false, + } + } + + pub fn t12() -> Self { + Self { + exp_ratio: 3, + in_channels: 64, + blocks: [2, 2, 6, 2], + attn: false, + lkc_use_act: false, + } + } + pub fn s12() -> Self { + Self { + exp_ratio: 4, + in_channels: 64, + blocks: [2, 2, 6, 2], + attn: false, + lkc_use_act: false, + } + } + pub fn sa12() -> Self { + Self { + exp_ratio: 4, + in_channels: 64, + blocks: [2, 2, 6, 2], + attn: true, + lkc_use_act: false, + } + } + pub fn sa24() -> Self { + Self { + exp_ratio: 4, + in_channels: 64, + blocks: [4, 4, 12, 4], + attn: true, + lkc_use_act: false, + } + } + pub fn sa36() -> Self { + Self { + exp_ratio: 4, + in_channels: 64, + blocks: [6, 6, 18, 6], + attn: true, + lkc_use_act: false, + } + } + pub fn ma36() -> Self { + Self { + exp_ratio: 4, + in_channels: 76, + blocks: [6, 6, 18, 6], + attn: true, + lkc_use_act: false, + } + } + + // configs used by MobileCLIP's image encoder + pub fn mci0() -> Self { + Self { + exp_ratio: 3, + in_channels: 64, + blocks: [2, 6, 10, 2], + attn: true, + lkc_use_act: true, + } + } + pub fn mci1() -> Self { + Self { + exp_ratio: 3, + in_channels: 64, + blocks: [4, 12, 20, 4], + attn: true, + lkc_use_act: true, + } + } + pub fn mci2() -> Self { + Self { + exp_ratio: 3, + in_channels: 80, + blocks: [4, 12, 24, 4], + attn: true, + lkc_use_act: true, + } + } +} + +fn conv_norm( + in_channels: usize, + out_channels: usize, + kernel: usize, + stride: usize, + vb: VarBuilder, +) -> Result> { + let conv2d_cfg = Conv2dConfig { + stride, + padding: kernel / 2, + groups: in_channels, + ..Default::default() + }; + + let bn = batch_norm(out_channels, 1e-5, vb.pp("bn"))?; + let conv = conv2d_no_bias(in_channels, out_channels, kernel, conv2d_cfg, vb.pp("conv"))?; + let conv = conv.absorb_bn(&bn)?; + Ok(Func::new(move |xs| { + let xs = xs.apply(&conv)?; + Ok(xs) + })) +} + +fn conv_mlp(dim: usize, exp_ratio: usize, vb: VarBuilder) -> Result> { + let conv2d_cfg = Conv2dConfig { + ..Default::default() + }; + + let conv = conv_norm(dim, dim, 7, 1, vb.pp("conv"))?; + let fc1 = conv2d(dim, dim * exp_ratio, 1, conv2d_cfg, vb.pp("fc1"))?; + let fc2 = conv2d(dim * exp_ratio, dim, 1, conv2d_cfg, vb.pp("fc2"))?; + + Ok(Func::new(move |xs| { + let xs = xs.apply(&conv)?.apply(&fc1)?.gelu_erf()?.apply(&fc2)?; + Ok(xs) + })) +} + +fn squeeze_and_excitation( + in_channels: usize, + squeeze_channels: usize, + vb: VarBuilder, +) -> Result> { + let conv2d_cfg = Conv2dConfig { + ..Default::default() + }; + let fc1 = conv2d(in_channels, squeeze_channels, 1, conv2d_cfg, vb.pp("fc1"))?; + let fc2 = conv2d(squeeze_channels, in_channels, 1, conv2d_cfg, vb.pp("fc2"))?; + + Ok(Func::new(move |xs| { + let residual = xs; + let xs = xs.mean_keepdim(D::Minus2)?.mean_keepdim(D::Minus1)?; + let xs = sigmoid(&xs.apply(&fc1)?.relu()?.apply(&fc2)?)?; + + residual.broadcast_mul(&xs) + })) +} + +// fuses a convolutional kernel and a batchnorm layer into a convolutional layer +// based on the _fuse_bn_tensor method in timm +// see https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/byobnet.py#L602 +fn fuse_conv_bn(weights: &Tensor, bn: BatchNorm) -> Result<(Tensor, Tensor)> { + let (gamma, beta) = bn.weight_and_bias().unwrap(); + let mu = bn.running_mean(); + let sigma = (bn.running_var() + bn.eps())?.sqrt(); + let gps = (gamma / sigma)?; + let bias = (beta - mu * &gps)?; + let weights = weights.broadcast_mul(&gps.reshape(((), 1, 1, 1))?)?; + + Ok((weights, bias)) +} + +fn mobileone_block( + in_channels: usize, + out_channels: usize, + kernel: usize, + stride: usize, + group_size: usize, + use_act: bool, + vb: VarBuilder, +) -> Result> { + let groups = if group_size == 0 { + 1 + } else { + in_channels / group_size + }; + + let padding = kernel / 2; + let conv2d_cfg = Conv2dConfig { + stride, + groups, + padding, + ..Default::default() + }; + + let mut w = Tensor::zeros( + (out_channels, in_channels / groups, kernel, kernel), + DType::F32, + vb.device(), + )?; + let dim = out_channels; + + let mut b = Tensor::zeros(dim, DType::F32, vb.device())?; + + let conv_kxk_bn = batch_norm(dim, 1e-5, vb.pp("conv_kxk.0.bn")); + let conv_kxk = conv2d_no_bias( + in_channels, + out_channels, + kernel, + conv2d_cfg, + vb.pp("conv_kxk.0.conv"), + ); + + if let (Ok(conv), Ok(bn)) = (conv_kxk, conv_kxk_bn) { + let (wk, bk) = fuse_conv_bn(conv.weight(), bn)?; + w = (w + wk)?; + b = (b + bk)?; + }; + + let conv_scale_bn = batch_norm(dim, 1e-5, vb.pp("conv_scale.bn")); + let conv_scale = conv2d_no_bias( + in_channels, + out_channels, + 1, + conv2d_cfg, + vb.pp("conv_scale.conv"), + ); + + if let (Ok(conv), Ok(bn)) = (conv_scale, conv_scale_bn) { + let (ws, bs) = fuse_conv_bn(conv.weight(), bn)?; + // pad to 3x3 + let ws = ws + .pad_with_zeros(D::Minus1, 1, 1)? + .pad_with_zeros(D::Minus2, 1, 1)?; + + w = (w + ws)?; + b = (b + bs)?; + }; + + let se = squeeze_and_excitation(out_channels, out_channels / 16, vb.pp("se")); + + // read and reparameterize the identity bn into wi and bi + let identity_bn = batch_norm(dim, 1e-5, vb.pp("identity")); + + if let Ok(id_bn) = identity_bn { + let mut weights: Vec = vec![0.0; w.elem_count()]; + let id = in_channels / groups; + // See https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/byobnet.py#L809 + for i in 0..in_channels { + if kernel > 1 { + weights[i * kernel * kernel + 4] = 1.0; + } else { + weights[i * (id + 1)] = 1.0; + } + } + + let weights = &Tensor::from_vec(weights, w.shape(), w.device())?; + let (wi, bi) = fuse_conv_bn(weights, id_bn)?; + + w = (w + wi)?; + b = (b + bi)?; + }; + let reparam_conv = Conv2d::new(w, Some(b), conv2d_cfg); + + Ok(Func::new(move |xs| { + let mut xs = xs.apply(&reparam_conv)?; + if let Ok(f) = &se { + xs = xs.apply(f)?; + } + if use_act { + xs = xs.gelu_erf()?; + }; + Ok(xs) + })) +} + +fn repmixer(dim: usize, kernel: usize, vb: VarBuilder) -> Result> { + let gamma = vb.get((dim, 1, 1), "layer_scale.gamma")?; + let norm = mobileone_block(dim, dim, kernel, 1, 1, false, vb.pp("norm"))?; + let mixer = mobileone_block(dim, dim, kernel, 1, 1, false, vb.pp("mixer"))?; + + Ok(Func::new(move |xs| { + let residual = xs.clone(); + let xs = (xs.apply(&mixer)? - xs.apply(&norm)?)?; + let xs = xs.broadcast_mul(&gamma.reshape((1, (), 1, 1))?)?; + let xs = (xs + residual)?; + Ok(xs) + })) +} + +fn repmixer_block(dim: usize, exp_ratio: usize, vb: VarBuilder) -> Result> { + let gamma = vb.get((dim, 1, 1), "layer_scale.gamma")?; + let token_mixer = repmixer(dim, 3, vb.pp("token_mixer"))?; + let mlp = conv_mlp(dim, exp_ratio, vb.pp("mlp"))?; + + Ok(Func::new(move |xs| { + let residual = xs.apply(&token_mixer)?; + let mut xs = residual.apply(&mlp)?; + xs = xs.broadcast_mul(&gamma.reshape((1, (), 1, 1))?)?; + let xs = (xs + residual)?; + Ok(xs) + })) +} + +fn positional_encoding(dim: usize, vb: VarBuilder) -> Result> { + let conv2d_cfg = Conv2dConfig { + stride: 1, + padding: 3, + groups: dim, + ..Default::default() + }; + + let conv = conv2d(dim, dim, 7, conv2d_cfg, vb.pp("pos_enc"))?; + + Ok(Func::new(move |xs| { + let xs = (xs + xs.apply(&conv)?)?; + Ok(xs) + })) +} + +fn attention(dim: usize, vb: VarBuilder) -> Result> { + let qkv = linear_no_bias(dim, dim * 3, vb.pp("qkv"))?; + let proj = linear(dim, dim, vb.pp("proj"))?; + let num_heads = 32; + let head_dim = dim / num_heads; + let scale = (head_dim as f64).powf(-0.5); + + Ok(Func::new(move |xs| { + let xs = xs.clone(); + let (b, c, h, w) = xs.dims4()?; + let n = h * w; + let xs = xs.flatten_from(2)?.transpose(D::Minus1, D::Minus2)?; + let qkv = xs + .apply(&qkv)? + .reshape((b, n, 3, num_heads, head_dim))? + .permute((2, 0, 3, 1, 4))?; + + let q = qkv.get(0)?; + let k = qkv.get(1)?; + let v = qkv.get(2)?; + + let q = (q * scale)?; + + let att = q.matmul(&k.transpose(D::Minus2, D::Minus1)?)?; + let att = softmax(&att, D::Minus1)?; + let xs = att.matmul(&v)?; + + let xs = xs.transpose(1, 2)?.reshape((b, n, c))?; + let xs = xs.apply(&proj)?; + let xs = xs.transpose(D::Minus1, D::Minus2)?.reshape((b, c, h, w))?; + + Ok(xs) + })) +} + +fn attention_block(dim: usize, exp_ratio: usize, vb: VarBuilder) -> Result> { + let gamma1 = vb.get((dim, 1, 1), "layer_scale_1.gamma")?; + let gamma2 = vb.get((dim, 1, 1), "layer_scale_2.gamma")?; + let norm = batch_norm(dim, 1e-5, vb.pp("norm"))?; + let token_mixer = attention(dim, vb.pp("token_mixer"))?; + let mlp = conv_mlp(dim, exp_ratio, vb.pp("mlp"))?; + + Ok(Func::new(move |xs| { + let xs = xs.clone(); + let xs = (&xs + + &xs + .apply_t(&norm, false)? + .apply(&token_mixer)? + .broadcast_mul(&gamma1.reshape((1, (), 1, 1))?)?)?; + + let xs = (&xs + + &xs + .apply(&mlp)? + .broadcast_mul(&gamma2.reshape((1, (), 1, 1))?)?)?; + + Ok(xs) + })) +} + +fn fastvit_stage(cfg: &Config, idx: usize, vb: VarBuilder) -> Result> { + let nblocks = cfg.blocks[idx]; + let mut blocks = Vec::with_capacity(nblocks); + + let dim = cfg.in_channels << idx; + let downsample = fastvit_patch_embed(dim / 2, dim, cfg.lkc_use_act, vb.pp("downsample")); + for block_idx in 0..nblocks { + let block = if cfg.attn && idx == 3 { + attention_block(dim, cfg.exp_ratio, vb.pp(format!("blocks.{block_idx}")))? + } else { + repmixer_block(dim, cfg.exp_ratio, vb.pp(format!("blocks.{block_idx}")))? + }; + blocks.push(block); + } + let pos_emb = positional_encoding(dim, vb.pp("pos_emb")); + + Ok(Func::new(move |xs| { + let mut xs = xs.clone(); + if let Ok(ds) = &downsample { + xs = xs.apply(ds)?; + } + if let Ok(pos) = &pos_emb { + xs = xs.apply(pos)?; + } + for block in blocks.iter() { + xs = xs.apply(block)?; + } + Ok(xs) + })) +} + +fn fastvit_patch_embed( + in_channels: usize, + out_channels: usize, + use_act: bool, + vb: VarBuilder, +) -> Result> { + let lk = conv_norm(in_channels, out_channels, 7, 2, vb.pp("proj.0.large_conv"))?; + let sk = conv_norm(in_channels, out_channels, 3, 2, vb.pp("proj.0.small_conv"))?; + let se = squeeze_and_excitation(out_channels, out_channels / 16, vb.pp("proj.0.se")); + let mb = mobileone_block(out_channels, out_channels, 1, 1, 0, true, vb.pp("proj.1"))?; + + Ok(Func::new(move |xs| { + let mut xs = (xs.apply(&lk)? + xs.apply(&sk)?)?; + if let Ok(f) = &se { + xs = xs.apply(f)?; + } + if use_act { + xs = xs.gelu_erf()?; + }; + let xs = xs.apply(&mb)?; + Ok(xs) + })) +} + +fn fastvit_stem(in_channels: usize, out_channels: usize, vb: VarBuilder) -> Result> { + let mb0 = mobileone_block(in_channels, out_channels, 3, 2, 0, true, vb.pp(0))?; + let mb1 = mobileone_block(out_channels, out_channels, 3, 2, 1, true, vb.pp(1))?; + let mb2 = mobileone_block(out_channels, out_channels, 1, 1, 0, true, vb.pp(2))?; + Ok(Func::new(move |xs| { + let xs = xs.apply(&mb0)?.apply(&mb1)?.apply(&mb2)?; + Ok(xs) + })) +} + +// Build a fastvit model for a given configuration. +fn fastvit_model(cfg: &Config, nclasses: Option, vb: VarBuilder) -> Result> { + let cls = match nclasses { + None => None, + Some(nclasses) => { + let linear = linear(cfg.in_channels * 16, nclasses, vb.pp("head.fc"))?; + Some(linear) + } + }; + + let stem = fastvit_stem(3, cfg.in_channels, vb.pp("stem"))?; + let final_conv = mobileone_block( + cfg.in_channels * 8, + cfg.in_channels * 16, + 3, + 1, + 1, + true, + vb.pp("final_conv"), + )?; + + let vb = vb.pp("stages"); + let stage1 = fastvit_stage(cfg, 0, vb.pp(0))?; + let stage2 = fastvit_stage(cfg, 1, vb.pp(1))?; + let stage3 = fastvit_stage(cfg, 2, vb.pp(2))?; + let stage4 = fastvit_stage(cfg, 3, vb.pp(3))?; + + Ok(Func::new(move |xs| { + let xs = xs + .apply(&stem)? + .apply(&stage1)? + .apply(&stage2)? + .apply(&stage3)? + .apply(&stage4)? + .apply(&final_conv)?; + + match &cls { + None => Ok(xs), + Some(cls) => xs.mean(D::Minus2)?.mean(D::Minus1)?.apply(cls), + } + })) +} + +pub fn fastvit(cfg: &Config, nclasses: usize, vb: VarBuilder) -> Result> { + fastvit_model(cfg, Some(nclasses), vb) +} + +pub fn fastvit_no_final_layer(cfg: &Config, vb: VarBuilder) -> Result> { + fastvit_model(cfg, None, vb) +} diff --git a/candle-transformers/src/models/mod.rs b/candle-transformers/src/models/mod.rs index cc83cf7bd..a234b8bb9 100644 --- a/candle-transformers/src/models/mod.rs +++ b/candle-transformers/src/models/mod.rs @@ -19,6 +19,7 @@ pub mod efficientvit; pub mod encodec; pub mod eva2; pub mod falcon; +pub mod fastvit; pub mod flux; pub mod gemma; pub mod gemma2;