Skip to content

Commit

Permalink
Merge pull request #367 from adefossez/fix_commitment_loss
Browse files Browse the repository at this point in the history
Fix commitment loss
  • Loading branch information
JadeCopet authored Dec 12, 2023
2 parents 8aa48dd + f91888b commit 5c7ea98
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 2 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).

Adding stereo models.

Fixed the commitment loss, which was until now only applied to the first RVQ layer.

Removed compression model state from the LM checkpoints, for consistency, it
should always be loaded from the original `compression_model_checkpoint`.

Expand Down
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@ AudioCraft requires Python 3.9, PyTorch 2.0.0. To install AudioCraft, you can ru
```shell
# Best to make sure you have torch installed first, in particular before installing xformers.
# Don't run this if you already have PyTorch installed.
python -m pip install 'torch>=2.0'
python -m pip install 'torch==2.1.0'
# You might need the following before trying to install the packages
python -m pip install setuptools wheel
# Then proceed to one of the following
python -m pip install -U audiocraft # stable release
python -m pip install -U git+https://[email protected]/facebookresearch/audiocraft#egg=audiocraft # bleeding edge
Expand Down
5 changes: 5 additions & 0 deletions audiocraft/quantization/core_vq.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,11 +371,16 @@ def forward(self, x, n_q: tp.Optional[int] = None):

for i, layer in enumerate(self.layers[:n_q]):
quantized, indices, loss = layer(residual)
quantized = quantized.detach()
residual = residual - quantized
quantized_out = quantized_out + quantized
all_indices.append(indices)
all_losses.append(loss)

if self.training:
# Solving subtle bug with STE and RVQ: https://github.com/facebookresearch/encodec/issues/25
quantized_out = x + (quantized_out - x).detach()

out_losses, out_indices = map(torch.stack, (all_losses, all_indices))
return quantized_out, out_indices, out_losses

Expand Down
4 changes: 3 additions & 1 deletion tests/quantization/test_vq.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@
class TestResidualVectorQuantizer:

def test_rvq(self):
x = torch.randn(1, 16, 2048)
x = torch.randn(1, 16, 2048, requires_grad=True)
vq = ResidualVectorQuantizer(n_q=8, dimension=16, bins=8)
res = vq(x, 1.)
assert res.x.shape == torch.Size([1, 16, 2048])
res.x.sum().backward()
assert torch.allclose(x.grad.data, torch.ones(1))

0 comments on commit 5c7ea98

Please sign in to comment.