Skip to content

Commit

Permalink
Add the causal mask in mixformer. (huggingface#937)
Browse files Browse the repository at this point in the history
  • Loading branch information
LaurentMazare authored Sep 23, 2023
1 parent b54acfa commit 7582937
Showing 1 changed file with 33 additions and 5 deletions.
38 changes: 33 additions & 5 deletions candle-transformers/src/models/mixformer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,20 @@ impl Module for Embedding {
}
}

fn get_mask(size: usize, device: &Device) -> Result<Tensor> {
let mask: Vec<_> = (0..size)
.flat_map(|i| (0..size).map(move |j| u8::from(j > i)))
.collect();
Tensor::from_slice(&mask, (size, size), device)
}

fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> {
let shape = mask.shape();
let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?;
let m = mask.where_cond(&on_true, on_false)?;
Ok(m)
}

#[derive(Debug)]
struct RotaryEmbedding {
sin: Tensor,
Expand Down Expand Up @@ -198,6 +212,7 @@ struct MHA {
rotary_emb: RotaryEmbedding,
kv_cache: Option<(Tensor, Tensor)>,
head_dim: usize,
n_head: usize,
softmax_scale: f64,
span: tracing::Span,
}
Expand All @@ -214,14 +229,15 @@ impl MHA {
wqkv,
out_proj,
head_dim,
n_head: cfg.n_head,
kv_cache: None,
rotary_emb,
softmax_scale,
span: tracing::span!(tracing::Level::TRACE, "mha"),
})
}

fn forward(&mut self, xs: &Tensor) -> Result<Tensor> {
fn forward(&mut self, xs: &Tensor, mask: Option<&Tensor>) -> Result<Tensor> {
let _enter = self.span.enter();
let (b_size, seq_len, _n_embd) = xs.dims3()?;
let qkv = self
Expand Down Expand Up @@ -249,9 +265,16 @@ impl MHA {
let v = v.transpose(1, 2)?.flatten_to(1)?; // b*h, s, d
let attn_weights = (q.matmul(&k.t()?)? * self.softmax_scale)?; // b*h, t, s

// TODO: Add the causal mask.
// causal_mask = torch.triu(torch.full((seqlen_q, seqlen_k), -10000.0, device=scores.device), 1)
// scores = scores + causal_mask.to(dtype=scores.dtype)
let attn_weights = match mask {
None => attn_weights,
Some(mask) => masked_fill(
&attn_weights,
&mask.broadcast_left(b_size * self.n_head)?,
f32::NEG_INFINITY,
)?,
};
let attn_weights = candle_nn::ops::softmax(&attn_weights, D::Minus1)?;

// output = torch.einsum('bhts,bshd->bthd', attention_drop, v)
Expand Down Expand Up @@ -287,11 +310,11 @@ impl ParallelBlock {
})
}

fn forward(&mut self, xs: &Tensor) -> Result<Tensor> {
fn forward(&mut self, xs: &Tensor, mask: Option<&Tensor>) -> Result<Tensor> {
let _enter = self.span.enter();
let residual = xs;
let xs = xs.apply(&self.ln)?;
let attn_outputs = self.mixer.forward(&xs)?;
let attn_outputs = self.mixer.forward(&xs, mask)?;
let feed_forward_hidden_states = self.mlp.forward(&xs)?;
attn_outputs + feed_forward_hidden_states + residual
}
Expand Down Expand Up @@ -327,8 +350,13 @@ impl MixFormerSequentialForCausalLM {
let _enter = self.span.enter();
let (_b_size, seq_len) = xs.dims2()?;
let mut xs = xs.apply(&self.embedding)?;
let mask = if seq_len <= 1 {
None
} else {
Some(get_mask(seq_len, xs.device())?)
};
for block in self.blocks.iter_mut() {
xs = block.forward(&xs)?
xs = block.forward(&xs, mask.as_ref())?
}
xs.narrow(1, seq_len - 1, 1)?.apply(&self.head)?.squeeze(1)
}
Expand Down

0 comments on commit 7582937

Please sign in to comment.