Skip to content

Commit

Permalink
Remove double type dispatch in quantize_ops (#1965)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #1965

Remove double type dispatch because we do not support it in quantize
ops.  This diff introduces `FBGEMM_DISPATCH_FLOAT_AND_HALF` for
dispatching float/half and `FBGEMM_DISPATCH_FLOAT_HALF_AND_BFLOAT16`
for dispatching float/half/bfloat16.  It replaces
`AT_DISPATCH_FLOATING_TYPES*` with these custom dispatchers.

Reviewed By: q10

Differential Revision: D48628880

fbshipit-source-id: 0d04b831adab3c12de8e671b4cbcc758b8aa186c
  • Loading branch information
sryap authored and facebook-github-bot committed Aug 24, 2023
1 parent 4920770 commit 97ffa8d
Show file tree
Hide file tree
Showing 7 changed files with 40 additions and 38 deletions.
16 changes: 16 additions & 0 deletions fbgemm_gpu/include/fbgemm_gpu/dispatch_macros.h
Original file line number Diff line number Diff line change
Expand Up @@ -185,3 +185,19 @@
#NAME, " not implemented for grad_t '", toString(_grad_t), "'"); \
} \
}()

#define FBGEMM_DISPATCH_FLOAT_AND_HALF_CASE(...) \
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__)

#define FBGEMM_DISPATCH_FLOAT_HALF_AND_BFLOAT16_CASE(...) \
FBGEMM_DISPATCH_FLOAT_AND_HALF_CASE(__VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__)

#define FBGEMM_DISPATCH_FLOAT_AND_HALF(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH( \
TYPE, NAME, FBGEMM_DISPATCH_FLOAT_AND_HALF_CASE(__VA_ARGS__))

#define FBGEMM_DISPATCH_FLOAT_HALF_AND_BFLOAT16(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH( \
TYPE, NAME, FBGEMM_DISPATCH_FLOAT_HALF_AND_BFLOAT16_CASE(__VA_ARGS__))
1 change: 1 addition & 0 deletions fbgemm_gpu/src/quantize_ops/common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include <ATen/core/TensorAccessor.h>
#include <ATen/native/TensorIterator.h>
#include <ATen/native/cuda/Loops.cuh>
#include "fbgemm_gpu/dispatch_macros.h"
#include "fbgemm_gpu/embedding_common.h"
#include "fbgemm_gpu/fbgemm_cuda_utils.cuh"
#include "fbgemm_gpu/ops_utils.h"
Expand Down
32 changes: 8 additions & 24 deletions fbgemm_gpu/src/quantize_ops/quantize_fp8_rowwise.cu
Original file line number Diff line number Diff line change
Expand Up @@ -253,12 +253,8 @@ Tensor _float_to_FP8rowwise_gpu_t(const Tensor& input, const bool forward) {
// think unsigned as we use 0, 255

if (nrows <= 20) {
AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
input.scalar_type(),
"_float_to_FP8rowwise_cuda_kernel",
[&] {
FBGEMM_DISPATCH_FLOAT_HALF_AND_BFLOAT16(
input.scalar_type(), "_float_to_FP8rowwise_cuda_kernel", [&] {
_float_to_FP8rowwise_cuda_kernel<scalar_t>
<<<num_blocks,
threads_per_block,
Expand Down Expand Up @@ -297,12 +293,8 @@ Tensor _float_to_FP8rowwise_gpu_t(const Tensor& input, const bool forward) {
const auto num_blocks_warp =
cuda_calc_xblock_count(nrows, rows_per_block);
AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
input.scalar_type(),
"_get_FP8_qparam_cuda_kernel",
[&] {
FBGEMM_DISPATCH_FLOAT_HALF_AND_BFLOAT16(
input.scalar_type(), "_get_FP8_qparam_cuda_kernel", [&] {
_get_FP8_qparam_cuda_kernel<scalar_t>
<<<num_blocks_warp,
dim3(blockDim_x, rows_per_block),
Expand All @@ -325,12 +317,8 @@ Tensor _float_to_FP8rowwise_gpu_t(const Tensor& input, const bool forward) {
const auto gridDim_y = cuda_calc_block_count(nrows, blockDim.y);
dim3 gridDim(gridDim_x, gridDim_y);
AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
input.scalar_type(),
"_compute_FP8_quantize_cuda_kernel",
[&] {
FBGEMM_DISPATCH_FLOAT_HALF_AND_BFLOAT16(
input.scalar_type(), "_compute_FP8_quantize_cuda_kernel", [&] {
_compute_FP8_quantize_cuda_kernel<scalar_t>
<<<gridDim, blockDim, 0, at::cuda::getCurrentCUDAStream()>>>(
input.data_ptr<scalar_t>(),
Expand Down Expand Up @@ -415,12 +403,8 @@ Tensor _FP8rowwise_to_float_gpu_t(const Tensor& input, bool forward) {
const auto gridDim_y = cuda_calc_block_count(nrows, blockDim.y);
const dim3 gridDim(gridDim_x, gridDim_y);
AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
output.scalar_type(),
"FP8rowwise_to_float_cuda_kernel",
[&] {
FBGEMM_DISPATCH_FLOAT_HALF_AND_BFLOAT16(
output.scalar_type(), "FP8rowwise_to_float_cuda_kernel", [&] {
_FP8rowwise_to_float_cuda_kernel<scalar_t>
<<<gridDim, blockDim, 0, at::cuda::getCurrentCUDAStream()>>>(
input.data_ptr<std::uint8_t>(),
Expand Down
12 changes: 6 additions & 6 deletions fbgemm_gpu/src/quantize_ops/quantize_fused_8bit_rowwise.cu
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ Tensor _float_to_fused8bitrowwise_gpu_t(const Tensor& input) {
// think unsigned as we use 0, 255

if (nrows <= 20) {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
FBGEMM_DISPATCH_FLOAT_AND_HALF(
input.scalar_type(), "_float_to_fused8bitrowwise_cuda_kernel", [&] {
_float_to_fused8bitrowwise_cuda_kernel<scalar_t>
<<<num_blocks,
Expand Down Expand Up @@ -292,7 +292,7 @@ Tensor _float_to_fused8bitrowwise_gpu_t(const Tensor& input) {
const auto num_blocks_warp =
cuda_calc_xblock_count(nrows, rows_per_block);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
FBGEMM_DISPATCH_FLOAT_AND_HALF(
input.scalar_type(), "_get_8bit_qparam_cuda_kernel", [&] {
_get_8bit_qparam_cuda_kernel<scalar_t>
<<<num_blocks_warp,
Expand All @@ -315,7 +315,7 @@ Tensor _float_to_fused8bitrowwise_gpu_t(const Tensor& input) {
const auto gridDim_y = cuda_calc_block_count(nrows, blockDim.y);
dim3 gridDim(gridDim_x, gridDim_y);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
FBGEMM_DISPATCH_FLOAT_AND_HALF(
input.scalar_type(), "_compute_8bit_quantize_cuda_kernel", [&] {
_compute_8bit_quantize_cuda_kernel<scalar_t>
<<<gridDim, blockDim, 0, at::cuda::getCurrentCUDAStream()>>>(
Expand Down Expand Up @@ -344,7 +344,7 @@ DLL_PUBLIC Tensor _half_to_fused8bitrowwise_gpu(const Tensor& input) {
///@ingroup quantize-data-cuda
DLL_PUBLIC Tensor _float_or_half_to_fused8bitrowwise_gpu(const Tensor& input) {
Tensor output;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
FBGEMM_DISPATCH_FLOAT_AND_HALF(
input.scalar_type(),
"float_or_half_to_fused8bitrowwise_cuda_kernel",
[&] { output = _float_to_fused8bitrowwise_gpu_t<scalar_t>(input); });
Expand Down Expand Up @@ -398,7 +398,7 @@ Tensor _fused8bitrowwise_to_float_gpu_t(const Tensor& input) {
const auto gridDim_y = cuda_calc_block_count(nrows, blockDim.y);
const dim3 gridDim(gridDim_x, gridDim_y);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
FBGEMM_DISPATCH_FLOAT_AND_HALF(
output.scalar_type(), "fused8bitrowwise_to_float_cuda_kernel", [&] {
_fused8bitrowwise_to_float_cuda_kernel<scalar_t>
<<<gridDim, blockDim, 0, at::cuda::getCurrentCUDAStream()>>>(
Expand Down Expand Up @@ -482,7 +482,7 @@ DLL_PUBLIC at::Tensor _fused8bitrowwise_to_float_mixed_dim_gpu(
const dim3 blockDim(kWarpSize, threads_per_block / kWarpSize);
const dim3 gridDim(
cuda_calc_xblock_count(num_tables * batch_size, blockDim.y));
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
FBGEMM_DISPATCH_FLOAT_AND_HALF(
output.scalar_type(),
"_fused8bitrowwise_to_float_mixed_dim_cuda_kernel",
[&] {
Expand Down
6 changes: 3 additions & 3 deletions fbgemm_gpu/src/quantize_ops/quantize_fused_nbit_rowwise.cu
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ Tensor _float_to_fusednbitrowwise_gpu_t(
const auto num_blocks = cuda_calc_xblock_count(nrows, threads_per_block);
// think unsigned as we use 0, 255

AT_DISPATCH_FLOATING_TYPES_AND_HALF(
FBGEMM_DISPATCH_FLOAT_AND_HALF(
input.scalar_type(), "_float_to_fusednbitrowwise_cuda_kernel", [&] {
_float_to_fusednbitrowwise_cuda_kernel<scalar_t>
<<<num_blocks,
Expand Down Expand Up @@ -185,7 +185,7 @@ DLL_PUBLIC Tensor _float_or_half_to_fusednbitrowwise_gpu(
const Tensor& input,
const int64_t bit_rate) {
Tensor output;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
FBGEMM_DISPATCH_FLOAT_AND_HALF(
input.scalar_type(),
"float_or_half_to_fusednbitrowwise_cuda_kernel",
[&] {
Expand Down Expand Up @@ -239,7 +239,7 @@ Tensor _fusednbitrowwise_to_float_gpu_t(
const auto gridDim_y = cuda_calc_block_count(nrows, blockDim.y);
const dim3 gridDim(gridDim_x, gridDim_y);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
FBGEMM_DISPATCH_FLOAT_AND_HALF(
output.scalar_type(), "fusednbitrowwise_to_float_cuda_kernel", [&] {
_fusednbitrowwise_to_float_cuda_kernel<scalar_t>
<<<gridDim, blockDim, 0, at::cuda::getCurrentCUDAStream()>>>(
Expand Down
5 changes: 3 additions & 2 deletions fbgemm_gpu/src/quantize_ops/quantize_ops_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include <fbgemm_gpu/sparse_ops_utils.h>
#include <torch/library.h>
#include "fbgemm/QuantUtils.h"
#include "fbgemm_gpu/dispatch_macros.h"
#include "fbgemm_gpu/embedding_common.h"
#include "fbgemm_gpu/quantize_ops_utils.h"

Expand Down Expand Up @@ -189,7 +190,7 @@ Tensor float_or_half_to_fused8bitrowwise_cpu(const Tensor& input) {
auto output = at::empty(
{0},
input.options().dtype(at::kByte)); // at::kBytes for uint8_t
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
FBGEMM_DISPATCH_FLOAT_AND_HALF(
input.scalar_type(), "float_or_half_to_fused8bitrowwise_cpu", [&] {
if (std::is_same<scalar_t, float>::value) {
_float_to_fused8bitrowwise_cpu_out(output, input);
Expand Down Expand Up @@ -301,7 +302,7 @@ Tensor float_or_half_to_fusednbitrowwise_cpu(
const Tensor& input,
const int64_t bit_rate) {
Tensor output;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
FBGEMM_DISPATCH_FLOAT_AND_HALF(
input.scalar_type(), "float_or_half_to_fusednbitrowwise_cpu", [&] {
if (std::is_same<scalar_t, float>::value) {
output = _float_to_fusednbitrowwise_cpu<float>(input, bit_rate);
Expand Down
6 changes: 3 additions & 3 deletions fbgemm_gpu/src/quantize_ops/quantize_padded_fp8_rowwise.cu
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ Tensor _float_to_paddedFP8rowwise_gpu_t(
const auto num_blocks = cuda_calc_xblock_count(
nrows == 1 ? (ncols + row_dim - 1) / row_dim : nrows, threads_per_block);

AT_DISPATCH_FLOATING_TYPES_AND_HALF(
FBGEMM_DISPATCH_FLOAT_AND_HALF(
input.scalar_type(), "_float_to_FP8rowwise_cuda_kernel", [&] {
_float_to_paddedFP8rowwise_cuda_kernel<scalar_t>
<<<num_blocks,
Expand Down Expand Up @@ -357,7 +357,7 @@ Tensor _paddedFP8rowwise_to_float_gpu_t(
constexpr int kMaxThreads = 1024;
const auto threads_per_block =
kMaxThreads < row_dim ? kMaxThreads : row_dim;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
FBGEMM_DISPATCH_FLOAT_AND_HALF(
output.scalar_type(), "PaddedFP8rowwise_to_float_1d_cuda_kernel", [&] {
_PaddedFP8rowwise_to_float_1d_cuda_kernel<scalar_t>
<<<num_rows,
Expand All @@ -375,7 +375,7 @@ Tensor _paddedFP8rowwise_to_float_gpu_t(
});
C10_CUDA_KERNEL_LAUNCH_CHECK();
} else {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
FBGEMM_DISPATCH_FLOAT_AND_HALF(
output.scalar_type(), "PaddedFP8rowwise_to_float_2d_cuda_kernel", [&] {
_PaddedFP8rowwise_to_float_2d_cuda_kernel<scalar_t>
<<<num_blocks,
Expand Down

0 comments on commit 97ffa8d

Please sign in to comment.