Skip to content

Commit

Permalink
Add option for QKV clipping (#489)
Browse files Browse the repository at this point in the history
  • Loading branch information
epwalsh authored Mar 7, 2024
1 parent 31d8528 commit c499632
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 6 deletions.
3 changes: 1 addition & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added MMLU downstream evaluation tasks, with prompt variations.
- Added support for PyTorch v2.2.
- Added ability to show logs from all ranks


- Added option for QKV clipping.

### Changed

Expand Down
5 changes: 5 additions & 0 deletions olmo/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,11 @@ class ModelConfig(BaseConfig):
The number of self-attention heads.
"""

clip_qkv: Optional[float] = None
"""
Clip QKV to this value when set.
"""

n_layers: int = 12
"""
The number of layers/blocks.
Expand Down
25 changes: 21 additions & 4 deletions olmo/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,6 +452,10 @@ def __init__(self, layer_id: int, config: ModelConfig, cache: BufferCache):
)
self.q_norm = LayerNormBase.build(config, elementwise_affine=config.attention_layer_norm_with_affine)

# Make sure QKV clip coefficient is positive, otherwise it's not well-defined.
if config.clip_qkv is not None:
assert config.clip_qkv > 0

# Activation function.
self.act = Activation.build(config)
assert (self.act.output_multiplier * self.hidden_size) % 1 == 0
Expand Down Expand Up @@ -680,11 +684,14 @@ def forward(
# - for multi-query attn q: (batch_size, seq_len, d_model)
# k, v: (batch_size, seq_len, d_model // n_heads)
if self._activation_checkpoint_fn is not None:
q, k, v = self.att_proj(self._activation_checkpoint_fn(self.attn_norm, x)).split(
self.fused_dims, dim=-1
)
qkv = self.att_proj(self._activation_checkpoint_fn(self.attn_norm, x))
else:
q, k, v = self.att_proj(self.attn_norm(x)).split(self.fused_dims, dim=-1)
qkv = self.att_proj(self.attn_norm(x))

if self.config.clip_qkv is not None:
qkv.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)

q, k, v = qkv.split(self.fused_dims, dim=-1)

# Get attention scores.
if self._activation_checkpoint_fn is not None:
Expand Down Expand Up @@ -780,6 +787,11 @@ def forward(
else:
q, k, v, ff = self.fused_attn_ff_proj(self.norm(x)).split(self.fused_dims, dim=-1)

if self.config.clip_qkv is not None:
q.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
k.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
v.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)

# Get attention scores.
# shape: (B, T, C)
if self._activation_checkpoint_fn is not None:
Expand Down Expand Up @@ -896,6 +908,11 @@ def forward(
k = self.k_proj(x_normed)
v = self.v_proj(x_normed)

if self.config.clip_qkv is not None:
q.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
k.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
v.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)

# Get attention scores.
att, cache = self.attention(q, k, v, attention_bias, layer_past=layer_past, use_cache=use_cache)

Expand Down

0 comments on commit c499632

Please sign in to comment.