Skip to content

Commit

Permalink
Refactor custom ops into proper file locations
Browse files Browse the repository at this point in the history
  • Loading branch information
irajagop committed Jun 27, 2024
1 parent 1c8d0a3 commit 1081625
Show file tree
Hide file tree
Showing 5 changed files with 185 additions and 166 deletions.
119 changes: 3 additions & 116 deletions QEfficient/customop/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,120 +5,7 @@
#
# -----------------------------------------------------------------------------

"""
RMS Norm CustomOp Node in com.qti.aisw.onnx Domain for Cloud AI 100
This is to handle the FP16 Overflow seen in RMS Norm for LLMs
from QEfficient.customop.ctx_scatter_gather import CtxGatherFunc, CtxScatterFunc
from QEfficient.customop.rms_norm import CustomRMSNormAIC

"""

import onnxscript
import torch
from onnxscript.onnx_opset import opset13 as ops
from torch import nn

opset_version = 13
custom_opset = onnxscript.values.Opset(domain="com.qti.aisw.onnx", version=1)


# Version 1
@onnxscript.script(custom_opset)
def CustomRMSNorm(hidden_states: onnxscript.FLOAT, weight: onnxscript.FLOAT, epsilon: float):
weight = ops.Cast(weight, to=1)
variance = ops.ReduceMean(ops.Pow(hidden_states, 2), axes=[-1], keepdims=1)
epsilon = ops.Expand(epsilon, ops.Shape(variance))
hidden_states = hidden_states * ops.Reciprocal(ops.Sqrt(variance + epsilon))
return weight * hidden_states


class CustomRMSNormOp(torch.autograd.Function):
@staticmethod
def forward(hidden_states: torch.Tensor, weight: torch.Tensor, epsilon: float):
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + epsilon)
return weight * hidden_states

@staticmethod
def setup_context(ctx, inputs, outputs):
pass

@staticmethod
def symbolic(
g: torch.onnx._internal.jit_utils.GraphContext,
hidden_states: torch.Value,
weight: torch.Value,
epsilon: torch.Value,
) -> torch.Value:
return g.onnxscript_op(CustomRMSNorm, hidden_states, weight, epsilon_f=epsilon).setTypeAs(hidden_states)


class CustomRMSNormAIC(nn.Module):
def __init__(self, hidden_size, eps=1e-05):
super(CustomRMSNormAIC, self).__init__()
self.variance_epsilon = eps
self.weight = torch.nn.Parameter(torch.ones(hidden_size))

def forward(self, hidden_states):
output = CustomRMSNormOp.apply(hidden_states, self.weight, self.variance_epsilon)
return output


@onnxscript.script(onnxscript.values.Opset("com.qualcomm.cloud", 1))
def CtxScatter(data: onnxscript.FLOAT, position_ids: onnxscript.INT32, updates: onnxscript.FLOAT) -> onnxscript.FLOAT:
# Find dims
batch_size = ops.Gather(ops.Shape(data), [0])
num_heads = ops.Gather(ops.Shape(data), [1])
seq_len = ops.Gather(ops.Shape(position_ids), [1])

# Expanded shape to create indices
zero = ops.Constant(value_ints=[0])
one = ops.Constant(value_ints=[1])
exp_shape = ops.Concat(batch_size, num_heads, seq_len, one, axis=0)

# Create indices
batch_idx = ops.Expand(ops.Unsqueeze(ops.Range(zero, batch_size, one), [1, 2, 3]), exp_shape)
head_idx = ops.Expand(ops.Unsqueeze(ops.Range(zero, num_heads, one), [0, 2, 3]), exp_shape)
ctx_idx = ops.Expand(ops.Unsqueeze(position_ids, [1, 3]), exp_shape)
indices = ops.Concat(batch_idx, head_idx, ctx_idx, axis=3)

return ops.ScatterND(data, indices, updates)


class CtxScatterFunc(torch.autograd.Function):
@staticmethod
def forward(data: torch.Tensor, position_ids: torch.Tensor, updates: torch.Tensor):
batch_idx = torch.arange(data.shape[0]).view(-1, 1, 1)
head_idx = torch.arange(data.shape[1]).view(1, -1, 1)
ctx_idx = position_ids.unsqueeze(1)
data[batch_idx, head_idx, ctx_idx] = updates
return data

@staticmethod
def setup_context(ctx, inputs, outputs):
pass

@staticmethod
def symbolic(g: torch.Graph, data: torch.Value, position_ids: torch.Value, updates: torch.Value) -> torch.Value:
return g.onnxscript_op(CtxScatter, data, position_ids, updates).setTypeAs(data)


@onnxscript.script(onnxscript.values.Opset("com.qualcomm.cloud", 1))
def CtxGather(data: onnxscript.FLOAT, ctx_indices: onnxscript.INT32) -> onnxscript.FLOAT:
ctx_indices = ops.Expand(ctx_indices, ops.Slice(ops.Shape(data), starts=[0], ends=[3], axes=[0]))
ctx_indices = ops.Unsqueeze(ctx_indices, [-1])
return ops.GatherND(data, ctx_indices, batch_dims=2)


class CtxGatherFunc(torch.autograd.Function):
@staticmethod
def forward(data: torch.Tensor, ctx_indices: torch.Tensor):
batch_indices = torch.arange(data.shape[0]).view(-1, 1, 1)
head_indices = torch.arange(data.shape[1]).view(1, -1, 1)
return data[batch_indices, head_indices, ctx_indices]

@staticmethod
def setup_context(ctx, inputs, outputs):
pass

@staticmethod
def symbolic(g: torch.Graph, data: torch.Value, ctx_indices: torch.Value) -> torch.Value:
return g.onnxscript_op(CtxGather, data, ctx_indices).setTypeAs(data)
__all__ = ["CtxGatherFunc", "CtxScatterFunc", "CustomRMSNormAIC"]
81 changes: 81 additions & 0 deletions QEfficient/customop/ctx_scatter_gather.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# -----------------------------------------------------------------------------
#
# Copyright (c) 2023-2024 Qualcomm Innovation Center, Inc. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# -----------------------------------------------------------------------------

import onnxscript
import torch

ops = onnxscript.opset13


@onnxscript.script(onnxscript.values.Opset("com.qualcomm.cloud", 1))
def CtxScatter(data: onnxscript.FLOAT, position_ids: onnxscript.INT32, updates: onnxscript.FLOAT) -> onnxscript.FLOAT:
# Find dims
batch_size = ops.Gather(ops.Shape(data), [0])
num_heads = ops.Gather(ops.Shape(data), [1])
seq_len = ops.Gather(ops.Shape(position_ids), [1])

# Expanded shape to create indices
zero = ops.Constant(value_ints=[0])
one = ops.Constant(value_ints=[1])
exp_shape = ops.Concat(batch_size, num_heads, seq_len, one, axis=0)

# Create indices
batch_idx = ops.Expand(ops.Unsqueeze(ops.Range(zero, batch_size, one), [1, 2, 3]), exp_shape)
head_idx = ops.Expand(ops.Unsqueeze(ops.Range(zero, num_heads, one), [0, 2, 3]), exp_shape)
ctx_idx = ops.Expand(ops.Unsqueeze(position_ids, [1, 3]), exp_shape)
indices = ops.Concat(batch_idx, head_idx, ctx_idx, axis=3)

return ops.ScatterND(data, indices, updates)


class CtxScatterFunc(torch.autograd.Function):
"""
Function to scatter the current key values into KV-cache.
"""

@staticmethod
def forward(data: torch.Tensor, position_ids: torch.Tensor, updates: torch.Tensor):
batch_idx = torch.arange(data.shape[0]).view(-1, 1, 1)
head_idx = torch.arange(data.shape[1]).view(1, -1, 1)
ctx_idx = position_ids.unsqueeze(1)
data[batch_idx, head_idx, ctx_idx] = updates
return data

@staticmethod
def setup_context(ctx, inputs, outputs):
pass

@staticmethod
def symbolic(g: torch.Graph, data: torch.Value, position_ids: torch.Value, updates: torch.Value) -> torch.Value:
return g.onnxscript_op(CtxScatter, data, position_ids, updates).setTypeAs(data)


@onnxscript.script(onnxscript.values.Opset("com.qualcomm.cloud", 1))
def CtxGather(data: onnxscript.FLOAT, ctx_indices: onnxscript.INT32) -> onnxscript.FLOAT:
ctx_indices = ops.Expand(ctx_indices, ops.Slice(ops.Shape(data), starts=[0], ends=[3], axes=[0]))
ctx_indices = ops.Unsqueeze(ctx_indices, [-1])
return ops.GatherND(data, ctx_indices, batch_dims=2)


class CtxGatherFunc(torch.autograd.Function):
"""
Function to gather only the valid key values from KV-cache.
"""

@staticmethod
def forward(data: torch.Tensor, ctx_indices: torch.Tensor):
batch_indices = torch.arange(data.shape[0]).view(-1, 1, 1)
head_indices = torch.arange(data.shape[1]).view(1, -1, 1)
return data[batch_indices, head_indices, ctx_indices]

@staticmethod
def setup_context(ctx, inputs, outputs):
pass

@staticmethod
def symbolic(g: torch.Graph, data: torch.Value, ctx_indices: torch.Value) -> torch.Value:
return g.onnxscript_op(CtxGather, data, ctx_indices).setTypeAs(data)
88 changes: 39 additions & 49 deletions QEfficient/customop/rms_norm.py
Original file line number Diff line number Diff line change
@@ -1,61 +1,51 @@
# -----------------------------------------------------------------------------
#
# Copyright (c) 2023-2024 Qualcomm Innovation Center, Inc. All rights reserved.
# Copyright (c) 2023-2024 Qualcomm Innovation Center, Inc. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# -----------------------------------------------------------------------------

"""
RMS Norm CustomOp Node in QAic Domain for Cloud AI 100
This is to handle the FP16 Overflow seen in RMS Norm for LLMs
"""

import onnxscript
import torch
from torch.onnx.symbolic_helper import parse_args

op_source = """
#include <torch/script.h>
torch::Tensor custom_rms_norm(torch::Tensor hidden_states, torch::Tensor weight, double eps) {
torch::Tensor output;
torch::Tensor variance;
bool keepdim;
// double eps = 1e-5;
variance = hidden_states.pow(2).mean(-1, keepdim=true);
output = hidden_states * torch::rsqrt(variance + eps);
output = output * weight;
return output;
}
TORCH_LIBRARY(QAic, m) {
m.def("QEffCustomRMSNorm", &custom_rms_norm);
}
"""

# Compile and load the custom op
torch.utils.cpp_extension.load_inline(
name="custom_rms_norm",
cpp_sources=op_source,
is_python_module=False,
verbose=True,
)


# Wrapper module for custom relu C++ op
class QEffCustomRMSNorm(torch.nn.Module):
def __init__(self, normalized_shape, eps=1e-5, device=None, dtype=None, **kwargs):
super().__init__()
self.weight = torch.nn.Parameter(torch.empty(normalized_shape, device=device, dtype=dtype))
self.eps = eps
from torch import nn

ops = onnxscript.opset13


@onnxscript.script(onnxscript.values.Opset(domain="com.qti.aisw.onnx", version=1))
def CustomRMSNorm(hidden_states: onnxscript.FLOAT, weight: onnxscript.FLOAT, epsilon: float):
weight = ops.Cast(weight, to=1)
variance = ops.ReduceMean(ops.Pow(hidden_states, 2), axes=[-1], keepdims=1)
epsilon = ops.Expand(epsilon, ops.Shape(variance))
hidden_states = hidden_states * ops.Reciprocal(ops.Sqrt(variance + epsilon))
return weight * hidden_states

def forward(self, hidden_states):
return torch.ops.QAic.QEffCustomRMSNorm(hidden_states, self.weight, self.eps)

class CustomRMSNormFunc(torch.autograd.Function):
@staticmethod
def forward(hidden_states: torch.Tensor, weight: torch.Tensor, epsilon: float):
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + epsilon)
return weight * hidden_states

# ONNX export symbolic helper
@parse_args("v", "v", "f")
def custom_rms_norm(g, hidden_states, weight, eps):
return g.op("QAic::QEffCustomRMSNorm", hidden_states, weight, eps_f=eps).setTypeAs(hidden_states)
@staticmethod
def setup_context(ctx, inputs, outputs):
pass

@staticmethod
def symbolic(g: torch.Graph, hidden_states: torch.Value, weight: torch.Value, epsilon: torch.Value) -> torch.Value:
return g.onnxscript_op(CustomRMSNorm, hidden_states, weight, epsilon_f=epsilon).setTypeAs(hidden_states)

torch.onnx.register_custom_op_symbolic("QAic::QEffCustomRMSNorm", custom_rms_norm, 1)

class CustomRMSNormAIC(nn.Module):
"""
RMSNorm module that works by replacing the current module with compiler known custom-op.
"""

def __init__(self, hidden_size, eps=1e-05):
super(CustomRMSNormAIC, self).__init__()
self.variance_epsilon = eps
self.weight = torch.nn.Parameter(torch.ones(hidden_size))

def forward(self, hidden_states):
return CustomRMSNormFunc.apply(hidden_states, self.weight, self.variance_epsilon)
61 changes: 61 additions & 0 deletions QEfficient/customop/rms_norm_native.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# -----------------------------------------------------------------------------
#
# Copyright (c) 2023-2024 Qualcomm Innovation Center, Inc. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# -----------------------------------------------------------------------------

"""
RMS Norm CustomOp Node in QAic Domain for Cloud AI 100
This is to handle the FP16 Overflow seen in RMS Norm for LLMs
"""

import torch
from torch.onnx.symbolic_helper import parse_args

op_source = """
#include <torch/script.h>
torch::Tensor custom_rms_norm(torch::Tensor hidden_states, torch::Tensor weight, double eps) {
torch::Tensor output;
torch::Tensor variance;
bool keepdim;
// double eps = 1e-5;
variance = hidden_states.pow(2).mean(-1, keepdim=true);
output = hidden_states * torch::rsqrt(variance + eps);
output = output * weight;
return output;
}
TORCH_LIBRARY(QAic, m) {
m.def("QEffCustomRMSNorm", &custom_rms_norm);
}
"""

# Compile and load the custom op
torch.utils.cpp_extension.load_inline(
name="custom_rms_norm",
cpp_sources=op_source,
is_python_module=False,
verbose=True,
)


# Wrapper module for custom relu C++ op
class QEffCustomRMSNorm(torch.nn.Module):
def __init__(self, normalized_shape, eps=1e-5, device=None, dtype=None, **kwargs):
super().__init__()
self.weight = torch.nn.Parameter(torch.empty(normalized_shape, device=device, dtype=dtype))
self.eps = eps

def forward(self, hidden_states):
return torch.ops.QAic.QEffCustomRMSNorm(hidden_states, self.weight, self.eps)


# ONNX export symbolic helper
@parse_args("v", "v", "f")
def custom_rms_norm(g, hidden_states, weight, eps):
return g.op("QAic::QEffCustomRMSNorm", hidden_states, weight, eps_f=eps).setTypeAs(hidden_states)


torch.onnx.register_custom_op_symbolic("QAic::QEffCustomRMSNorm", custom_rms_norm, 1)
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# -----------------------------------------------------------------------------
#
#Copyright (c) 2023-2024 Qualcomm Innovation Center, Inc. All rights reserved.
#Copyright (c) 2023-2024 Qualcomm Innovation Center, Inc. All rights reserved.
#SPDX-License-Identifier: BSD-3-Clause
#
# -----------------------------------------------------------------------------
Expand Down

0 comments on commit 1081625

Please sign in to comment.