From 455f3f3eb09f376d68ee70fd96820e67c4ddea3a Mon Sep 17 00:00:00 2001 From: L0SG Date: Mon, 22 Jul 2024 05:36:02 -0700 Subject: [PATCH] Cleanup fully fused CUDA kernel, update README.md --- README.md | 9 ++-- alias_free_activation/cuda/activation1d.py | 8 ++-- .../cuda/anti_alias_activation_cuda.cu | 42 +++++++++---------- 3 files changed, 30 insertions(+), 29 deletions(-) diff --git a/README.md b/README.md index 0e3c729..c52342a 100644 --- a/README.md +++ b/README.md @@ -10,9 +10,8 @@ ## News - **Jul 2024 (v2.3):** - - General refactor and code improvements for improved readability - - Fully fused CUDA kernel of anti-alised activation (upsampling + activation + downsampling) - - Inference speed benchmark + - General refactor and code improvements for improved readability. + - Fully fused CUDA kernel of anti-alised activation (upsampling + activation + downsampling) with inference speed benchmark. - **Jul 2024 (v2.2):** The repository now includes an interactive local demo using gradio. @@ -149,7 +148,7 @@ generator = BigVGAN(h, use_cuda_kernel=True) You can also pass `--use_cuda_kernel` to `inference.py` and `inference_e2e.py` to enable this feature. -When applied for the first time, it builds the kernel using `nvcc` and `ninja`. If the build succeeds, the kernel is saved to `alias/cuda/build` and the model automatically loads the kernel. The codebase has been tested using CUDA `12.1`. +When applied for the first time, it builds the kernel using `nvcc` and `ninja`. If the build succeeds, the kernel is saved to `alias_free_activation/cuda/build` and the model automatically loads the kernel. The codebase has been tested using CUDA `12.1`. Please make sure that both are installed in your system and `nvcc` installed in your system matches the version your PyTorch build is using. @@ -165,7 +164,7 @@ loading plain Pytorch BigVGAN ... loading CUDA kernel BigVGAN with auto-build Detected CUDA files, patching ldflags -Emitting ninja build file /path/to/your/BigVGAN/alias_free_cuda/build/build.ninja... +Emitting ninja build file /path/to/your/BigVGAN/alias_free_activation/cuda/build/build.ninja.. Building extension module anti_alias_activation_cuda... ... Loading extension module anti_alias_activation_cuda... diff --git a/alias_free_activation/cuda/activation1d.py b/alias_free_activation/cuda/activation1d.py index b205887..fbc0fd8 100644 --- a/alias_free_activation/cuda/activation1d.py +++ b/alias_free_activation/cuda/activation1d.py @@ -23,7 +23,7 @@ def forward(ctx, inputs, up_ftr, down_ftr, alpha, beta): activation_results = anti_alias_activation_cuda.forward( inputs, up_ftr, down_ftr, alpha, beta ) - + return activation_results @staticmethod @@ -70,6 +70,8 @@ def forward(self, x): ): # Exp baked into cuda kernel, cancel it out with a log alpha = torch.log(alpha) beta = torch.log(beta) - - x = FusedAntiAliasActivation.apply(x, self.upsample.filter, self.downsample.lowpass.filter, alpha, beta) + + x = FusedAntiAliasActivation.apply( + x, self.upsample.filter, self.downsample.lowpass.filter, alpha, beta + ) return x diff --git a/alias_free_activation/cuda/anti_alias_activation_cuda.cu b/alias_free_activation/cuda/anti_alias_activation_cuda.cu index 2db1049..8c44233 100644 --- a/alias_free_activation/cuda/anti_alias_activation_cuda.cu +++ b/alias_free_activation/cuda/anti_alias_activation_cuda.cu @@ -30,6 +30,16 @@ namespace { + // Hard-coded hyperparameters + // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and + constexpr int ELEMENTS_PER_LDG_STG = 1; //(WARP_ITERATIONS < 4) ? 1 : 4; + constexpr int BUFFER_SIZE = 32; + constexpr int FILTER_SIZE = 12; + constexpr int HALF_FILTER_SIZE = 6; + constexpr int UPSAMPLE_REPLICATION_PAD = 5; // 5 on each side, matching torch impl + constexpr int DOWNSAMPLE_REPLICATION_PAD_LEFT = 5; // matching torch impl + constexpr int DOWNSAMPLE_REPLICATION_PAD_RIGHT = 6; // matching torch impl + template __global__ void anti_alias_activation_forward( output_t *dst, @@ -42,14 +52,16 @@ namespace int channels, int seq_len) { - // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and - constexpr int ELEMENTS_PER_LDG_STG = 1; //(WARP_ITERATIONS < 4) ? 1 : 4; - constexpr int BUFFER_SIZE = 32; - constexpr int FILTER_SIZE = 12; - constexpr int HALF_FILTER_SIZE = 6; - constexpr int UPSAMPLE_REPLICATION_PAD = 5; // 5 on each side, matching torch impl - constexpr int DOWNSAMPLE_REPLICATION_PAD_LEFT = 5; // matching torch impl - constexpr int DOWNSAMPLE_REPLICATION_PAD_RIGHT = 6; // matching torch impl + // Up and downsample filters + input_t up_filter[FILTER_SIZE]; + input_t down_filter[FILTER_SIZE]; + + // Load data from global memory including extra indices reserved for replication paddings + input_t elements[2 * FILTER_SIZE + 2 * BUFFER_SIZE + 2 * UPSAMPLE_REPLICATION_PAD] = {0}; + input_t intermediates[2 * FILTER_SIZE + 2 * BUFFER_SIZE + DOWNSAMPLE_REPLICATION_PAD_LEFT + DOWNSAMPLE_REPLICATION_PAD_RIGHT] = {0}; + + // Output stores downsampled output before writing to dst + output_t output[BUFFER_SIZE]; // blockDim/threadIdx = (128, 1, 1) // gridDim/blockIdx = (seq_blocks, channels, batches) @@ -57,8 +69,7 @@ namespace int local_offset = threadIdx.x * BUFFER_SIZE; int seq_offset = blockIdx.x * 128 * BUFFER_SIZE + local_offset; - int intermediate_seq_len = seq_len * 2; // intermediate have double the seq_len - int intermediate_block_offset = (blockIdx.x * 128 * BUFFER_SIZE * 2 + intermediate_seq_len * (blockIdx.y + gridDim.y * blockIdx.z)); + // intermediate have double the seq_len int intermediate_local_offset = threadIdx.x * BUFFER_SIZE * 2; int intermediate_seq_offset = blockIdx.x * 128 * BUFFER_SIZE * 2 + intermediate_local_offset; @@ -77,17 +88,6 @@ namespace beta = beta + blockIdx.y; input_t beta_val = expf(beta[0]); - // Load data from global memory including extra indices reserved for replication paddings - input_t elements[2 * FILTER_SIZE + 2 * BUFFER_SIZE + 2 * UPSAMPLE_REPLICATION_PAD] = {0}; - input_t intermediates[2 * FILTER_SIZE + 2 * BUFFER_SIZE + DOWNSAMPLE_REPLICATION_PAD_LEFT + DOWNSAMPLE_REPLICATION_PAD_RIGHT] = {0}; - - // Output stores downsampled output before writing to dst - output_t output[BUFFER_SIZE]; - - // Up and downsample filters - input_t up_filter[FILTER_SIZE]; - input_t down_filter[FILTER_SIZE]; - #pragma unroll for (int it = 0; it < FILTER_SIZE; it += 1) {