From 108162595d55454fabb702c1fb4f1cdbd2544ede Mon Sep 17 00:00:00 2001 From: Ilango Rajagopal Date: Thu, 27 Jun 2024 17:11:40 +0530 Subject: [PATCH] Refactor custom ops into proper file locations --- QEfficient/customop/__init__.py | 119 +----------------- QEfficient/customop/ctx_scatter_gather.py | 81 ++++++++++++ QEfficient/customop/rms_norm.py | 88 ++++++------- QEfficient/customop/rms_norm_native.py | 61 +++++++++ ...ms_op_config.yaml => rms_norm_native.yaml} | 2 +- 5 files changed, 185 insertions(+), 166 deletions(-) create mode 100644 QEfficient/customop/ctx_scatter_gather.py create mode 100644 QEfficient/customop/rms_norm_native.py rename QEfficient/customop/{custom_rms_op_config.yaml => rms_norm_native.yaml} (90%) diff --git a/QEfficient/customop/__init__.py b/QEfficient/customop/__init__.py index d13c8d90..9bfd0899 100644 --- a/QEfficient/customop/__init__.py +++ b/QEfficient/customop/__init__.py @@ -5,120 +5,7 @@ # # ----------------------------------------------------------------------------- -""" -RMS Norm CustomOp Node in com.qti.aisw.onnx Domain for Cloud AI 100 -This is to handle the FP16 Overflow seen in RMS Norm for LLMs +from QEfficient.customop.ctx_scatter_gather import CtxGatherFunc, CtxScatterFunc +from QEfficient.customop.rms_norm import CustomRMSNormAIC -""" - -import onnxscript -import torch -from onnxscript.onnx_opset import opset13 as ops -from torch import nn - -opset_version = 13 -custom_opset = onnxscript.values.Opset(domain="com.qti.aisw.onnx", version=1) - - -# Version 1 -@onnxscript.script(custom_opset) -def CustomRMSNorm(hidden_states: onnxscript.FLOAT, weight: onnxscript.FLOAT, epsilon: float): - weight = ops.Cast(weight, to=1) - variance = ops.ReduceMean(ops.Pow(hidden_states, 2), axes=[-1], keepdims=1) - epsilon = ops.Expand(epsilon, ops.Shape(variance)) - hidden_states = hidden_states * ops.Reciprocal(ops.Sqrt(variance + epsilon)) - return weight * hidden_states - - -class CustomRMSNormOp(torch.autograd.Function): - @staticmethod - def forward(hidden_states: torch.Tensor, weight: torch.Tensor, epsilon: float): - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + epsilon) - return weight * hidden_states - - @staticmethod - def setup_context(ctx, inputs, outputs): - pass - - @staticmethod - def symbolic( - g: torch.onnx._internal.jit_utils.GraphContext, - hidden_states: torch.Value, - weight: torch.Value, - epsilon: torch.Value, - ) -> torch.Value: - return g.onnxscript_op(CustomRMSNorm, hidden_states, weight, epsilon_f=epsilon).setTypeAs(hidden_states) - - -class CustomRMSNormAIC(nn.Module): - def __init__(self, hidden_size, eps=1e-05): - super(CustomRMSNormAIC, self).__init__() - self.variance_epsilon = eps - self.weight = torch.nn.Parameter(torch.ones(hidden_size)) - - def forward(self, hidden_states): - output = CustomRMSNormOp.apply(hidden_states, self.weight, self.variance_epsilon) - return output - - -@onnxscript.script(onnxscript.values.Opset("com.qualcomm.cloud", 1)) -def CtxScatter(data: onnxscript.FLOAT, position_ids: onnxscript.INT32, updates: onnxscript.FLOAT) -> onnxscript.FLOAT: - # Find dims - batch_size = ops.Gather(ops.Shape(data), [0]) - num_heads = ops.Gather(ops.Shape(data), [1]) - seq_len = ops.Gather(ops.Shape(position_ids), [1]) - - # Expanded shape to create indices - zero = ops.Constant(value_ints=[0]) - one = ops.Constant(value_ints=[1]) - exp_shape = ops.Concat(batch_size, num_heads, seq_len, one, axis=0) - - # Create indices - batch_idx = ops.Expand(ops.Unsqueeze(ops.Range(zero, batch_size, one), [1, 2, 3]), exp_shape) - head_idx = ops.Expand(ops.Unsqueeze(ops.Range(zero, num_heads, one), [0, 2, 3]), exp_shape) - ctx_idx = ops.Expand(ops.Unsqueeze(position_ids, [1, 3]), exp_shape) - indices = ops.Concat(batch_idx, head_idx, ctx_idx, axis=3) - - return ops.ScatterND(data, indices, updates) - - -class CtxScatterFunc(torch.autograd.Function): - @staticmethod - def forward(data: torch.Tensor, position_ids: torch.Tensor, updates: torch.Tensor): - batch_idx = torch.arange(data.shape[0]).view(-1, 1, 1) - head_idx = torch.arange(data.shape[1]).view(1, -1, 1) - ctx_idx = position_ids.unsqueeze(1) - data[batch_idx, head_idx, ctx_idx] = updates - return data - - @staticmethod - def setup_context(ctx, inputs, outputs): - pass - - @staticmethod - def symbolic(g: torch.Graph, data: torch.Value, position_ids: torch.Value, updates: torch.Value) -> torch.Value: - return g.onnxscript_op(CtxScatter, data, position_ids, updates).setTypeAs(data) - - -@onnxscript.script(onnxscript.values.Opset("com.qualcomm.cloud", 1)) -def CtxGather(data: onnxscript.FLOAT, ctx_indices: onnxscript.INT32) -> onnxscript.FLOAT: - ctx_indices = ops.Expand(ctx_indices, ops.Slice(ops.Shape(data), starts=[0], ends=[3], axes=[0])) - ctx_indices = ops.Unsqueeze(ctx_indices, [-1]) - return ops.GatherND(data, ctx_indices, batch_dims=2) - - -class CtxGatherFunc(torch.autograd.Function): - @staticmethod - def forward(data: torch.Tensor, ctx_indices: torch.Tensor): - batch_indices = torch.arange(data.shape[0]).view(-1, 1, 1) - head_indices = torch.arange(data.shape[1]).view(1, -1, 1) - return data[batch_indices, head_indices, ctx_indices] - - @staticmethod - def setup_context(ctx, inputs, outputs): - pass - - @staticmethod - def symbolic(g: torch.Graph, data: torch.Value, ctx_indices: torch.Value) -> torch.Value: - return g.onnxscript_op(CtxGather, data, ctx_indices).setTypeAs(data) +__all__ = ["CtxGatherFunc", "CtxScatterFunc", "CustomRMSNormAIC"] diff --git a/QEfficient/customop/ctx_scatter_gather.py b/QEfficient/customop/ctx_scatter_gather.py new file mode 100644 index 00000000..fa615b46 --- /dev/null +++ b/QEfficient/customop/ctx_scatter_gather.py @@ -0,0 +1,81 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) 2023-2024 Qualcomm Innovation Center, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +import onnxscript +import torch + +ops = onnxscript.opset13 + + +@onnxscript.script(onnxscript.values.Opset("com.qualcomm.cloud", 1)) +def CtxScatter(data: onnxscript.FLOAT, position_ids: onnxscript.INT32, updates: onnxscript.FLOAT) -> onnxscript.FLOAT: + # Find dims + batch_size = ops.Gather(ops.Shape(data), [0]) + num_heads = ops.Gather(ops.Shape(data), [1]) + seq_len = ops.Gather(ops.Shape(position_ids), [1]) + + # Expanded shape to create indices + zero = ops.Constant(value_ints=[0]) + one = ops.Constant(value_ints=[1]) + exp_shape = ops.Concat(batch_size, num_heads, seq_len, one, axis=0) + + # Create indices + batch_idx = ops.Expand(ops.Unsqueeze(ops.Range(zero, batch_size, one), [1, 2, 3]), exp_shape) + head_idx = ops.Expand(ops.Unsqueeze(ops.Range(zero, num_heads, one), [0, 2, 3]), exp_shape) + ctx_idx = ops.Expand(ops.Unsqueeze(position_ids, [1, 3]), exp_shape) + indices = ops.Concat(batch_idx, head_idx, ctx_idx, axis=3) + + return ops.ScatterND(data, indices, updates) + + +class CtxScatterFunc(torch.autograd.Function): + """ + Function to scatter the current key values into KV-cache. + """ + + @staticmethod + def forward(data: torch.Tensor, position_ids: torch.Tensor, updates: torch.Tensor): + batch_idx = torch.arange(data.shape[0]).view(-1, 1, 1) + head_idx = torch.arange(data.shape[1]).view(1, -1, 1) + ctx_idx = position_ids.unsqueeze(1) + data[batch_idx, head_idx, ctx_idx] = updates + return data + + @staticmethod + def setup_context(ctx, inputs, outputs): + pass + + @staticmethod + def symbolic(g: torch.Graph, data: torch.Value, position_ids: torch.Value, updates: torch.Value) -> torch.Value: + return g.onnxscript_op(CtxScatter, data, position_ids, updates).setTypeAs(data) + + +@onnxscript.script(onnxscript.values.Opset("com.qualcomm.cloud", 1)) +def CtxGather(data: onnxscript.FLOAT, ctx_indices: onnxscript.INT32) -> onnxscript.FLOAT: + ctx_indices = ops.Expand(ctx_indices, ops.Slice(ops.Shape(data), starts=[0], ends=[3], axes=[0])) + ctx_indices = ops.Unsqueeze(ctx_indices, [-1]) + return ops.GatherND(data, ctx_indices, batch_dims=2) + + +class CtxGatherFunc(torch.autograd.Function): + """ + Function to gather only the valid key values from KV-cache. + """ + + @staticmethod + def forward(data: torch.Tensor, ctx_indices: torch.Tensor): + batch_indices = torch.arange(data.shape[0]).view(-1, 1, 1) + head_indices = torch.arange(data.shape[1]).view(1, -1, 1) + return data[batch_indices, head_indices, ctx_indices] + + @staticmethod + def setup_context(ctx, inputs, outputs): + pass + + @staticmethod + def symbolic(g: torch.Graph, data: torch.Value, ctx_indices: torch.Value) -> torch.Value: + return g.onnxscript_op(CtxGather, data, ctx_indices).setTypeAs(data) diff --git a/QEfficient/customop/rms_norm.py b/QEfficient/customop/rms_norm.py index d7f4fd2d..210cca68 100644 --- a/QEfficient/customop/rms_norm.py +++ b/QEfficient/customop/rms_norm.py @@ -1,61 +1,51 @@ # ----------------------------------------------------------------------------- # -# Copyright (c) 2023-2024 Qualcomm Innovation Center, Inc. All rights reserved. +# Copyright (c) 2023-2024 Qualcomm Innovation Center, Inc. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause # # ----------------------------------------------------------------------------- -""" -RMS Norm CustomOp Node in QAic Domain for Cloud AI 100 -This is to handle the FP16 Overflow seen in RMS Norm for LLMs -""" - +import onnxscript import torch -from torch.onnx.symbolic_helper import parse_args - -op_source = """ -#include - -torch::Tensor custom_rms_norm(torch::Tensor hidden_states, torch::Tensor weight, double eps) { - torch::Tensor output; - torch::Tensor variance; - bool keepdim; - // double eps = 1e-5; - variance = hidden_states.pow(2).mean(-1, keepdim=true); - output = hidden_states * torch::rsqrt(variance + eps); - output = output * weight; - return output; -} - -TORCH_LIBRARY(QAic, m) { - m.def("QEffCustomRMSNorm", &custom_rms_norm); -} -""" - -# Compile and load the custom op -torch.utils.cpp_extension.load_inline( - name="custom_rms_norm", - cpp_sources=op_source, - is_python_module=False, - verbose=True, -) - - -# Wrapper module for custom relu C++ op -class QEffCustomRMSNorm(torch.nn.Module): - def __init__(self, normalized_shape, eps=1e-5, device=None, dtype=None, **kwargs): - super().__init__() - self.weight = torch.nn.Parameter(torch.empty(normalized_shape, device=device, dtype=dtype)) - self.eps = eps +from torch import nn + +ops = onnxscript.opset13 + + +@onnxscript.script(onnxscript.values.Opset(domain="com.qti.aisw.onnx", version=1)) +def CustomRMSNorm(hidden_states: onnxscript.FLOAT, weight: onnxscript.FLOAT, epsilon: float): + weight = ops.Cast(weight, to=1) + variance = ops.ReduceMean(ops.Pow(hidden_states, 2), axes=[-1], keepdims=1) + epsilon = ops.Expand(epsilon, ops.Shape(variance)) + hidden_states = hidden_states * ops.Reciprocal(ops.Sqrt(variance + epsilon)) + return weight * hidden_states - def forward(self, hidden_states): - return torch.ops.QAic.QEffCustomRMSNorm(hidden_states, self.weight, self.eps) +class CustomRMSNormFunc(torch.autograd.Function): + @staticmethod + def forward(hidden_states: torch.Tensor, weight: torch.Tensor, epsilon: float): + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + epsilon) + return weight * hidden_states -# ONNX export symbolic helper -@parse_args("v", "v", "f") -def custom_rms_norm(g, hidden_states, weight, eps): - return g.op("QAic::QEffCustomRMSNorm", hidden_states, weight, eps_f=eps).setTypeAs(hidden_states) + @staticmethod + def setup_context(ctx, inputs, outputs): + pass + @staticmethod + def symbolic(g: torch.Graph, hidden_states: torch.Value, weight: torch.Value, epsilon: torch.Value) -> torch.Value: + return g.onnxscript_op(CustomRMSNorm, hidden_states, weight, epsilon_f=epsilon).setTypeAs(hidden_states) -torch.onnx.register_custom_op_symbolic("QAic::QEffCustomRMSNorm", custom_rms_norm, 1) + +class CustomRMSNormAIC(nn.Module): + """ + RMSNorm module that works by replacing the current module with compiler known custom-op. + """ + + def __init__(self, hidden_size, eps=1e-05): + super(CustomRMSNormAIC, self).__init__() + self.variance_epsilon = eps + self.weight = torch.nn.Parameter(torch.ones(hidden_size)) + + def forward(self, hidden_states): + return CustomRMSNormFunc.apply(hidden_states, self.weight, self.variance_epsilon) diff --git a/QEfficient/customop/rms_norm_native.py b/QEfficient/customop/rms_norm_native.py new file mode 100644 index 00000000..105359d7 --- /dev/null +++ b/QEfficient/customop/rms_norm_native.py @@ -0,0 +1,61 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) 2023-2024 Qualcomm Innovation Center, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +""" +RMS Norm CustomOp Node in QAic Domain for Cloud AI 100 +This is to handle the FP16 Overflow seen in RMS Norm for LLMs +""" + +import torch +from torch.onnx.symbolic_helper import parse_args + +op_source = """ +#include + +torch::Tensor custom_rms_norm(torch::Tensor hidden_states, torch::Tensor weight, double eps) { + torch::Tensor output; + torch::Tensor variance; + bool keepdim; + // double eps = 1e-5; + variance = hidden_states.pow(2).mean(-1, keepdim=true); + output = hidden_states * torch::rsqrt(variance + eps); + output = output * weight; + return output; +} + +TORCH_LIBRARY(QAic, m) { + m.def("QEffCustomRMSNorm", &custom_rms_norm); +} +""" + +# Compile and load the custom op +torch.utils.cpp_extension.load_inline( + name="custom_rms_norm", + cpp_sources=op_source, + is_python_module=False, + verbose=True, +) + + +# Wrapper module for custom relu C++ op +class QEffCustomRMSNorm(torch.nn.Module): + def __init__(self, normalized_shape, eps=1e-5, device=None, dtype=None, **kwargs): + super().__init__() + self.weight = torch.nn.Parameter(torch.empty(normalized_shape, device=device, dtype=dtype)) + self.eps = eps + + def forward(self, hidden_states): + return torch.ops.QAic.QEffCustomRMSNorm(hidden_states, self.weight, self.eps) + + +# ONNX export symbolic helper +@parse_args("v", "v", "f") +def custom_rms_norm(g, hidden_states, weight, eps): + return g.op("QAic::QEffCustomRMSNorm", hidden_states, weight, eps_f=eps).setTypeAs(hidden_states) + + +torch.onnx.register_custom_op_symbolic("QAic::QEffCustomRMSNorm", custom_rms_norm, 1) diff --git a/QEfficient/customop/custom_rms_op_config.yaml b/QEfficient/customop/rms_norm_native.yaml similarity index 90% rename from QEfficient/customop/custom_rms_op_config.yaml rename to QEfficient/customop/rms_norm_native.yaml index 446921b2..4c2fa450 100644 --- a/QEfficient/customop/custom_rms_op_config.yaml +++ b/QEfficient/customop/rms_norm_native.yaml @@ -1,6 +1,6 @@ # ----------------------------------------------------------------------------- # -#Copyright (c) 2023-2024 Qualcomm Innovation Center, Inc. All rights reserved. +#Copyright (c) 2023-2024 Qualcomm Innovation Center, Inc. All rights reserved. #SPDX-License-Identifier: BSD-3-Clause # # -----------------------------------------------------------------------------