Skip to content

Commit

Permalink
Input wise masks for mask gradients (#4)
Browse files Browse the repository at this point in the history
Adds support for computing input-wise mask gradients. Useful for e.g. doing anomaly detection using edge attribution scores
  • Loading branch information
oliveradk authored Jul 19, 2024
1 parent e3ddca7 commit 3ad51b5
Show file tree
Hide file tree
Showing 4 changed files with 133 additions and 3 deletions.
24 changes: 24 additions & 0 deletions auto_circuit/utils/graph_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,6 +398,30 @@ def mask_fn_mode(model: PatchableModel, mask_fn: MaskFn, dropout_p: float = 0.0)
wrapper.dropout_layer.p = 0.0 # type: ignore


@contextmanager
def set_mask_batch_size(model: PatchableModel, batch_size: int | None):
"""
Context manager to set the batch size of the patch masks in the model.
Args:
model: The patchable model to alter.
batch_size: The batch size to set the patch masks to. If `None`, the batch size
is not modified.
Warning:
This function breaks other functions of the library while the context is active
and should be considered an experimental feature.
This function modifies the state of the model! This is a likely source of bugs.
"""
for wrapper in model.dest_wrappers:
wrapper.set_mask_batch_size(batch_size)
try:
yield
finally:
for wrapper in model.dest_wrappers:
wrapper.set_mask_batch_size(None)


def edge_counts_util(
edges: Set[Edge],
test_counts: Optional[TestEdges] = None, # None means default
Expand Down
37 changes: 36 additions & 1 deletion auto_circuit/utils/patch_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,11 +85,43 @@ def __init__(
self.mask_fn: MaskFn = None
self.dropout_layer: t.nn.Module = t.nn.Dropout(p=0.0)
self.patch_mode = False
self.batch_size = None

assert head_dim is None or seq_dim is None or head_dim > seq_dim
dims = range(1, max(head_dim if head_dim else 2, seq_dim if seq_dim else 2))
self.dims = " ".join(["seq" if i == seq_dim else f"d{i}" for i in dims])

def set_mask_batch_size(self, batch_size: int | None):
"""
Set the batch size of the patch mask. Should only be used by context manager
[`set_mask_batch_size`][auto_circuit.utils.graph_utils.set_mask_batch_size]
The current primary use case is to collect gradients on the patch mask for
each input in the batch.
Warning:
This is an exmperimental feature that breaks some parts of the library and
should be used with caution.
Args:
batch_size: The batch size of the patch mask.
"""
if batch_size is None and self.batch_size is None:
return
if batch_size is None: # removing batch dim
self.patch_mask = t.nn.Parameter(self.patch_mask[0].clone())
elif self.batch_size is None: # adding batch_dim
self.patch_mask = t.nn.Parameter(
self.patch_mask.repeat(batch_size, *((1,) * self.patch_mask.ndim))
)
elif self.batch_size != batch_size: # modifying batch dim
self.patch_mask = t.nn.Parameter(
self.patch_mask[0]
.clone()
.repeat(batch_size, *((1,) * self.patch_mask.ndim))
)
self.batch_size = batch_size

def forward(self, *args: Any, **kwargs: Any) -> Any:
arg_0: t.Tensor = args[0].clone()

Expand All @@ -100,12 +132,15 @@ def forward(self, *args: Any, **kwargs: Any) -> Any:
head_str = "" if self.head_dim is None else "dest" # Patch heads separately
seq_str = "" if self.seq_dim is None else "seq" # Patch tokens separately
if self.mask_fn == "hard_concrete":
mask = sample_hard_concrete(self.patch_mask, arg_0.size(0))
mask = sample_hard_concrete(
self.patch_mask, arg_0.size(0), self.batch_size is not None
)
batch_str = "batch" # Sample distribution for each batch element
elif self.mask_fn == "sigmoid":
mask = t.sigmoid(self.patch_mask)
else:
assert self.mask_fn is None
batch_str = "batch" if self.batch_size is not None else ""
mask = self.patch_mask
mask = self.dropout_layer(mask)
ein_pre = f"{batch_str} {seq_str} {head_str} src, src batch {self.dims} ..."
Expand Down
10 changes: 8 additions & 2 deletions auto_circuit/utils/tensor_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,20 +9,26 @@
left, right, temp = -0.1, 1.1, 2 / 3


def sample_hard_concrete(mask: t.Tensor, batch_size: int) -> t.Tensor:
def sample_hard_concrete(
mask: t.Tensor, batch_size: int, mask_expanded: bool
) -> t.Tensor:
"""
Sample from the hard concrete distribution
([Louizos et al., 2017](https://arxiv.org/abs/1712.01312)).
Args:
mask: The mask whose values parameterize the distribution.
batch_size: The number of samples to draw.
mask_expanded: Whether the mask has a batch dimension at the start.
Returns:
A sample for each element in the mask for each batch element. The returned
tensor has shape `(batch_size, *mask.shape)`.
"""
mask = mask.repeat(batch_size, *([1] * mask.ndim))
if not mask_expanded:
mask = mask.repeat(batch_size, *([1] * mask.ndim))
else:
assert mask.size(0) == batch_size
u = t.zeros_like(mask).uniform_().clamp(0.0001, 0.9999)
s = t.sigmoid((u.log() - (1 - u).log() + mask) / temp)
s_bar = s * (right - left) + left
Expand Down
65 changes: 65 additions & 0 deletions tests/utils/test_instance_grads.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
#%%

from collections import defaultdict
from typing import Dict, List

import torch as t
from transformer_lens import HookedTransformer

from auto_circuit.tasks import Task
from auto_circuit.types import AblationType, PruneScores
from auto_circuit.utils.ablation_activations import src_ablations
from auto_circuit.utils.graph_utils import (
patch_mode,
set_all_masks,
set_mask_batch_size,
train_mask_mode,
)
from auto_circuit.utils.tensor_ops import batch_avg_answer_diff


def test_instance_grads(mini_tl_transformer: HookedTransformer):

# create task
batch_size = 2
batch_count = 1
task = Task(
key="test_eap",
name="test_eap",
batch_size=batch_size,
batch_count=batch_count,
token_circuit=False,
_model_def=mini_tl_transformer,
_dataset_name="mini_prompts"
)
model = task.model
train_loader = task.train_loader

# compute src patch out
src_patch_out = src_ablations(
model, next(iter(train_loader)).clean, ablation_type=AblationType.ZERO
)

# collecting prune scores batches for each module, concatented after
prune_scores_batch: PruneScores = {}
with set_mask_batch_size(model, batch_size), train_mask_mode(model):
set_all_masks(model, val=0.0)
for batch in train_loader:
with patch_mode(model, src_patch_out.clone().detach()):
# combine clean and corrupt to get differet values for testing
logits = model(t.cat([batch.clean[0:1], batch.corrupt[0:1]]))[model.out_slice]
loss = -batch_avg_answer_diff(logits, batch)
loss.backward(t.ones_like(loss))
for dest_wrapper in model.dest_wrappers:
assert dest_wrapper.patch_mask.size(0) == batch_size
grad = dest_wrapper.patch_mask.grad.detach().clone()
prune_scores_batch[dest_wrapper.module_name] = grad
model.zero_grad()

ex_prune_score = next(iter(prune_scores_batch.values()))
# check expanded batch size
assert ex_prune_score.size(0) == batch_size
# check gradients are not the same
assert not t.allclose(ex_prune_score[0], ex_prune_score[1])
# check masks collapsed on exit
assert next(iter(model.dest_wrappers)).patch_mask.ndim == ex_prune_score.ndim - 1

0 comments on commit 3ad51b5

Please sign in to comment.