Skip to content

Commit

Permalink
fix: Fixed Mixup for binary & one-hot targets (#225)
Browse files Browse the repository at this point in the history
* feat: Added support for binary and one-hot target

* fix: Fixed mixup in training

* test: Updated unittests
  • Loading branch information
frgfm authored Jul 16, 2022
1 parent 273f01f commit 49125c2
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 2 deletions.
8 changes: 7 additions & 1 deletion holocron/utils/data/collate.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,13 @@ def __init__(self, num_classes: int, alpha: float = 0.2) -> None:
def forward(self, inputs: Tensor, targets: Tensor) -> Tuple[Tensor, Tensor]:

# Convert target to one-hot
targets = one_hot(targets, num_classes=self.num_classes).to(dtype=inputs.dtype)
if targets.ndim == 1:
# (N,) --> (N, C)
if self.num_classes > 1:
targets = one_hot(targets, num_classes=self.num_classes)
elif self.num_classes == 1:
targets = targets.unsqueeze(1)
targets = targets.to(dtype=inputs.dtype)

# Sample lambda
if self.alpha == 0:
Expand Down
2 changes: 1 addition & 1 deletion references/classification/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def main(args):
num_classes = len(train_set.classes)
collate_fn = default_collate
if args.mixup_alpha > 0:
mix = Mixup(len(train_set.classes), alpha=0.2)
mix = Mixup(len(train_set.classes), alpha=args.mixup_alpha)
collate_fn = lambda batch: mix(*default_collate(batch)) # noqa: E731
train_loader = torch.utils.data.DataLoader(
train_set,
Expand Down
17 changes: 17 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,23 @@ def test_mixup():
assert torch.all(mix_target.sum(dim=1) == 1.0)
assert torch.all((mix_target > 0).sum(dim=1) == 1.0)

# Binary target
mix = utils.data.Mixup(1, alpha=0.5)
img = torch.rand((batch_size, *shape))
target = torch.concat((torch.zeros(batch_size // 2), torch.ones(batch_size - batch_size // 2)))
mix_img, mix_target = mix(img.clone(), target.clone())
assert img.shape == (batch_size, *shape)
assert not torch.equal(img, mix_img)
assert mix_target.dtype == torch.float32 and mix_target.shape == (batch_size, 1)

# Already in one-hot
mix = utils.data.Mixup(num_classes, alpha=0.2)
img, target = torch.rand((batch_size, *shape)), torch.rand((batch_size, num_classes))
mix_img, mix_target = mix(img.clone(), target.clone())
assert img.shape == (batch_size, *shape)
assert not torch.equal(img, mix_img)
assert mix_target.dtype == torch.float32 and mix_target.shape == (batch_size, num_classes)


def _train_one_batch(model, x, target, optimizer, criterion, device):
"""Mock batch training function"""
Expand Down

0 comments on commit 49125c2

Please sign in to comment.