Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Contextualized bias mitigation #5176

Merged
merged 58 commits into from
Jun 2, 2021
Merged
Show file tree
Hide file tree
Changes from 57 commits
Commits
Show all changes
58 commits
Select commit Hold shift + click to select a range
79c6c33
added linear and hard debiasers
Apr 13, 2021
e23057c
worked on documentation
Apr 14, 2021
fcc3d34
committing changes before branch switch
Apr 14, 2021
7d00910
committing changes before switching branch
Apr 15, 2021
668a513
finished bias direction, linear and hard debiasers, need to write tests
Apr 15, 2021
91029ef
finished bias direction test
Apr 15, 2021
396b245
Commiting changes before switching branch
Apr 16, 2021
a8c22a1
finished hard and linear debiasers
Apr 16, 2021
ef6a062
finished OSCaR
Apr 17, 2021
2c873cb
bias mitigators tests and bias metrics remaining
Apr 17, 2021
d97a526
added bias mitigator tests
Apr 18, 2021
8460281
added bias mitigator tests
Apr 18, 2021
5a76922
finished tests for bias mitigation methods
Apr 19, 2021
85cb107
Merge remote-tracking branch 'origin/main' into arjuns/post-processin…
Apr 19, 2021
8e55f28
fixed gpu issues
Apr 19, 2021
b42b73a
fixed gpu issues
Apr 19, 2021
37d8e33
fixed gpu issues
Apr 20, 2021
31b1d2c
resolve issue with count_nonzero not being differentiable
Apr 20, 2021
a1f4f2a
merged main into post-processing-debiasing
Apr 21, 2021
36cebe3
added more references
Apr 21, 2021
88c083b
Merge branch 'main' of https://github.com/allenai/allennlp into arjun…
Apr 28, 2021
86081ee
fairness during finetuning
Apr 29, 2021
ae592d8
finished bias mitigator wrapper
May 5, 2021
2501b8c
added reference
May 5, 2021
f664dfb
updated CHANGELOG and fixed minor docs issues
May 5, 2021
595449d
move id tensors to embedding device
May 5, 2021
dc4793f
Merge branch 'main' into arjuns/contextualized-bias-mitigation
ArjunSubramonian May 6, 2021
0cdcf89
fixed to use predetermined bias direction
May 6, 2021
f254128
fixed minor doc errors
May 6, 2021
1be00c8
snli reader registration issue
May 6, 2021
a6c9bf6
fixed _pretrained from params issue
May 6, 2021
6624680
fixed device issues
May 6, 2021
90a372e
evaluate bias mitigation initial commit
May 9, 2021
c6a2dbf
finished evaluate bias mitigation
May 10, 2021
7797659
handles multiline prediction files
May 10, 2021
bbfddd7
fixed minor bugs
May 11, 2021
f2f3fc3
fixed minor bugs
May 11, 2021
4e79de7
improved prediction diff JSON format
May 11, 2021
5dae69f
merged main
May 11, 2021
254676f
forgot to resolve a conflict
May 11, 2021
26d8dff
Merge branch 'main' of https://github.com/allenai/allennlp into arjun…
May 13, 2021
1ae5e99
Refactored evaluate bias mitigation to use NLI metric
May 13, 2021
e2cc38e
Added SNLIPredictionsDiff class
May 17, 2021
c34cf31
ensured dataloader is same for bias mitigated and baseline models
May 17, 2021
fdb9ea7
finished evaluate bias mitigation
May 18, 2021
3efffd2
Merge branch 'main' into arjuns/contextualized-bias-mitigation
AkshitaB May 18, 2021
c47de58
Update CHANGELOG.md
AkshitaB May 18, 2021
2b8cf09
Merge branch 'main' of https://github.com/allenai/allennlp into arjun…
May 20, 2021
33d6267
Replaced local data files with github raw content links
May 20, 2021
ec53a05
Update allennlp/fairness/bias_mitigator_applicator.py
ArjunSubramonian May 25, 2021
4afb7f2
deleted evaluate_bias_mitigation from git tracking
May 26, 2021
21bed9d
removed evaluate-bias-mitigation instances from rest of repo
May 26, 2021
fefcbad
Merge branch 'arjuns/contextualized-bias-mitigation' of https://githu…
May 26, 2021
972ea60
addressed Akshita's comments
May 26, 2021
b4011cb
moved bias mitigator applicator test to allennlp-models
Jun 2, 2021
4d7fffb
Merge branch 'main' into arjuns/contextualized-bias-mitigation
AkshitaB Jun 2, 2021
22a5964
removed unnecessary files
Jun 2, 2021
bd727dd
Merge branch 'main' into arjuns/contextualized-bias-mitigation
ArjunSubramonian Jun 2, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Added

- Added `TaskSuite` base class and command line functionality for running [`checklist`](https://github.com/marcotcr/checklist) test suites, along with implementations for `SentimentAnalysisSuite`, `QuestionAnsweringSuite`, and `TextualEntailmentSuite`. These can be found in the `allennlp.confidence_checks.task_checklists` module.
- Added `BiasMitigatorApplicator`, which wraps any Model and mitigates biases by finetuning
on a downstream task.
- Added `allennlp diff` command to compute a diff on model checkpoints, analogous to what `git diff` does on two files.
- Meta data defined by the class `allennlp.common.meta.Meta` is now saved in the serialization directory and archive file
when training models from the command line. This is also now part of the `Archive` named tuple that's returned from `load_archive()`.
Expand Down Expand Up @@ -52,7 +54,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
is still possible if used with `force_extract=True`.
- Fixed `wandb` callback to work in distributed training.


## [v2.4.0](https://github.com/allenai/allennlp/releases/tag/v2.4.0) - 2021-04-22

### Added
Expand All @@ -78,8 +79,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Add new dimension to the `interpret` module: influence functions via the `InfluenceInterpreter` base class, along with a concrete implementation: `SimpleInfluence`.
- Added a `quiet` parameter to the `MultiProcessDataLoading` that disables `Tqdm` progress bars.
- The test for distributed metrics now takes a parameter specifying how often you want to run it.
- Created the fairness module and added four fairness metrics: `Independence`, `Separation`, and `Sufficiency`.
- Added three bias metrics to the fairness module: `WordEmbeddingAssociationTest`, `EmbeddingCoherenceTest`, `NaturalLanguageInference`, and `AssociationWithoutGroundTruth`.
- Created the fairness module and added three fairness metrics: `Independence`, `Separation`, and `Sufficiency`.
- Added four bias metrics to the fairness module: `WordEmbeddingAssociationTest`, `EmbeddingCoherenceTest`, `NaturalLanguageInference`, and `AssociationWithoutGroundTruth`.
- Added four bias direction methods (`PCABiasDirection`, `PairedPCABiasDirection`, `TwoMeansBiasDirection`, `ClassificationNormalBiasDirection`) and four bias mitigation methods (`LinearBiasMitigator`, `HardBiasMitigator`, `INLPBiasMitigator`, `OSCaRBiasMitigator`).

### Changed
Expand Down
17 changes: 16 additions & 1 deletion allennlp/fairness/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
1. measure the fairness of models according to multiple definitions of fairness
2. measure bias amplification
3. debias embeddings during training time and post-processing
3. mitigate bias in static and contextualized embeddings during training time and
post-processing
"""

from allennlp.fairness.fairness_metrics import Independence, Separation, Sufficiency
Expand All @@ -25,3 +26,17 @@
INLPBiasMitigator,
OSCaRBiasMitigator,
)
from allennlp.fairness.bias_utils import load_words, load_word_pairs
from allennlp.fairness.bias_mitigator_applicator import BiasMitigatorApplicator
from allennlp.fairness.bias_mitigator_wrappers import (
HardBiasMitigatorWrapper,
LinearBiasMitigatorWrapper,
INLPBiasMitigatorWrapper,
OSCaRBiasMitigatorWrapper,
)
from allennlp.fairness.bias_direction_wrappers import (
PCABiasDirectionWrapper,
PairedPCABiasDirectionWrapper,
TwoMeansBiasDirectionWrapper,
ClassificationNormalBiasDirectionWrapper,
)
269 changes: 269 additions & 0 deletions allennlp/fairness/bias_direction_wrappers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,269 @@
import torch
from typing import Union, Optional
from os import PathLike

from allennlp.fairness.bias_direction import (
BiasDirection,
PCABiasDirection,
PairedPCABiasDirection,
TwoMeansBiasDirection,
ClassificationNormalBiasDirection,
)
from allennlp.fairness.bias_utils import load_word_pairs, load_words

from allennlp.common import Registrable
from allennlp.data.tokenizers.tokenizer import Tokenizer
from allennlp.data import Vocabulary


class BiasDirectionWrapper(Registrable):
"""
Parent class for bias direction wrappers.
"""

def __init__(self):
self.direction: BiasDirection = None
self.noise: float = None

def __call__(self, module):
raise NotImplementedError

def train(self, mode: bool = True):
"""
# Parameters
mode : `bool`, optional (default=`True`)
Sets `requires_grad` to value of `mode` for bias direction.
"""
self.direction.requires_grad = mode

def add_noise(self, t: torch.Tensor):
"""
# Parameters
t : `torch.Tensor`
Tensor to which to add small amount of Gaussian noise.
"""
return t + self.noise * torch.randn(t.size(), device=t.device)


@BiasDirectionWrapper.register("pca")
class PCABiasDirectionWrapper(BiasDirectionWrapper):
"""
# Parameters
seed_words_file : `Union[PathLike, str]`
Path of file containing seed words.
tokenizer : `Tokenizer`
Tokenizer used to tokenize seed words.
direction_vocab : `Vocabulary`, optional (default=`None`)
Vocabulary of tokenizer. If `None`, assumes tokenizer is of
type `PreTrainedTokenizer` and uses tokenizer's `vocab` attribute.
namespace : `str`, optional (default=`"tokens"`)
Namespace of direction_vocab to use when tokenizing.
Disregarded when direction_vocab is `None`.
requires_grad : `bool`, optional (default=`False`)
Option to enable gradient calculation for bias direction.
noise : `float`, optional (default=`1e-10`)
To avoid numerical instability if embeddings are initialized uniformly.
"""

def __init__(
self,
seed_words_file: Union[PathLike, str],
tokenizer: Tokenizer,
direction_vocab: Optional[Vocabulary] = None,
namespace: str = "tokens",
requires_grad: bool = False,
noise: float = 1e-10,
):
self.ids = load_words(seed_words_file, tokenizer, direction_vocab, namespace)
self.direction = PCABiasDirection(requires_grad=requires_grad)
self.noise = noise

def __call__(self, module):
# embed subword token IDs and mean pool to get
# embedding of original word
ids_embeddings = []
for i in self.ids:
i = i.to(module.weight.device)
ids_embeddings.append(torch.mean(module.forward(i), dim=0, keepdim=True))
ids_embeddings = torch.cat(ids_embeddings)

# adding trivial amount of noise
# to eliminate linear dependence amongst all embeddings
# when training first starts
ids_embeddings = self.add_noise(ids_embeddings)

return self.direction(ids_embeddings)


@BiasDirectionWrapper.register("paired_pca")
class PairedPCABiasDirectionWrapper(BiasDirectionWrapper):
"""
# Parameters
seed_word_pairs_file : `Union[PathLike, str]`
Path of file containing seed word pairs.
tokenizer : `Tokenizer`
Tokenizer used to tokenize seed words.
direction_vocab : `Vocabulary`, optional (default=`None`)
Vocabulary of tokenizer. If `None`, assumes tokenizer is of
type `PreTrainedTokenizer` and uses tokenizer's `vocab` attribute.
namespace : `str`, optional (default=`"tokens"`)
Namespace of direction_vocab to use when tokenizing.
Disregarded when direction_vocab is `None`.
requires_grad : `bool`, optional (default=`False`)
Option to enable gradient calculation for bias direction.
noise : `float`, optional (default=`1e-10`)
To avoid numerical instability if embeddings are initialized uniformly.
"""

def __init__(
self,
seed_word_pairs_file: Union[PathLike, str],
tokenizer: Tokenizer,
direction_vocab: Optional[Vocabulary] = None,
namespace: str = "tokens",
requires_grad: bool = False,
noise: float = 1e-10,
):
self.ids1, self.ids2 = load_word_pairs(
seed_word_pairs_file, tokenizer, direction_vocab, namespace
)
self.direction = PairedPCABiasDirection(requires_grad=requires_grad)
self.noise = noise

def __call__(self, module):
# embed subword token IDs and mean pool to get
# embedding of original word
ids1_embeddings = []
for i in self.ids1:
i = i.to(module.weight.device)
ids1_embeddings.append(torch.mean(module.forward(i), dim=0, keepdim=True))
ids2_embeddings = []
for i in self.ids2:
i = i.to(module.weight.device)
ids2_embeddings.append(torch.mean(module.forward(i), dim=0, keepdim=True))
ids1_embeddings = torch.cat(ids1_embeddings)
ids2_embeddings = torch.cat(ids2_embeddings)

ids1_embeddings = self.add_noise(ids1_embeddings)
ids2_embeddings = self.add_noise(ids2_embeddings)

return self.direction(ids1_embeddings, ids2_embeddings)


@BiasDirectionWrapper.register("two_means")
class TwoMeansBiasDirectionWrapper(BiasDirectionWrapper):
"""
# Parameters
seed_word_pairs_file : `Union[PathLike, str]`
Path of file containing seed word pairs.
tokenizer : `Tokenizer`
Tokenizer used to tokenize seed words.
direction_vocab : `Vocabulary`, optional (default=`None`)
Vocabulary of tokenizer. If `None`, assumes tokenizer is of
type `PreTrainedTokenizer` and uses tokenizer's `vocab` attribute.
namespace : `str`, optional (default=`"tokens"`)
Namespace of direction_vocab to use when tokenizing.
Disregarded when direction_vocab is `None`.
requires_grad : `bool`, optional (default=`False`)
Option to enable gradient calculation for bias direction.
noise : `float`, optional (default=`1e-10`)
To avoid numerical instability if embeddings are initialized uniformly.
"""

def __init__(
self,
seed_word_pairs_file: Union[PathLike, str],
tokenizer: Tokenizer,
direction_vocab: Optional[Vocabulary] = None,
namespace: str = "tokens",
requires_grad: bool = False,
noise: float = 1e-10,
):
self.ids1, self.ids2 = load_word_pairs(
seed_word_pairs_file, tokenizer, direction_vocab, namespace
)
self.direction = TwoMeansBiasDirection(requires_grad=requires_grad)
self.noise = noise

def __call__(self, module):
# embed subword token IDs and mean pool to get
# embedding of original word
ids1_embeddings = []
for i in self.ids1:
i = i.to(module.weight.device)
ids1_embeddings.append(torch.mean(module.forward(i), dim=0, keepdim=True))
ids2_embeddings = []
for i in self.ids2:
i = i.to(module.weight.device)
ids2_embeddings.append(torch.mean(module.forward(i), dim=0, keepdim=True))
ids1_embeddings = torch.cat(ids1_embeddings)
ids2_embeddings = torch.cat(ids2_embeddings)

ids1_embeddings = self.add_noise(ids1_embeddings)
ids2_embeddings = self.add_noise(ids2_embeddings)

return self.direction(ids1_embeddings, ids2_embeddings)


@BiasDirectionWrapper.register("classification_normal")
class ClassificationNormalBiasDirectionWrapper(BiasDirectionWrapper):
"""
# Parameters
seed_word_pairs_file : `Union[PathLike, str]`
Path of file containing seed word pairs.
tokenizer : `Tokenizer`
Tokenizer used to tokenize seed words.
direction_vocab : `Vocabulary`, optional (default=`None`)
Vocabulary of tokenizer. If `None`, assumes tokenizer is of
type `PreTrainedTokenizer` and uses tokenizer's `vocab` attribute.
namespace : `str`, optional (default=`"tokens"`)
Namespace of direction_vocab to use when tokenizing.
Disregarded when direction_vocab is `None`.
noise : `float`, optional (default=`1e-10`)
To avoid numerical instability if embeddings are initialized uniformly.
"""

def __init__(
self,
seed_word_pairs_file: Union[PathLike, str],
tokenizer: Tokenizer,
direction_vocab: Optional[Vocabulary] = None,
namespace: str = "tokens",
noise: float = 1e-10,
):
self.ids1, self.ids2 = load_word_pairs(
seed_word_pairs_file, tokenizer, direction_vocab, namespace
)
self.direction = ClassificationNormalBiasDirection()
self.noise = noise

def __call__(self, module):
# embed subword token IDs and mean pool to get
# embedding of original word
ids1_embeddings = []
for i in self.ids1:
i = i.to(module.weight.device)
ids1_embeddings.append(torch.mean(module.forward(i), dim=0, keepdim=True))
ids2_embeddings = []
for i in self.ids2:
i = i.to(module.weight.device)
ids2_embeddings.append(torch.mean(module.forward(i), dim=0, keepdim=True))
ids1_embeddings = torch.cat(ids1_embeddings)
ids2_embeddings = torch.cat(ids2_embeddings)

ids1_embeddings = self.add_noise(ids1_embeddings)
ids2_embeddings = self.add_noise(ids2_embeddings)

return self.direction(ids1_embeddings, ids2_embeddings)
2 changes: 2 additions & 0 deletions allennlp/fairness/bias_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,8 @@ class NaturalLanguageInference(Metric):
3. Threshold:tau (T:tau): A parameterized measure that reports the fraction
of examples whose probability of neutral is above tau.
# Parameters
neutral_label : `int`, optional (default=`2`)
The discrete integer label corresponding to a neutral entailment prediction.
taus : `List[float]`, optional (default=`[0.5, 0.7]`)
Expand Down
Loading