Skip to content

Commit

Permalink
remove duplicate seanet, add py.typed
Browse files Browse the repository at this point in the history
  • Loading branch information
Tuan Tran committed Feb 21, 2024
1 parent 9ea40cc commit ba9064b
Show file tree
Hide file tree
Showing 7 changed files with 6 additions and 38 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ print(message) # message is a binary vector of 16 bits
# To detect the messages in the low-level.
result, message = detector(watermarked_audio)

# result is a tensor of size batch x 2 x frames, indicating the probability (positive and negative) of watermarking for each frame
# result is a tensor of size batch x 2 x frames, indicating the probablity (positive and negative) of watermarking for each frame
# A watermarked audio should have result[:, 1, :] > 0.5
print(result[:, 1 , :])

Expand Down
Empty file added generator.pth
Empty file.
Binary file added generator.pth.1
Binary file not shown.
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@ dependencies = [
"pre-commit",
]

[tool.setuptools.package-data]
"audioseal" = ["py.typed", "cards/*.yaml"]

[tool.flake8]
extend_ignore = ["E", "Y"] # Black
per-file-ignores = [
Expand Down
4 changes: 1 addition & 3 deletions src/audioseal/libs/audiocraft/modules/seanet.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,9 +264,7 @@ def forward(self, x):
x = self.model(x)
x = self.reverse_convolution(x)
# make sure dim didn't change
x = x[:, :, :orig_nframes]
return x

return x[:, :, :orig_nframes]

class SEANetDecoder(nn.Module):
"""SEANet decoder.
Expand Down
35 changes: 1 addition & 34 deletions src/audioseal/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,11 @@
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import math
from typing import Optional, Tuple

import torch

from audioseal.libs.audiocraft.modules.seanet import SEANetEncoder
from audioseal.libs.audiocraft.modules.seanet import SEANetEncoderKeepDimension


class MsgProcessor(torch.nn.Module):
Expand Down Expand Up @@ -105,38 +104,6 @@ def forward(
return x + alpha * wm


class SEANetEncoderKeepDimension(SEANetEncoder):
"""
similar architecture to the audiocraft.SEANet encoder but with an extra step that
projects the output dimension to the same input dimension by repeating
the sequential
Args:
SEANetEncoder (_type_): _description_
output_dim (int): Output dimension
"""

def __init__(self, *args, output_dim=8, **kwargs):

self.output_dim = output_dim
super().__init__(*args, **kwargs)
# Adding a reverse convolution layer
self.reverse_convolution = torch.nn.ConvTranspose1d(
in_channels=self.dimension,
out_channels=self.output_dim,
kernel_size=math.prod(self.ratios),
stride=math.prod(self.ratios),
padding=0,
)

def forward(self, x: torch.Tensor) -> torch.Tensor:
orig_nframes = x.shape[-1]
x = self.model(x)
x = self.reverse_convolution(x)
# make sure dim didn't change
return x[:, :, :orig_nframes]


class AudioSealDetector(torch.nn.Module):
"""
Detect the watermarking from an audio signal
Expand Down
Empty file added src/audioseal/py.typed
Empty file.

0 comments on commit ba9064b

Please sign in to comment.