Skip to content

Commit

Permalink
Refactor custom ops into proper file locations (quic#55)
Browse files Browse the repository at this point in the history
* Refactor custom ops into proper file locations

Signed-off-by: Ilango Rajagopal <[email protected]>

* Remove the old native RMSNorm code

Signed-off-by: Ilango Rajagopal <[email protected]>

---------

Signed-off-by: Ilango Rajagopal <[email protected]>
  • Loading branch information
irajagop authored Jul 10, 2024
1 parent a262cd3 commit 8265edf
Show file tree
Hide file tree
Showing 8 changed files with 123 additions and 379 deletions.
25 changes: 0 additions & 25 deletions QEfficient/customop/CMakeLists.txt

This file was deleted.

49 changes: 0 additions & 49 deletions QEfficient/customop/CustomRMSNorm/src/customrmsnorm_aic.cpp

This file was deleted.

85 changes: 0 additions & 85 deletions QEfficient/customop/CustomRMSNorm/src/customrmsnorm_functions.cpp

This file was deleted.

20 changes: 0 additions & 20 deletions QEfficient/customop/README.md

This file was deleted.

119 changes: 3 additions & 116 deletions QEfficient/customop/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,120 +5,7 @@
#
# -----------------------------------------------------------------------------

"""
RMS Norm CustomOp Node in com.qti.aisw.onnx Domain for Cloud AI 100
This is to handle the FP16 Overflow seen in RMS Norm for LLMs
from QEfficient.customop.ctx_scatter_gather import CtxGatherFunc, CtxScatterFunc
from QEfficient.customop.rms_norm import CustomRMSNormAIC

"""

import onnxscript
import torch
from onnxscript.onnx_opset import opset13 as ops
from torch import nn

opset_version = 13
custom_opset = onnxscript.values.Opset(domain="com.qti.aisw.onnx", version=1)


# Version 1
@onnxscript.script(custom_opset)
def CustomRMSNorm(hidden_states: onnxscript.FLOAT, weight: onnxscript.FLOAT, epsilon: float):
weight = ops.Cast(weight, to=1)
variance = ops.ReduceMean(ops.Pow(hidden_states, 2), axes=[-1], keepdims=1)
epsilon = ops.Expand(epsilon, ops.Shape(variance))
hidden_states = hidden_states * ops.Reciprocal(ops.Sqrt(variance + epsilon))
return weight * hidden_states


class CustomRMSNormOp(torch.autograd.Function):
@staticmethod
def forward(hidden_states: torch.Tensor, weight: torch.Tensor, epsilon: float):
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + epsilon)
return weight * hidden_states

@staticmethod
def setup_context(ctx, inputs, outputs):
pass

@staticmethod
def symbolic(
g: torch.onnx._internal.jit_utils.GraphContext,
hidden_states: torch.Value,
weight: torch.Value,
epsilon: torch.Value,
) -> torch.Value:
return g.onnxscript_op(CustomRMSNorm, hidden_states, weight, epsilon_f=epsilon).setTypeAs(hidden_states)


class CustomRMSNormAIC(nn.Module):
def __init__(self, hidden_size, eps=1e-05):
super(CustomRMSNormAIC, self).__init__()
self.variance_epsilon = eps
self.weight = torch.nn.Parameter(torch.ones(hidden_size))

def forward(self, hidden_states):
output = CustomRMSNormOp.apply(hidden_states, self.weight, self.variance_epsilon)
return output


@onnxscript.script(onnxscript.values.Opset("com.qualcomm.cloud", 1))
def CtxScatter(data: onnxscript.FLOAT, position_ids: onnxscript.INT32, updates: onnxscript.FLOAT) -> onnxscript.FLOAT:
# Find dims
batch_size = ops.Gather(ops.Shape(data), [0])
num_heads = ops.Gather(ops.Shape(data), [1])
seq_len = ops.Gather(ops.Shape(position_ids), [1])

# Expanded shape to create indices
zero = ops.Constant(value_ints=[0])
one = ops.Constant(value_ints=[1])
exp_shape = ops.Concat(batch_size, num_heads, seq_len, one, axis=0)

# Create indices
batch_idx = ops.Expand(ops.Unsqueeze(ops.Range(zero, batch_size, one), [1, 2, 3]), exp_shape)
head_idx = ops.Expand(ops.Unsqueeze(ops.Range(zero, num_heads, one), [0, 2, 3]), exp_shape)
ctx_idx = ops.Expand(ops.Unsqueeze(position_ids, [1, 3]), exp_shape)
indices = ops.Concat(batch_idx, head_idx, ctx_idx, axis=3)

return ops.ScatterND(data, indices, updates)


class CtxScatterFunc(torch.autograd.Function):
@staticmethod
def forward(data: torch.Tensor, position_ids: torch.Tensor, updates: torch.Tensor):
batch_idx = torch.arange(data.shape[0]).view(-1, 1, 1)
head_idx = torch.arange(data.shape[1]).view(1, -1, 1)
ctx_idx = position_ids.unsqueeze(1)
data[batch_idx, head_idx, ctx_idx] = updates
return data

@staticmethod
def setup_context(ctx, inputs, outputs):
pass

@staticmethod
def symbolic(g: torch.Graph, data: torch.Value, position_ids: torch.Value, updates: torch.Value) -> torch.Value:
return g.onnxscript_op(CtxScatter, data, position_ids, updates).setTypeAs(data)


@onnxscript.script(onnxscript.values.Opset("com.qualcomm.cloud", 1))
def CtxGather(data: onnxscript.FLOAT, ctx_indices: onnxscript.INT32) -> onnxscript.FLOAT:
ctx_indices = ops.Expand(ctx_indices, ops.Slice(ops.Shape(data), starts=[0], ends=[3], axes=[0]))
ctx_indices = ops.Unsqueeze(ctx_indices, [-1])
return ops.GatherND(data, ctx_indices, batch_dims=2)


class CtxGatherFunc(torch.autograd.Function):
@staticmethod
def forward(data: torch.Tensor, ctx_indices: torch.Tensor):
batch_indices = torch.arange(data.shape[0]).view(-1, 1, 1)
head_indices = torch.arange(data.shape[1]).view(1, -1, 1)
return data[batch_indices, head_indices, ctx_indices]

@staticmethod
def setup_context(ctx, inputs, outputs):
pass

@staticmethod
def symbolic(g: torch.Graph, data: torch.Value, ctx_indices: torch.Value) -> torch.Value:
return g.onnxscript_op(CtxGather, data, ctx_indices).setTypeAs(data)
__all__ = ["CtxGatherFunc", "CtxScatterFunc", "CustomRMSNormAIC"]
81 changes: 81 additions & 0 deletions QEfficient/customop/ctx_scatter_gather.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# -----------------------------------------------------------------------------
#
# Copyright (c) 2023-2024 Qualcomm Innovation Center, Inc. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# -----------------------------------------------------------------------------

import onnxscript
import torch

ops = onnxscript.opset13


@onnxscript.script(onnxscript.values.Opset("com.qualcomm.cloud", 1))
def CtxScatter(data: onnxscript.FLOAT, position_ids: onnxscript.INT32, updates: onnxscript.FLOAT) -> onnxscript.FLOAT:
# Find dims
batch_size = ops.Gather(ops.Shape(data), [0])
num_heads = ops.Gather(ops.Shape(data), [1])
seq_len = ops.Gather(ops.Shape(position_ids), [1])

# Expanded shape to create indices
zero = ops.Constant(value_ints=[0])
one = ops.Constant(value_ints=[1])
exp_shape = ops.Concat(batch_size, num_heads, seq_len, one, axis=0)

# Create indices
batch_idx = ops.Expand(ops.Unsqueeze(ops.Range(zero, batch_size, one), [1, 2, 3]), exp_shape)
head_idx = ops.Expand(ops.Unsqueeze(ops.Range(zero, num_heads, one), [0, 2, 3]), exp_shape)
ctx_idx = ops.Expand(ops.Unsqueeze(position_ids, [1, 3]), exp_shape)
indices = ops.Concat(batch_idx, head_idx, ctx_idx, axis=3)

return ops.ScatterND(data, indices, updates)


class CtxScatterFunc(torch.autograd.Function):
"""
Function to scatter the current key values into KV-cache.
"""

@staticmethod
def forward(data: torch.Tensor, position_ids: torch.Tensor, updates: torch.Tensor):
batch_idx = torch.arange(data.shape[0]).view(-1, 1, 1)
head_idx = torch.arange(data.shape[1]).view(1, -1, 1)
ctx_idx = position_ids.unsqueeze(1)
data[batch_idx, head_idx, ctx_idx] = updates
return data

@staticmethod
def setup_context(ctx, inputs, outputs):
pass

@staticmethod
def symbolic(g: torch.Graph, data: torch.Value, position_ids: torch.Value, updates: torch.Value) -> torch.Value:
return g.onnxscript_op(CtxScatter, data, position_ids, updates).setTypeAs(data)


@onnxscript.script(onnxscript.values.Opset("com.qualcomm.cloud", 1))
def CtxGather(data: onnxscript.FLOAT, ctx_indices: onnxscript.INT32) -> onnxscript.FLOAT:
ctx_indices = ops.Expand(ctx_indices, ops.Slice(ops.Shape(data), starts=[0], ends=[3], axes=[0]))
ctx_indices = ops.Unsqueeze(ctx_indices, [-1])
return ops.GatherND(data, ctx_indices, batch_dims=2)


class CtxGatherFunc(torch.autograd.Function):
"""
Function to gather only the valid key values from KV-cache.
"""

@staticmethod
def forward(data: torch.Tensor, ctx_indices: torch.Tensor):
batch_indices = torch.arange(data.shape[0]).view(-1, 1, 1)
head_indices = torch.arange(data.shape[1]).view(1, -1, 1)
return data[batch_indices, head_indices, ctx_indices]

@staticmethod
def setup_context(ctx, inputs, outputs):
pass

@staticmethod
def symbolic(g: torch.Graph, data: torch.Value, ctx_indices: torch.Value) -> torch.Value:
return g.onnxscript_op(CtxGather, data, ctx_indices).setTypeAs(data)
Loading

0 comments on commit 8265edf

Please sign in to comment.