Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

sync patch to upstream repo #31

Merged
merged 9 commits into from
Apr 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/lint_and_test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ jobs:
run: |
sudo apt-get install libsndfile1
python -m pip install --upgrade pip
pip install torch==1.13.0 torchaudio==0.13.0 func_argparse soundfile pytest omegaconf numpy julius
pip install torch==1.13.0 torchaudio==0.13.0 func_argparse soundfile pytest omegaconf numpy julius huggingface_hub
pip install --no-deps -e .
- name: pytest_unit
run: pytest -s -v tests/test_models.py
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -102,5 +102,8 @@ cython_debug/
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
.idea/

# Visual Studio Code
.vscode/

# local training outputs
outputs/*
3 changes: 1 addition & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,9 @@ repos:
- id: end-of-file-fixer

- repo: https://github.com/psf/black
rev: 24.1.1
rev: 24.4.2
hooks:
- id: black
language_version: python3.8

- repo: https://github.com/pycqa/isort
rev: 5.12.0
Expand Down
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,12 @@ All notable changes to AudioSeal are documented in this file.

The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

## [0.1.3] - 2024-04-30
- Fix bug in getting the watermark with non-empty message created in CPU, while the model is loaded in CUDA
- Update Fix bug in building the model card programmatically (not via .YAML file using OmegaConf)
- Add support for HuggingFace Hub, now we can load the model from HF. Unit tests are updated


## [0.1.2] - 2024-02-29
- Add py.typed to make audioseal mypy-friendly
- Add the option to resample the input audio's sample rate to the expected sample rate of the model (https://github.com/facebookresearch/audioseal/pull/18)
Expand Down
6 changes: 3 additions & 3 deletions examples/attacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,9 +109,9 @@ def echo(
# Define a few reflections with decreasing amplitude
impulse_response[0] = 1.0 # Direct sound

impulse_response[
int(sample_rate * duration) - 1
] = volume # First reflection after 100ms
impulse_response[int(sample_rate * duration) - 1] = (
volume # First reflection after 100ms
)

# Add batch and channel dimensions to the impulse response
impulse_response = impulse_response.unsqueeze(0).unsqueeze(0)
Expand Down
1 change: 1 addition & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@ black
isort
flake8
pre-commit
huggingface_hub
2 changes: 1 addition & 1 deletion src/audioseal/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

"""

__version__ = "0.1.2"
__version__ = "0.1.3"


from audioseal import builder
Expand Down
24 changes: 18 additions & 6 deletions src/audioseal/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

from dataclasses import dataclass
from typing import Any, Dict, List, Optional
from dataclasses import asdict, dataclass, is_dataclass
from typing import Any, Dict, List, Mapping, Optional

from omegaconf import DictConfig, OmegaConf
from torch import device, dtype
from typing_extensions import TypeAlias

Expand Down Expand Up @@ -71,6 +72,17 @@ class AudioSealDetectorConfig:
detector: DetectorConfig


def as_dict(obj: Any) -> Dict[str, Any]:
if isinstance(obj, dict):
return obj
if is_dataclass(obj):
return asdict(obj)
elif isinstance(obj, DictConfig):
return OmegaConf.to_container(obj) # type: ignore
else:
raise NotImplementedError(f"Unsupported type for config: {type(obj)}")


def create_generator(
config: AudioSealWMConfig,
*,
Expand All @@ -81,11 +93,11 @@ def create_generator(

# Currently the encoder hparams are the same as
# SEANet, but this can be changed in the future.
encoder = audiocraft.modules.SEANetEncoder(**config.seanet) # type: ignore[arg-type]
encoder = audiocraft.modules.SEANetEncoder(**as_dict(config.seanet))
encoder = encoder.to(device=device, dtype=dtype)

decoder_config = {**config.seanet, **config.decoder} # type: ignore
decoder = audiocraft.modules.SEANetDecoder(**decoder_config) # type: ignore[arg-type]
decoder_config = {**as_dict(config.seanet), **as_dict(config.decoder)}
decoder = audiocraft.modules.SEANetDecoder(**as_dict(decoder_config))
decoder = decoder.to(device=device, dtype=dtype)

msgprocessor = MsgProcessor(nbits=config.nbits, hidden_size=config.seanet.dimension)
Expand All @@ -100,7 +112,7 @@ def create_detector(
device: Optional[Device] = None,
dtype: Optional[DataType] = None,
) -> AudioSealDetector:
detector_config = {**config.seanet, **config.detector} # type: ignore
detector_config = {**as_dict(config.seanet), **as_dict(config.detector)}
detector = AudioSealDetector(nbits=config.nbits, **detector_config)
detector = detector.to(device=device, dtype=dtype)
return detector
2 changes: 1 addition & 1 deletion src/audioseal/libs/audiocraft/modules/conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#
#
# Vendor from https://github.com/facebookresearch/audiocraft

import math
Expand Down
1 change: 1 addition & 0 deletions src/audioseal/libs/audiocraft/modules/seanet.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,7 @@ def forward(self, x):
# make sure dim didn't change
return x[:, :, :orig_nframes]


class SEANetDecoder(nn.Module):
"""SEANet decoder.

Expand Down
112 changes: 68 additions & 44 deletions src/audioseal/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import torch
from omegaconf import DictConfig, OmegaConf

import audioseal
from audioseal.builder import (
AudioSealDetectorConfig,
AudioSealWMConfig,
Expand Down Expand Up @@ -80,11 +81,29 @@ def load_model_checkpoint(
parts = urlparse(str(model_path))
if parts.scheme == "https":

# TODO: Add HF Hub
hash_ = sha1(parts.path.encode()).hexdigest()[:24]
return torch.hub.load_state_dict_from_url(
str(model_path), model_dir=cache_dir, map_location=device, file_name=hash_
)
elif str(model_path).startswith("facebook/audioseal/"):
hf_filename = str(model_path)[len("facebook/audioseal/") :]

try:
from huggingface_hub import hf_hub_download
except ModuleNotFoundError:
print(
f"The model path {model_path} seems to be a direct HF path, "
"but you do not install Huggingface_hub. Install with for example "
"`pip install huggingface_hub` to use this feature."
)
file = hf_hub_download(
repo_id="facebook/audioseal",
filename=hf_filename,
cache_dir=cache_dir,
library_name="audioseal",
library_version=audioseal.__version__,
)
return torch.load(file, map_location=device)
else:
raise ModelLoadError(f"Path or uri {model_path} is unknown or does not exist")

Expand All @@ -100,7 +119,7 @@ def load_local_model_config(model_card: str) -> Optional[DictConfig]:
class AudioSeal:

@staticmethod
def _parse_model(
def parse_model(
model_card_or_path: str,
model_type: Type[AudioSealT],
nbits: Optional[int] = None,
Expand All @@ -126,64 +145,67 @@ def _parse_model(
config_dict = {}
checkpoint = load_model_checkpoint(model_card_or_path)

# If the checkpoint has config in its, take this but uses the info
# in the mode as precedence
assert isinstance(
checkpoint, dict
), f"Expect loaded checkpoint to be a dictionary, get {type(checkpoint)}"
assert isinstance(
config_dict, dict
), f"Except loaded config to be a dictionary, get {type(config_dict)}"
if "xp.cfg" in checkpoint:
config = {**checkpoint["xp.cfg"], **config_dict} # type: ignore
assert config is not None
assert (
"seanet" in config
), f"missing seanet backbone config in {model_card_or_path}"

# Patch 1: Resolve the variables in the checkpoint
config = OmegaConf.create(config)
OmegaConf.resolve(config)
config = OmegaConf.to_container(config) # type: ignore

# Patch 2: Put decoder, encoder and detector outside seanet
seanet_config = config["seanet"]
for key_to_patch in ["encoder", "decoder", "detector"]:
if key_to_patch in seanet_config:
config_to_patch = config.get(key_to_patch) or {}
config[key_to_patch] = {
**config_to_patch,
**seanet_config.pop(key_to_patch),
}

config["seanet"] = seanet_config

# Patch 3: Put nbits into config if specified
if nbits and "nbits" not in config:
config["nbits"] = nbits
config_dict = {**checkpoint["xp.cfg"], **config_dict} # type: ignore

model_config = AudioSeal.parse_config(config_dict, config_type=model_type, nbits=nbits) # type: ignore

if "model" in checkpoint:
checkpoint = checkpoint["model"]

return checkpoint, model_config

@staticmethod
def parse_config(
config: Dict[str, Any],
config_type: Type[AudioSealT],
nbits: Optional[int] = None,
) -> AudioSealT:

assert "seanet" in config, f"missing seanet backbone config in {config}"

# Patch 1: Resolve the variables in the checkpoint
config = OmegaConf.create(config) # type: ignore
OmegaConf.resolve(config) # type: ignore
config = OmegaConf.to_container(config) # type: ignore

# Patch 2: Put decoder, encoder and detector outside seanet
seanet_config = config["seanet"]
for key_to_patch in ["encoder", "decoder", "detector"]:
if key_to_patch in seanet_config:
config_to_patch = config.get(key_to_patch) or {}
config[key_to_patch] = {
**config_to_patch,
**seanet_config.pop(key_to_patch),
}

config["seanet"] = seanet_config

# Patch 3: Put nbits into config if specified
if nbits and "nbits" not in config:
config["nbits"] = nbits

# remove attributes not related to the model_type
result_config = {}
assert config, f"Empty config in {model_card_or_path}"
for field in fields(model_type):
assert config, f"Empty config"
for field in fields(config_type):
if field.name in config:
result_config[field.name] = config[field.name]

schema = OmegaConf.structured(model_type)
schema = OmegaConf.structured(config_type)
schema.merge_with(result_config)
return checkpoint, schema
return schema

@staticmethod
def load_generator(
model_card_or_path: str,
nbits: Optional[int] = None,
) -> AudioSealWM:
"""Load the AudioSeal generator from the model card"""
checkpoint, config = AudioSeal._parse_model(
model_card_or_path, AudioSealWMConfig, nbits=nbits,
checkpoint, config = AudioSeal.parse_model(
model_card_or_path,
AudioSealWMConfig,
nbits=nbits,
)

model = create_generator(config)
Expand All @@ -195,8 +217,10 @@ def load_detector(
model_card_or_path: str,
nbits: Optional[int] = None,
) -> AudioSealDetector:
checkpoint, config = AudioSeal._parse_model(
model_card_or_path, AudioSealDetectorConfig, nbits=nbits,
checkpoint, config = AudioSeal.parse_model(
model_card_or_path,
AudioSealDetectorConfig,
nbits=nbits,
)
model = create_detector(config)
model.load_state_dict(checkpoint)
Expand Down
33 changes: 18 additions & 15 deletions src/audioseal/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,22 +109,24 @@ def get_watermark(
hidden = self.encoder(x)

if self.msg_processor is not None:
if message is None:
self.message = self.message or torch.randint(
0, 2, (x.shape[0], self.msg_processor.nbits), device=x.device
)
message = self.message
if self.message is None:
self.message = torch.randint(0, 2, (x.shape[0], self.msg_processor.nbits), device=x.device)
else:
self.message = self.message.to(device=x.device)


message = self.message

hidden = self.msg_processor(hidden, message)

watermark = self.decoder(hidden)

if sample_rate != 16000:
watermark = julius.resample_frac(watermark, old_sr=16000, new_sr=sample_rate)
watermark = julius.resample_frac(
watermark, old_sr=16000, new_sr=sample_rate
)

return watermark[
..., : length
] # trim output cf encodec codebase
return watermark[..., :length] # trim output cf encodec codebase

def forward(
self,
Expand Down Expand Up @@ -164,7 +166,7 @@ def detect_watermark(
self,
x: torch.Tensor,
sample_rate: Optional[int] = None,
message_threshold: float = 0.5
message_threshold: float = 0.5,
) -> Tuple[float, torch.Tensor]:
"""
A convenience function that returns a probability of an audio being watermarked,
Expand All @@ -174,13 +176,15 @@ def detect_watermark(
x: Audio signal, size: batch x frames
sample_rate: The sample rate of the input audio
message_threshold: threshold used to convert the watermark output (probability
of each bits being 0 or 1) into the binary n-bit message.
of each bits being 0 or 1) into the binary n-bit message.
"""
if sample_rate is None:
logger.warning(COMPATIBLE_WARNING)
sample_rate = 16_000
result, message = self.forward(x, sample_rate=sample_rate) # b x 2+nbits
detected = torch.count_nonzero(torch.gt(result[:, 1, :], 0.5)) / result.shape[-1]
detected = (
torch.count_nonzero(torch.gt(result[:, 1, :], 0.5)) / result.shape[-1]
)
detect_prob = detected.cpu().item() # type: ignore
message = torch.gt(message, message_threshold).int()
return detect_prob, message
Expand All @@ -193,9 +197,8 @@ def decode_message(self, result: torch.Tensor) -> torch.Tensor:
Returns:
The message of size batch x nbits, indicating probability of 1 for each bit
"""
assert (
(result.dim() > 2 and result.shape[1] == self.nbits) or
(self.dim() == 2 and result.shape[0] == self.nbits)
assert (result.dim() > 2 and result.shape[1] == self.nbits) or (
self.dim() == 2 and result.shape[0] == self.nbits
), f"Expect message of size [,{self.nbits}, frames] (get {result.size()})"
decoded_message = result.mean(dim=-1)
return torch.sigmoid(decoded_message)
Expand Down
Loading
Loading