Skip to content

Commit

Permalink
feat: added unet patching
Browse files Browse the repository at this point in the history
  • Loading branch information
flavioschneider committed Jul 16, 2022
1 parent 20b0100 commit 24d9b99
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 7 deletions.
47 changes: 41 additions & 6 deletions audio_diffusion_pytorch/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import torch
import torch.nn as nn
from einops import rearrange
from einops.layers.torch import Rearrange
from einops_exts import rearrange_many, repeat_many
from einops_exts.torch import EinopsToAndFrom
from torch import Tensor, einsum
Expand Down Expand Up @@ -664,6 +665,7 @@ def __init__(
use_skip_scale: bool,
use_attention_bottleneck: bool,
out_channels: Optional[int] = None,
patch_size: int = 1,
):
super().__init__()

Expand All @@ -677,11 +679,14 @@ def __init__(
and len(dilations) == num_layers
)

self.to_in = CrossEmbed1d(
in_channels=in_channels,
out_channels=channels,
kernel_sizes=kernel_sizes_init,
stride=1,
self.to_in = nn.Sequential(
Rearrange("b c (l p) -> b (c p) l", p=patch_size),
CrossEmbed1d(
in_channels=in_channels * patch_size,
out_channels=channels,
kernel_sizes=kernel_sizes_init,
stride=1,
),
)

self.to_time = nn.Sequential(
Expand Down Expand Up @@ -754,7 +759,12 @@ def __init__(
num_groups=resnet_groups,
time_context_features=time_context_features,
),
Conv1d(in_channels=channels, out_channels=out_channels, kernel_size=1),
Conv1d(
in_channels=channels,
out_channels=out_channels * patch_size,
kernel_size=1,
),
Rearrange("b (c p) l -> b c (l p)", p=patch_size),
)

def forward(self, x: Tensor, t: Tensor):
Expand Down Expand Up @@ -804,5 +814,30 @@ def __init__(self, *args, **kwargs):
use_skip_scale=True,
use_attention_bottleneck=True,
use_learned_time_embedding=True,
patch_size=1,
)
super().__init__(*args, **{**default_kwargs, **kwargs})


class UNet1dBravo(UNet1d):
def __init__(self, *args, **kwargs):
default_kwargs = dict(
in_channels=1,
patch_size=4,
channels=128,
multipliers=[1, 2, 4, 4, 4, 4, 4],
factors=[4, 4, 4, 4, 2, 2],
attentions=[False, False, False, False, True, True],
attention_heads=8,
attention_features=64,
attention_multiplier=2,
dilations=[[1, 1], [1, 1], [1, 1], [1, 1], [1, 1], [1, 1]],
resnet_groups=8,
kernel_multiplier_downsample=2,
kernel_sizes_init=[1, 3, 7],
use_nearest_upsample=False,
use_skip_scale=True,
use_attention_bottleneck=True,
use_learned_time_embedding=True,
)
super().__init__(*args, **{**default_kwargs, **kwargs})
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name="audio-diffusion-pytorch",
packages=find_packages(exclude=[]),
version="0.0.4",
version="0.0.5",
license="MIT",
description="Audio Diffusion - PyTorch",
long_description_content_type="text/markdown",
Expand Down

0 comments on commit 24d9b99

Please sign in to comment.