Skip to content

Commit

Permalink
Clamping t5 hidden states for f16 NaN explosion
Browse files Browse the repository at this point in the history
  • Loading branch information
EricLBuehler committed Sep 17, 2024
1 parent e326121 commit 265bc3b
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 4 deletions.
17 changes: 17 additions & 0 deletions candle-core/src/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2539,6 +2539,23 @@ impl Tensor {
pub fn broadcast_pow(&self, rhs: &Tensor) -> Result<Self> {
rhs.broadcast_mul(&self.log()?)?.exp()
}

pub fn is_inf(&self) -> Result<Self> {
self.broadcast_eq(&Tensor::new(f64::INFINITY, self.device())?.to_dtype(self.dtype)?)
}

pub fn any(&self) -> Result<bool> {
let sum = self.sum_all()?;
match self.dtype {
DType::U8 => Ok(sum.to_scalar::<u8>()? == 0),
DType::U32 => Ok(sum.to_scalar::<u32>()? == 0),
DType::I64 => Ok(sum.to_scalar::<i64>()? == 0),
DType::F16 => Ok(sum.to_scalar::<half::f16>()? == half::f16::from_f32_const(0.)),
DType::BF16 => Ok(sum.to_scalar::<half::bf16>()? == half::bf16::from_f32_const(0.)),
DType::F32 => Ok(sum.to_scalar::<f32>()? == 0.),
DType::F64 => Ok(sum.to_scalar::<f64>()? == 0.),
}
}
}

macro_rules! bin_trait {
Expand Down
1 change: 1 addition & 0 deletions candle-transformers/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ serde = { workspace = true }
serde_json = { workspace = true }
serde_plain = { workspace = true }
tracing = { workspace = true }
half = { workspace = true }

[features]
default = []
Expand Down
33 changes: 29 additions & 4 deletions candle-transformers/src/models/t5.rs
Original file line number Diff line number Diff line change
Expand Up @@ -577,6 +577,22 @@ impl T5LayerCrossAttention {
}
}

fn clamp_for_f16(xs: &Tensor) -> Result<Tensor> {
let mut max = match xs.dtype() {
DType::U8 => u8::MAX as f64 - 1000.,
DType::U32 => u32::MAX as f64 - 1000.,
DType::I64 => i64::MAX as f64 - 1000.,
DType::F16 => half::f16::MAX.to_f64_const() - 1000.,
DType::BF16 => half::bf16::MAX.to_f64_const() - 1000.,
DType::F32 => f32::MAX as f64 - 1000.,
DType::F64 => f64::MAX - 1000.,
};
if xs.is_inf()?.any()? {
max = max - 1000.;
}
xs.clamp(-max, max)
}

#[derive(Debug, Clone)]
struct T5Block {
self_attn: T5LayerSelfAttention,
Expand Down Expand Up @@ -632,13 +648,22 @@ impl T5Block {
false => None,
};
let (mut xs, position_bias) = self.self_attn.forward(xs, position_bias, mask.as_ref())?;
// TODO: clamp for f16?
// Clamp for f16
if xs.dtype() == DType::F16 {
xs = clamp_for_f16(&xs)?;
}
if let Some(cross_attn) = &mut self.cross_attn {
(xs, _) = cross_attn.forward(&xs, None, encoder_hidden_states.unwrap())?;
// TODO: clamp for f16?
// Clamp for f16
if xs.dtype() == DType::F16 {
xs = clamp_for_f16(&xs)?;
}
}
let mut xs = self.ff.forward(&xs)?;
// Clamp for f16
if xs.dtype() == DType::F16 {
xs = clamp_for_f16(&xs)?;
}
let xs = self.ff.forward(&xs)?;
// TODO: clamp for f16?
Ok((xs, position_bias))
}

Expand Down

0 comments on commit 265bc3b

Please sign in to comment.