-
Notifications
You must be signed in to change notification settings - Fork 26
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Refactor custom ops into proper file locations
- Loading branch information
Showing
5 changed files
with
185 additions
and
166 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
# ----------------------------------------------------------------------------- | ||
# | ||
# Copyright (c) 2023-2024 Qualcomm Innovation Center, Inc. All rights reserved. | ||
# SPDX-License-Identifier: BSD-3-Clause | ||
# | ||
# ----------------------------------------------------------------------------- | ||
|
||
import onnxscript | ||
import torch | ||
|
||
ops = onnxscript.opset13 | ||
|
||
|
||
@onnxscript.script(onnxscript.values.Opset("com.qualcomm.cloud", 1)) | ||
def CtxScatter(data: onnxscript.FLOAT, position_ids: onnxscript.INT32, updates: onnxscript.FLOAT) -> onnxscript.FLOAT: | ||
# Find dims | ||
batch_size = ops.Gather(ops.Shape(data), [0]) | ||
num_heads = ops.Gather(ops.Shape(data), [1]) | ||
seq_len = ops.Gather(ops.Shape(position_ids), [1]) | ||
|
||
# Expanded shape to create indices | ||
zero = ops.Constant(value_ints=[0]) | ||
one = ops.Constant(value_ints=[1]) | ||
exp_shape = ops.Concat(batch_size, num_heads, seq_len, one, axis=0) | ||
|
||
# Create indices | ||
batch_idx = ops.Expand(ops.Unsqueeze(ops.Range(zero, batch_size, one), [1, 2, 3]), exp_shape) | ||
head_idx = ops.Expand(ops.Unsqueeze(ops.Range(zero, num_heads, one), [0, 2, 3]), exp_shape) | ||
ctx_idx = ops.Expand(ops.Unsqueeze(position_ids, [1, 3]), exp_shape) | ||
indices = ops.Concat(batch_idx, head_idx, ctx_idx, axis=3) | ||
|
||
return ops.ScatterND(data, indices, updates) | ||
|
||
|
||
class CtxScatterFunc(torch.autograd.Function): | ||
""" | ||
Function to scatter the current key values into KV-cache. | ||
""" | ||
|
||
@staticmethod | ||
def forward(data: torch.Tensor, position_ids: torch.Tensor, updates: torch.Tensor): | ||
batch_idx = torch.arange(data.shape[0]).view(-1, 1, 1) | ||
head_idx = torch.arange(data.shape[1]).view(1, -1, 1) | ||
ctx_idx = position_ids.unsqueeze(1) | ||
data[batch_idx, head_idx, ctx_idx] = updates | ||
return data | ||
|
||
@staticmethod | ||
def setup_context(ctx, inputs, outputs): | ||
pass | ||
|
||
@staticmethod | ||
def symbolic(g: torch.Graph, data: torch.Value, position_ids: torch.Value, updates: torch.Value) -> torch.Value: | ||
return g.onnxscript_op(CtxScatter, data, position_ids, updates).setTypeAs(data) | ||
|
||
|
||
@onnxscript.script(onnxscript.values.Opset("com.qualcomm.cloud", 1)) | ||
def CtxGather(data: onnxscript.FLOAT, ctx_indices: onnxscript.INT32) -> onnxscript.FLOAT: | ||
ctx_indices = ops.Expand(ctx_indices, ops.Slice(ops.Shape(data), starts=[0], ends=[3], axes=[0])) | ||
ctx_indices = ops.Unsqueeze(ctx_indices, [-1]) | ||
return ops.GatherND(data, ctx_indices, batch_dims=2) | ||
|
||
|
||
class CtxGatherFunc(torch.autograd.Function): | ||
""" | ||
Function to gather only the valid key values from KV-cache. | ||
""" | ||
|
||
@staticmethod | ||
def forward(data: torch.Tensor, ctx_indices: torch.Tensor): | ||
batch_indices = torch.arange(data.shape[0]).view(-1, 1, 1) | ||
head_indices = torch.arange(data.shape[1]).view(1, -1, 1) | ||
return data[batch_indices, head_indices, ctx_indices] | ||
|
||
@staticmethod | ||
def setup_context(ctx, inputs, outputs): | ||
pass | ||
|
||
@staticmethod | ||
def symbolic(g: torch.Graph, data: torch.Value, ctx_indices: torch.Value) -> torch.Value: | ||
return g.onnxscript_op(CtxGather, data, ctx_indices).setTypeAs(data) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,61 +1,51 @@ | ||
# ----------------------------------------------------------------------------- | ||
# | ||
# Copyright (c) 2023-2024 Qualcomm Innovation Center, Inc. All rights reserved. | ||
# Copyright (c) 2023-2024 Qualcomm Innovation Center, Inc. All rights reserved. | ||
# SPDX-License-Identifier: BSD-3-Clause | ||
# | ||
# ----------------------------------------------------------------------------- | ||
|
||
""" | ||
RMS Norm CustomOp Node in QAic Domain for Cloud AI 100 | ||
This is to handle the FP16 Overflow seen in RMS Norm for LLMs | ||
""" | ||
|
||
import onnxscript | ||
import torch | ||
from torch.onnx.symbolic_helper import parse_args | ||
|
||
op_source = """ | ||
#include <torch/script.h> | ||
torch::Tensor custom_rms_norm(torch::Tensor hidden_states, torch::Tensor weight, double eps) { | ||
torch::Tensor output; | ||
torch::Tensor variance; | ||
bool keepdim; | ||
// double eps = 1e-5; | ||
variance = hidden_states.pow(2).mean(-1, keepdim=true); | ||
output = hidden_states * torch::rsqrt(variance + eps); | ||
output = output * weight; | ||
return output; | ||
} | ||
TORCH_LIBRARY(QAic, m) { | ||
m.def("QEffCustomRMSNorm", &custom_rms_norm); | ||
} | ||
""" | ||
|
||
# Compile and load the custom op | ||
torch.utils.cpp_extension.load_inline( | ||
name="custom_rms_norm", | ||
cpp_sources=op_source, | ||
is_python_module=False, | ||
verbose=True, | ||
) | ||
|
||
|
||
# Wrapper module for custom relu C++ op | ||
class QEffCustomRMSNorm(torch.nn.Module): | ||
def __init__(self, normalized_shape, eps=1e-5, device=None, dtype=None, **kwargs): | ||
super().__init__() | ||
self.weight = torch.nn.Parameter(torch.empty(normalized_shape, device=device, dtype=dtype)) | ||
self.eps = eps | ||
from torch import nn | ||
|
||
ops = onnxscript.opset13 | ||
|
||
|
||
@onnxscript.script(onnxscript.values.Opset(domain="com.qti.aisw.onnx", version=1)) | ||
def CustomRMSNorm(hidden_states: onnxscript.FLOAT, weight: onnxscript.FLOAT, epsilon: float): | ||
weight = ops.Cast(weight, to=1) | ||
variance = ops.ReduceMean(ops.Pow(hidden_states, 2), axes=[-1], keepdims=1) | ||
epsilon = ops.Expand(epsilon, ops.Shape(variance)) | ||
hidden_states = hidden_states * ops.Reciprocal(ops.Sqrt(variance + epsilon)) | ||
return weight * hidden_states | ||
|
||
def forward(self, hidden_states): | ||
return torch.ops.QAic.QEffCustomRMSNorm(hidden_states, self.weight, self.eps) | ||
|
||
class CustomRMSNormFunc(torch.autograd.Function): | ||
@staticmethod | ||
def forward(hidden_states: torch.Tensor, weight: torch.Tensor, epsilon: float): | ||
variance = hidden_states.pow(2).mean(-1, keepdim=True) | ||
hidden_states = hidden_states * torch.rsqrt(variance + epsilon) | ||
return weight * hidden_states | ||
|
||
# ONNX export symbolic helper | ||
@parse_args("v", "v", "f") | ||
def custom_rms_norm(g, hidden_states, weight, eps): | ||
return g.op("QAic::QEffCustomRMSNorm", hidden_states, weight, eps_f=eps).setTypeAs(hidden_states) | ||
@staticmethod | ||
def setup_context(ctx, inputs, outputs): | ||
pass | ||
|
||
@staticmethod | ||
def symbolic(g: torch.Graph, hidden_states: torch.Value, weight: torch.Value, epsilon: torch.Value) -> torch.Value: | ||
return g.onnxscript_op(CustomRMSNorm, hidden_states, weight, epsilon_f=epsilon).setTypeAs(hidden_states) | ||
|
||
torch.onnx.register_custom_op_symbolic("QAic::QEffCustomRMSNorm", custom_rms_norm, 1) | ||
|
||
class CustomRMSNormAIC(nn.Module): | ||
""" | ||
RMSNorm module that works by replacing the current module with compiler known custom-op. | ||
""" | ||
|
||
def __init__(self, hidden_size, eps=1e-05): | ||
super(CustomRMSNormAIC, self).__init__() | ||
self.variance_epsilon = eps | ||
self.weight = torch.nn.Parameter(torch.ones(hidden_size)) | ||
|
||
def forward(self, hidden_states): | ||
return CustomRMSNormFunc.apply(hidden_states, self.weight, self.variance_epsilon) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
# ----------------------------------------------------------------------------- | ||
# | ||
# Copyright (c) 2023-2024 Qualcomm Innovation Center, Inc. All rights reserved. | ||
# SPDX-License-Identifier: BSD-3-Clause | ||
# | ||
# ----------------------------------------------------------------------------- | ||
|
||
""" | ||
RMS Norm CustomOp Node in QAic Domain for Cloud AI 100 | ||
This is to handle the FP16 Overflow seen in RMS Norm for LLMs | ||
""" | ||
|
||
import torch | ||
from torch.onnx.symbolic_helper import parse_args | ||
|
||
op_source = """ | ||
#include <torch/script.h> | ||
torch::Tensor custom_rms_norm(torch::Tensor hidden_states, torch::Tensor weight, double eps) { | ||
torch::Tensor output; | ||
torch::Tensor variance; | ||
bool keepdim; | ||
// double eps = 1e-5; | ||
variance = hidden_states.pow(2).mean(-1, keepdim=true); | ||
output = hidden_states * torch::rsqrt(variance + eps); | ||
output = output * weight; | ||
return output; | ||
} | ||
TORCH_LIBRARY(QAic, m) { | ||
m.def("QEffCustomRMSNorm", &custom_rms_norm); | ||
} | ||
""" | ||
|
||
# Compile and load the custom op | ||
torch.utils.cpp_extension.load_inline( | ||
name="custom_rms_norm", | ||
cpp_sources=op_source, | ||
is_python_module=False, | ||
verbose=True, | ||
) | ||
|
||
|
||
# Wrapper module for custom relu C++ op | ||
class QEffCustomRMSNorm(torch.nn.Module): | ||
def __init__(self, normalized_shape, eps=1e-5, device=None, dtype=None, **kwargs): | ||
super().__init__() | ||
self.weight = torch.nn.Parameter(torch.empty(normalized_shape, device=device, dtype=dtype)) | ||
self.eps = eps | ||
|
||
def forward(self, hidden_states): | ||
return torch.ops.QAic.QEffCustomRMSNorm(hidden_states, self.weight, self.eps) | ||
|
||
|
||
# ONNX export symbolic helper | ||
@parse_args("v", "v", "f") | ||
def custom_rms_norm(g, hidden_states, weight, eps): | ||
return g.op("QAic::QEffCustomRMSNorm", hidden_states, weight, eps_f=eps).setTypeAs(hidden_states) | ||
|
||
|
||
torch.onnx.register_custom_op_symbolic("QAic::QEffCustomRMSNorm", custom_rms_norm, 1) |
2 changes: 1 addition & 1 deletion
2
...icient/customop/custom_rms_op_config.yaml → QEfficient/customop/rms_norm_native.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters