Skip to content

Commit

Permalink
Cleanup fully fused CUDA kernel, update README.md
Browse files Browse the repository at this point in the history
  • Loading branch information
L0SG committed Jul 22, 2024
1 parent dec9431 commit 455f3f3
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 29 deletions.
9 changes: 4 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,8 @@

## News
- **Jul 2024 (v2.3):**
- General refactor and code improvements for improved readability
- Fully fused CUDA kernel of anti-alised activation (upsampling + activation + downsampling)
- Inference speed benchmark
- General refactor and code improvements for improved readability.
- Fully fused CUDA kernel of anti-alised activation (upsampling + activation + downsampling) with inference speed benchmark.

- **Jul 2024 (v2.2):** The repository now includes an interactive local demo using gradio.

Expand Down Expand Up @@ -149,7 +148,7 @@ generator = BigVGAN(h, use_cuda_kernel=True)

You can also pass `--use_cuda_kernel` to `inference.py` and `inference_e2e.py` to enable this feature.

When applied for the first time, it builds the kernel using `nvcc` and `ninja`. If the build succeeds, the kernel is saved to `alias/cuda/build` and the model automatically loads the kernel. The codebase has been tested using CUDA `12.1`.
When applied for the first time, it builds the kernel using `nvcc` and `ninja`. If the build succeeds, the kernel is saved to `alias_free_activation/cuda/build` and the model automatically loads the kernel. The codebase has been tested using CUDA `12.1`.

Please make sure that both are installed in your system and `nvcc` installed in your system matches the version your PyTorch build is using.

Expand All @@ -165,7 +164,7 @@ loading plain Pytorch BigVGAN
...
loading CUDA kernel BigVGAN with auto-build
Detected CUDA files, patching ldflags
Emitting ninja build file /path/to/your/BigVGAN/alias_free_cuda/build/build.ninja...
Emitting ninja build file /path/to/your/BigVGAN/alias_free_activation/cuda/build/build.ninja..
Building extension module anti_alias_activation_cuda...
...
Loading extension module anti_alias_activation_cuda...
Expand Down
8 changes: 5 additions & 3 deletions alias_free_activation/cuda/activation1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def forward(ctx, inputs, up_ftr, down_ftr, alpha, beta):
activation_results = anti_alias_activation_cuda.forward(
inputs, up_ftr, down_ftr, alpha, beta
)

return activation_results

@staticmethod
Expand Down Expand Up @@ -70,6 +70,8 @@ def forward(self, x):
): # Exp baked into cuda kernel, cancel it out with a log
alpha = torch.log(alpha)
beta = torch.log(beta)

x = FusedAntiAliasActivation.apply(x, self.upsample.filter, self.downsample.lowpass.filter, alpha, beta)

x = FusedAntiAliasActivation.apply(
x, self.upsample.filter, self.downsample.lowpass.filter, alpha, beta
)
return x
42 changes: 21 additions & 21 deletions alias_free_activation/cuda/anti_alias_activation_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,16 @@

namespace
{
// Hard-coded hyperparameters
// WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and
constexpr int ELEMENTS_PER_LDG_STG = 1; //(WARP_ITERATIONS < 4) ? 1 : 4;
constexpr int BUFFER_SIZE = 32;
constexpr int FILTER_SIZE = 12;
constexpr int HALF_FILTER_SIZE = 6;
constexpr int UPSAMPLE_REPLICATION_PAD = 5; // 5 on each side, matching torch impl
constexpr int DOWNSAMPLE_REPLICATION_PAD_LEFT = 5; // matching torch impl
constexpr int DOWNSAMPLE_REPLICATION_PAD_RIGHT = 6; // matching torch impl

template <typename input_t, typename output_t, typename acc_t>
__global__ void anti_alias_activation_forward(
output_t *dst,
Expand All @@ -42,23 +52,24 @@ namespace
int channels,
int seq_len)
{
// WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and
constexpr int ELEMENTS_PER_LDG_STG = 1; //(WARP_ITERATIONS < 4) ? 1 : 4;
constexpr int BUFFER_SIZE = 32;
constexpr int FILTER_SIZE = 12;
constexpr int HALF_FILTER_SIZE = 6;
constexpr int UPSAMPLE_REPLICATION_PAD = 5; // 5 on each side, matching torch impl
constexpr int DOWNSAMPLE_REPLICATION_PAD_LEFT = 5; // matching torch impl
constexpr int DOWNSAMPLE_REPLICATION_PAD_RIGHT = 6; // matching torch impl
// Up and downsample filters
input_t up_filter[FILTER_SIZE];
input_t down_filter[FILTER_SIZE];

// Load data from global memory including extra indices reserved for replication paddings
input_t elements[2 * FILTER_SIZE + 2 * BUFFER_SIZE + 2 * UPSAMPLE_REPLICATION_PAD] = {0};
input_t intermediates[2 * FILTER_SIZE + 2 * BUFFER_SIZE + DOWNSAMPLE_REPLICATION_PAD_LEFT + DOWNSAMPLE_REPLICATION_PAD_RIGHT] = {0};

// Output stores downsampled output before writing to dst
output_t output[BUFFER_SIZE];

// blockDim/threadIdx = (128, 1, 1)
// gridDim/blockIdx = (seq_blocks, channels, batches)
int block_offset = (blockIdx.x * 128 * BUFFER_SIZE + seq_len * (blockIdx.y + gridDim.y * blockIdx.z));
int local_offset = threadIdx.x * BUFFER_SIZE;
int seq_offset = blockIdx.x * 128 * BUFFER_SIZE + local_offset;

int intermediate_seq_len = seq_len * 2; // intermediate have double the seq_len
int intermediate_block_offset = (blockIdx.x * 128 * BUFFER_SIZE * 2 + intermediate_seq_len * (blockIdx.y + gridDim.y * blockIdx.z));
// intermediate have double the seq_len
int intermediate_local_offset = threadIdx.x * BUFFER_SIZE * 2;
int intermediate_seq_offset = blockIdx.x * 128 * BUFFER_SIZE * 2 + intermediate_local_offset;

Expand All @@ -77,17 +88,6 @@ namespace
beta = beta + blockIdx.y;
input_t beta_val = expf(beta[0]);

// Load data from global memory including extra indices reserved for replication paddings
input_t elements[2 * FILTER_SIZE + 2 * BUFFER_SIZE + 2 * UPSAMPLE_REPLICATION_PAD] = {0};
input_t intermediates[2 * FILTER_SIZE + 2 * BUFFER_SIZE + DOWNSAMPLE_REPLICATION_PAD_LEFT + DOWNSAMPLE_REPLICATION_PAD_RIGHT] = {0};

// Output stores downsampled output before writing to dst
output_t output[BUFFER_SIZE];

// Up and downsample filters
input_t up_filter[FILTER_SIZE];
input_t down_filter[FILTER_SIZE];

#pragma unroll
for (int it = 0; it < FILTER_SIZE; it += 1)
{
Expand Down

0 comments on commit 455f3f3

Please sign in to comment.