diff --git a/README.md b/README.md index 6a1d9ab9..ae51f6fe 100644 --- a/README.md +++ b/README.md @@ -270,6 +270,7 @@ loss.backward() | KLDivergence | `liger_kernel.transformers.LigerKLDIVLoss` | | JSD | `liger_kernel.transformers.LigerJSD` | | FusedLinearJSD | `liger_kernel.transformers.LigerFusedLinearJSD` | +| TVD | `liger_kernel.transformers.LigerTVDLoss` | - **RMSNorm**: [RMSNorm](https://arxiv.org/pdf/1910.07467), which normalizes activations using their root mean square, is implemented by fusing the normalization and scaling steps into a single Triton kernel, and achieves ~3X speedup with ~3X peak memory reduction. - **LayerNorm**: [LayerNorm](https://arxiv.org/pdf/1607.06450), which centers and normalizes activations across the feature dimension, is implemented by fusing the centering, normalization and scaling steps into a single Triton kernel, and achieves ~2X speedup. @@ -286,6 +287,7 @@ $$\text{GeGLU}(x)=\text{GELU}(xW+b)\otimes(xV+c)$$ - **FusedLinearCrossEntropy**: Peak memory usage of cross entropy loss is further improved by fusing the model head with the CE loss and chunking the input for block-wise loss and gradient calculation, a technique inspired by [Efficient Cross Entropy](https://github.com/mgmalek/efficient_cross_entropy). It achieves >4X memory reduction for 128k vocab size. **This is highly effective for large batch size, large sequence length, and large vocabulary sizes.** Please refer to the [Medusa example](https://github.com/linkedin/Liger-Kernel/tree/main/examples/medusa) for individual kernel usage. - **KLDivergence**: [KL Divergence](https://pytorch.org/docs/stable/generated/torch.nn.KLDivLoss.html) is implemented by fusing the forward into a single triton kernel, with reduction done outside the kernel. It achieves ~1.5X speed and ~15% memory reduction for 128K vocab size. - **JSD**: [Generalized JSD](https://arxiv.org/pdf/2306.13649) (Jensen-Shannon divergence), is implemented by computing both the loss and gradient in the forward pass. It achieves ~1.5X speed and ~54% memory reduction for 128k vocab size. +- **TVD**: [TVD](https://aclanthology.org/2023.acl-long.605.pdf) (Total variation distance), is implemented by computing both the loss and gradient in the forward pass. It achieves ~2X speed and ~15% memory reduction for 128k vocab size. - **FusedLinearJSD**: Peak memory usage of JSD loss is further improved by fusing the model head with the JSD and chunking the input for block-wise loss and gradient calculation. It achieves ~85% memory reduction for 128k vocab size where batch size $\times$ sequence length is 8192. diff --git a/benchmark/data/all_benchmark_data.csv b/benchmark/data/all_benchmark_data.csv index dfd31091..b2026bc3 100644 --- a/benchmark/data/all_benchmark_data.csv +++ b/benchmark/data/all_benchmark_data.csv @@ -505,6 +505,42 @@ fused_linear_jsd,torch,full,memory,MB,BT,B x T,1024,10609.005859375,10609.005859 fused_linear_jsd,torch,full,memory,MB,BT,B x T,2048,17146.009765625,17146.009765625,17146.009765625,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA H100 80GB HBM3,2024-10-09 12:29:35,0.3.1 fused_linear_jsd,torch,full,memory,MB,BT,B x T,4096,30220.017578125,30220.017578125,30220.017578125,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA H100 80GB HBM3,2024-10-09 12:29:35,0.3.1 fused_linear_jsd,torch,full,memory,MB,BT,B x T,8192,56368.015625,56368.015625,56368.015625,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA H100 80GB HBM3,2024-10-09 12:29:35,0.3.1 +tvd,liger,full,memory,MB,V,vocab size,4096,1792.0009765625,1792.0009765625,1792.0009765625,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:04,0.3.1 +tvd,liger,full,memory,MB,V,vocab size,8192,3584.0009765625,3584.0009765625,3584.0009765625,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:04,0.3.1 +tvd,liger,full,memory,MB,V,vocab size,16384,7168.0009765625,7168.0009765625,7168.0009765625,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:04,0.3.1 +tvd,liger,full,memory,MB,V,vocab size,32768,14336.0009765625,14336.0009765625,14336.0009765625,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:04,0.3.1 +tvd,liger,full,memory,MB,V,vocab size,65536,28672.0,28672.0,28672.0,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:04,0.3.1 +tvd,liger,full,memory,MB,V,vocab size,131072,57344.0,57344.0,57344.0,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:04,0.3.1 +tvd,torch,full,memory,MB,V,vocab size,4096,2048.0009765625,2048.0009765625,2048.0009765625,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:04,0.3.1 +tvd,torch,full,memory,MB,V,vocab size,8192,4096.0009765625,4096.0009765625,4096.0009765625,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:04,0.3.1 +tvd,torch,full,memory,MB,V,vocab size,16384,8192.0009765625,8192.0009765625,8192.0009765625,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:04,0.3.1 +tvd,torch,full,memory,MB,V,vocab size,32768,16384.0,16384.0,16384.0,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:04,0.3.1 +tvd,torch,full,memory,MB,V,vocab size,65536,32768.0,32768.0,32768.0,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:04,0.3.1 +tvd,torch,full,memory,MB,V,vocab size,131072,65536.0,65536.0,65536.0,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:04,0.3.1 +tvd,liger,forward,speed,ms,V,vocab size,4096,0.47814399003982544,0.4774720072746277,0.4790079891681671,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:08,0.3.1 +tvd,liger,forward,speed,ms,V,vocab size,8192,0.906495988368988,0.905951976776123,0.9073920249938965,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:08,0.3.1 +tvd,liger,forward,speed,ms,V,vocab size,16384,1.8787360191345215,1.8778239488601685,1.8797119855880737,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:08,0.3.1 +tvd,liger,forward,speed,ms,V,vocab size,32768,3.5788800716400146,3.5772159099578857,3.58076810836792,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:08,0.3.1 +tvd,liger,forward,speed,ms,V,vocab size,65536,7.008831977844238,7.007718086242676,7.010636806488037,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:08,0.3.1 +tvd,liger,forward,speed,ms,V,vocab size,131072,13.88646411895752,13.88128662109375,13.890560150146484,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:08,0.3.1 +tvd,torch,forward,speed,ms,V,vocab size,4096,1.308608055114746,1.306502342224121,1.3104127645492554,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:09,0.3.1 +tvd,torch,forward,speed,ms,V,vocab size,8192,2.4735519886016846,2.472287893295288,2.4749441146850586,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:09,0.3.1 +tvd,torch,forward,speed,ms,V,vocab size,16384,4.828320026397705,4.826848030090332,4.830643177032471,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:09,0.3.1 +tvd,torch,forward,speed,ms,V,vocab size,32768,9.5206880569458,9.517024040222168,9.525145530700684,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:09,0.3.1 +tvd,torch,forward,speed,ms,V,vocab size,65536,19.01535987854004,19.011123657226562,19.01806640625,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:09,0.3.1 +tvd,torch,forward,speed,ms,V,vocab size,131072,38.022865295410156,38.01945877075195,38.02627182006836,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:09,0.3.1 +tvd,liger,full,speed,ms,V,vocab size,4096,2.626512050628662,2.621260643005371,2.646751880645752,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:11,0.3.1 +tvd,liger,full,speed,ms,V,vocab size,8192,4.661711692810059,4.657618999481201,4.662930965423584,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:11,0.3.1 +tvd,liger,full,speed,ms,V,vocab size,16384,9.088272094726562,9.080741882324219,9.092268943786621,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:11,0.3.1 +tvd,liger,full,speed,ms,V,vocab size,32768,18.116064071655273,18.112728118896484,18.118234634399414,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:11,0.3.1 +tvd,liger,full,speed,ms,V,vocab size,65536,35.85124969482422,35.849971771240234,35.85252380371094,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:11,0.3.1 +tvd,liger,full,speed,ms,V,vocab size,131072,71.1648941040039,71.1648941040039,71.1648941040039,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:11,0.3.1 +tvd,torch,full,speed,ms,V,vocab size,4096,4.361599922180176,4.360159873962402,4.3639678955078125,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:13,0.3.1 +tvd,torch,full,speed,ms,V,vocab size,8192,8.11302375793457,8.11075210571289,8.114463806152344,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:13,0.3.1 +tvd,torch,full,speed,ms,V,vocab size,16384,15.841055870056152,15.837087631225586,15.841856002807617,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:13,0.3.1 +tvd,torch,full,speed,ms,V,vocab size,32768,31.71219253540039,31.706951141357422,31.715898513793945,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:13,0.3.1 +tvd,torch,full,speed,ms,V,vocab size,65536,63.17919921875,63.17919921875,63.17919921875,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:13,0.3.1 +tvd,torch,full,speed,ms,V,vocab size,131072,126.0436782836914,126.0436782836914,126.0436782836914,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:13,0.3.1 group_norm,liger,forward,speed,ms,C,num_channels,32,0.03481600061058998,0.03379200026392937,0.03993599861860275,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-40GB,2024-11-05 19:20:35,0.3.1 group_norm,liger,forward,speed,ms,C,num_channels,64,0.05222399905323982,0.05119999870657921,0.05222399905323982,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-40GB,2024-11-05 19:20:35,0.3.1 group_norm,liger,forward,speed,ms,C,num_channels,128,0.08499199897050858,0.08396799862384796,0.08499199897050858,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-40GB,2024-11-05 19:20:35,0.3.1 @@ -619,3 +655,4 @@ layer_norm,huggingface,full,memory,MB,N,hidden size,2048,160.09375,160.09375,160 layer_norm,huggingface,full,memory,MB,N,hidden size,4096,320.15625,320.15625,320.15625,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-40GB,2024-11-05 19:28:05,0.3.1 layer_norm,huggingface,full,memory,MB,N,hidden size,8192,640.28125,640.28125,640.28125,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-40GB,2024-11-05 19:28:05,0.3.1 layer_norm,huggingface,full,memory,MB,N,hidden size,16384,1280.53125,1280.53125,1280.53125,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-40GB,2024-11-05 19:28:05,0.3.1 + diff --git a/benchmark/scripts/benchmark_tvd.py b/benchmark/scripts/benchmark_tvd.py new file mode 100644 index 00000000..2e62fd6f --- /dev/null +++ b/benchmark/scripts/benchmark_tvd.py @@ -0,0 +1,136 @@ +import torch +import triton +from utils import ( + QUANTILES, + SingleBenchmarkRunInput, + SingleBenchmarkRunOutput, + _test_memory, + parse_benchmark_script_args, + run_benchmarks, +) + +from liger_kernel.transformers.tvd import LigerTVDLoss + + +class TorchTVDLoss(torch.nn.Module): + def __init__(self, reduction="batchmean"): + super(TorchTVDLoss, self).__init__() + self.reduction = reduction + + def forward(self, p, q): + tvd = torch.abs(p - q) / 2.0 + if self.reduction == "mean": + return torch.sum(tvd) / (p.size(0) * p.size(1)) + elif self.reduction == "sum": + return torch.sum(tvd) + elif self.reduction == "none": + return tvd + elif self.reduction == "batchmean": + return torch.sum(tvd) / p.size(0) + else: + raise ValueError("Invalid reduction type.") + + +S, E = 12, 18 + + +def bench_speed_tvd(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + reduction = "batchmean" + V = input.x + B, T = input.extra_benchmark_config["B"], input.extra_benchmark_config["T"] + torch_tvd = TorchTVDLoss(reduction=reduction) + liger_tvd = LigerTVDLoss(reduction=reduction) + + _input = torch.randn(B * T, V, requires_grad=True, device="cuda").softmax(dim=-1) + target = torch.randn(B * T, V, device="cuda").softmax(dim=-1) + + def fwd(): + if input.kernel_provider == "liger": + return liger_tvd(_input, target) + else: + return torch_tvd(_input, target) + + if input.kernel_operation_mode == "forward": + ms_50, ms_20, ms_80 = triton.testing.do_bench(fwd, quantiles=QUANTILES, rep=100) + elif input.kernel_operation_mode == "backward": + y = fwd() + + ms_50, ms_20, ms_80 = triton.testing.do_bench( + lambda: y.backward(retain_graph=True), + quantiles=QUANTILES, + grad_to_none=[_input], + rep=100, + ) + elif input.kernel_operation_mode == "full": + + def full(): + y = fwd() + y.backward(retain_graph=True) + + ms_50, ms_20, ms_80 = triton.testing.do_bench( + full, quantiles=QUANTILES, rep=100 + ) + return SingleBenchmarkRunOutput( + y_20=ms_20, + y_50=ms_50, + y_80=ms_80, + ) + + +def bench_memory_tvd(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + reduction = "batchmean" + torch_tvd = TorchTVDLoss(reduction=reduction) + liger_tvd = LigerTVDLoss(reduction=reduction) + + V = input.x + B, T = input.extra_benchmark_config["B"], input.extra_benchmark_config["T"] + + _input = torch.randn(B * T, V, requires_grad=True, device="cuda").softmax(dim=-1) + target = torch.randn(B * T, V, device="cuda").softmax(dim=-1) + + def fwd(): + if input.kernel_provider == "liger": + return liger_tvd(_input, target) + else: + return torch_tvd(_input, target) + + def full(): + y = fwd() + y.backward(retain_graph=True) + + mem_50, mem_20, mem_80 = _test_memory(full, quantiles=QUANTILES) + + return SingleBenchmarkRunOutput( + y_20=mem_20, + y_50=mem_50, + y_80=mem_80, + ) + + +if __name__ == "__main__": + args = parse_benchmark_script_args() + common_args = { + "kernel_name": "tvd", + "x_name": "V", + "x_label": "vocab size", + "x_values": [2**i for i in range(12, 18)], + "kernel_providers": ["liger", "torch"], + "extra_benchmark_configs": [{"B": 8, "T": 2048}], + "overwrite": args.overwrite, + } + + run_benchmarks( + bench_test_fn=bench_memory_tvd, + kernel_operation_modes=["full"], + metric_name="memory", + metric_unit="MB", + **common_args, + ) + + run_benchmarks( + bench_test_fn=bench_speed_tvd, + kernel_operation_modes=["forward", "full"], + metric_name="speed", + metric_unit="ms", + **common_args, + ) diff --git a/src/liger_kernel/ops/tvd.py b/src/liger_kernel/ops/tvd.py new file mode 100644 index 00000000..1099a3ec --- /dev/null +++ b/src/liger_kernel/ops/tvd.py @@ -0,0 +1,176 @@ +from typing import Literal + +import torch +import triton +import triton.language as tl + +from liger_kernel.ops.utils import ensure_contiguous + +MAX_FUSED_SIZE = 65536 // 4 + +REDUCTION_LITERAL = Literal["none", "sum", "mean", "batchmean"] + +_REDUCTION_MODE_NONE = tl.constexpr(0) +_REDUCTION_MODE_SUM = tl.constexpr(1) +_REDUCTION_MODE_MEAN = tl.constexpr(2) +_REDUCTION_MODE_BATCHMEAN = tl.constexpr(3) + +_str_to_reduction_mode = { + "none": _REDUCTION_MODE_NONE.value, + "sum": _REDUCTION_MODE_SUM.value, + "mean": _REDUCTION_MODE_MEAN.value, + "batchmean": _REDUCTION_MODE_BATCHMEAN.value, +} + + +def get_num_warps(BLOCK_SIZE): + num_warps = 4 + if BLOCK_SIZE >= 32768: + num_warps = 32 + elif BLOCK_SIZE >= 8192: + num_warps = 16 + elif BLOCK_SIZE >= 2048: + num_warps = 8 + + return num_warps + + +@triton.jit +def _tv_distance_kernel( + p_ptr, + p_stride, + q_ptr, + q_stride, + loss_ptr, + loss_stride, + grads_ptr, + grads_stride, + n_cols, + BLOCK_SIZE: tl.constexpr, + reduction: tl.constexpr = _REDUCTION_MODE_BATCHMEAN, +): + pid = tl.program_id(0).to(tl.int64) + p_ptr += pid * p_stride + q_ptr += pid * q_stride + loss_ptr += pid * loss_stride + grads_ptr += pid * grads_stride + + base_offsets = tl.arange(0, BLOCK_SIZE) + + loss_sum = 0.0 + for i in range(0, n_cols, BLOCK_SIZE): + offsets = i + base_offsets + mask = offsets < n_cols + + p = tl.load(p_ptr + offsets, mask=mask, other=0.0) + q = tl.load(q_ptr + offsets, mask=mask, other=0.0) + + # TVD(P || Q) = 0.5 * |P - Q| + tv_loss = 0.5 * tl.abs(p - q) + + grad_res = tl.where(p > q, 0.5, -0.5) + + tl.store(grads_ptr + offsets, grad_res, mask=mask) + + if reduction == _REDUCTION_MODE_NONE: + tl.store(loss_ptr + offsets, tv_loss, mask=mask) + else: + loss_sum += tl.sum(tv_loss, axis=0) + + if reduction != _REDUCTION_MODE_NONE: + tl.store(loss_ptr, loss_sum) + + +def tv_distance_forward_triton(p, q, reduction): + BT, V = p.shape + + BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V)) + num_warps = get_num_warps(BLOCK_SIZE) + + grid = (BT,) + + reduction = _str_to_reduction_mode[reduction] + + out_size = (BT, V) if reduction == _REDUCTION_MODE_NONE.value else (BT,) + output_tensor = torch.zeros(out_size, device=p.device, dtype=torch.float32) + grads = torch.empty_like(p) + + _tv_distance_kernel[grid]( + p, + p.stride(0), + q, + q.stride(0), + output_tensor, + output_tensor.stride(0), + grads, + grads.stride(0), + V, + BLOCK_SIZE=BLOCK_SIZE, + num_warps=num_warps, + reduction=reduction, + ) + + if reduction == _REDUCTION_MODE_BATCHMEAN.value: + return output_tensor.sum() / BT, grads / BT + elif reduction == _REDUCTION_MODE_SUM.value: + return output_tensor.sum(dim=0), grads + elif reduction == _REDUCTION_MODE_MEAN.value: + return output_tensor.sum() / (BT * V), grads / (BT * V) + else: + return output_tensor, grads + + +def tvd_backward_triton(grad_output, grads): + + # If cross entropy is the last layer, grad_output is 1.0. Skip the mul then. + if torch.equal(grad_output, torch.tensor(1.0, device=grad_output.device)): + return grads + + return grads * grad_output + + +class LigerTVDLossFunction(torch.autograd.Function): + """ + Class implementing the forward and backward pass for the Total Variation Distance Loss using Triton. + """ + + @staticmethod + @ensure_contiguous + def forward( + ctx, + p: torch.Tensor, + q: torch.Tensor, + reduction: REDUCTION_LITERAL = "batchmean", + ) -> torch.Tensor: + """A forward pass for the Total Variation Distance Loss. + + Args: + ctx: Torch autograd context + p (torch.Tensor): A tensor of shape (BT, V) containing the first distribution. + q (torch.Tensor): A tensor of shape (BT, V) containing the second distribution. + reduction (REDUCTION_LITERAL, optional): The reduction method to be applied. Defaults to "batchmean". + + Returns: + torch.Tensor: The computed Total Variation Distance Loss. + """ + loss, grads = tv_distance_forward_triton(p, q, reduction) + ctx.save_for_backward(grads) + return loss + + @staticmethod + @ensure_contiguous + def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor: + """A backward pass for the Total Variation Distance Loss. + + Args: + ctx: Torch autograd context + grad_output (torch.Tensor): The gradient of the loss with respect to the output. + + Returns: + tuple[torch.Tensor, None, None]: The gradient of the loss with respect to the inputs. + """ + (grads,) = ctx.saved_tensors + + grads = tvd_backward_triton(grad_output, grads) + + return grads, None, None diff --git a/src/liger_kernel/transformers/__init__.py b/src/liger_kernel/transformers/__init__.py index ffb8235c..7a8d4fee 100644 --- a/src/liger_kernel/transformers/__init__.py +++ b/src/liger_kernel/transformers/__init__.py @@ -8,6 +8,7 @@ from liger_kernel.transformers.fused_linear_jsd import LigerFusedLinearJSD # noqa: F401 from liger_kernel.transformers.geglu import LigerGEGLUMLP # noqa: F401 from liger_kernel.transformers.jsd import LigerJSD # noqa: F401 +from liger_kernel.transformers.tvd import LigerTVDLoss # noqa: F401 from liger_kernel.transformers.layer_norm import LigerLayerNorm # noqa: F401 from liger_kernel.transformers.monkey_patch import ( # noqa: F401 _apply_liger_kernel, diff --git a/src/liger_kernel/transformers/functional.py b/src/liger_kernel/transformers/functional.py index 292c0dba..62736817 100644 --- a/src/liger_kernel/transformers/functional.py +++ b/src/liger_kernel/transformers/functional.py @@ -11,6 +11,7 @@ from liger_kernel.ops.rms_norm import LigerRMSNormFunction from liger_kernel.ops.rope import LigerRopeFunction from liger_kernel.ops.swiglu import LigerSiLUMulFunction +from liger_kernel.ops.tvd import LigerTVDLossFunction liger_swiglu = LigerSiLUMulFunction.apply liger_cross_entropy = LigerCrossEntropyFunction.apply @@ -22,4 +23,5 @@ liger_kl_div = LigerKLDivLossFunction.apply liger_jsd = LigerJSDFunction.apply liger_fused_linear_jsd = LigerFusedLinearJSDFunction.apply +liger_tvd = LigerTVDLossFunction.apply liger_group_norm = LigerGroupNormFunction.apply diff --git a/src/liger_kernel/transformers/tvd.py b/src/liger_kernel/transformers/tvd.py new file mode 100644 index 00000000..f226ee26 --- /dev/null +++ b/src/liger_kernel/transformers/tvd.py @@ -0,0 +1,12 @@ +import torch.nn as nn + +from liger_kernel.ops.tvd import LigerTVDLossFunction + + +class LigerTVDLoss(nn.Module): + def __init__(self, reduction="batchmean"): + super(LigerTVDLoss, self).__init__() + self.reduction = reduction + + def forward(self, p, q): + return LigerTVDLossFunction.apply(p, q, self.reduction) diff --git a/test/transformers/test_tvd.py b/test/transformers/test_tvd.py new file mode 100644 index 00000000..23f4bf00 --- /dev/null +++ b/test/transformers/test_tvd.py @@ -0,0 +1,132 @@ +from test.utils import supports_bfloat16 + +import pytest +import torch + +from liger_kernel.transformers.tvd import LigerTVDLoss + + +class TorchTVDLoss(torch.nn.Module): + def __init__(self, reduction="batchmean"): + super(TorchTVDLoss, self).__init__() + self.reduction = reduction + + def forward(self, p, q): + + tvd = torch.abs(p - q) / 2.0 + + if self.reduction == "mean": + return torch.sum(tvd) / (p.size(0) * p.size(1)) + elif self.reduction == "sum": + return torch.sum(tvd) + elif self.reduction == "none": + return tvd + elif self.reduction == "batchmean": + return torch.sum(tvd) / p.size(0) + else: + raise ValueError("Invalid reduction type.") + + +_SHAPE_PARAMS = ( + "B, T, V", + [ + (1, 4096, 32000), + (32, 4096, 1024), + (41, 401, 1271), + pytest.param( + 1, + 4096, + 128256, + marks=pytest.mark.skipif( + torch.cuda.get_device_properties(0).total_memory + < 36 * 1000 * 1000 * 1000, + reason="This test requires a GPU with at least 36GB of memory", + ), + ), + (3, 423, 32000), + ], +) + +_DTYPE_PARAMS = ( + "dtype, atol, rtol", + [ + pytest.param( + torch.bfloat16, + 1e-8, + 5e-2, + marks=pytest.mark.skipif( + not supports_bfloat16(), reason="bfloat16 not supported on this GPU" + ), + ), + (torch.float32, 1e-8, 1e-6), + (torch.float16, 1e-3, 1e-3), + ], +) + + +def _test_correctness_once( + target_tvd, + torch_tvd, + B, + T, + V, + dtype, + atol, + rtol, + reduction, + is_last_layer=True, + device="cuda", +): + torch.manual_seed(0) + input = torch.randn(B * T, V, device=device, dtype=dtype, requires_grad=True) + + x1 = input.detach().clone().requires_grad_(True) + x2 = input.detach().clone().requires_grad_(True) + + with torch.no_grad(): + target = torch.randn(B * T, V, device=device).softmax(dim=-1) + + output = target_tvd(x1, target) + output2 = torch_tvd(x2, target) + + assert torch.allclose(output, output2, atol=atol, rtol=rtol) + + if not is_last_layer: + output = output * 2.0 + output2 = output2 * 2.0 + + if reduction == "none": + return + + output.backward() + output2.backward() + assert torch.allclose(x1.grad, x2.grad, atol=atol, rtol=rtol) + + +@pytest.mark.parametrize(*_SHAPE_PARAMS) +@pytest.mark.parametrize("reduction", ["batchmean", "sum", "mean", "none"]) +@pytest.mark.parametrize(*_DTYPE_PARAMS) +def test_correctness(B, T, V, reduction, dtype, atol, rtol): + liger_tvd = LigerTVDLoss(reduction=reduction) + torch_tvd = TorchTVDLoss(reduction=reduction) + _test_correctness_once(liger_tvd, torch_tvd, B, T, V, dtype, atol, rtol, reduction) + + +@pytest.mark.parametrize(*_SHAPE_PARAMS) +@pytest.mark.parametrize("reduction", ["batchmean", "sum", "mean", "none"]) +@pytest.mark.parametrize(*_DTYPE_PARAMS) +def test_correctness_not_last(B, T, V, reduction, dtype, atol, rtol): + liger_tvd = LigerTVDLoss(reduction=reduction) + torch_tvd = TorchTVDLoss(reduction=reduction) + _test_correctness_once( + liger_tvd, + torch_tvd, + B, + T, + V, + dtype, + atol, + rtol, + reduction, + is_last_layer=False, + )