forked from quic/efficient-transformers
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Added transforms API base and tests Signed-off-by: Ilango Rajagopal <[email protected]> * transforms test make generic Signed-off-by: Ilango Rajagopal <[email protected]> * Refactor transforms heirarchy Signed-off-by: Ilango Rajagopal <[email protected]> * Remove redundant "Transform" suffix Signed-off-by: Ilango Rajagopal <[email protected]> * Remove unimplemented placeholder transforms Signed-off-by: Ilango Rajagopal <[email protected]> --------- Signed-off-by: Ilango Rajagopal <[email protected]>
- Loading branch information
Showing
3 changed files
with
132 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
# ----------------------------------------------------------------------------- | ||
# | ||
# Copyright (c) 2024 Qualcomm Innovation Center, Inc. All rights reserved. | ||
# SPDX-License-Identifier: BSD-3-Clause | ||
# | ||
# ---------------------------------------------------------------------------- | ||
|
||
from onnx import ModelProto | ||
|
||
|
||
class OnnxTransform: | ||
""" | ||
OnnxTransform is the base class for graph modifications on exported onnx. | ||
""" | ||
|
||
def __init__(self): | ||
raise TypeError("Transform classes are not to be instantiated. Directly use the `apply` method.") | ||
|
||
@classmethod | ||
def apply(cls, model: ModelProto) -> ModelProto: | ||
""" | ||
Override this class to apply a transformation. | ||
:param model: The model's ONNX graph to transform | ||
:returns: ONNX graph after applying the transform | ||
""" | ||
raise NotImplementedError("Use subclasses for ONNX transform") | ||
|
||
|
||
class FP16Clip(OnnxTransform): | ||
pass | ||
|
||
|
||
class SplitWeights(OnnxTransform): | ||
pass | ||
|
||
|
||
class LoraAdapters(OnnxTransform): | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
# ----------------------------------------------------------------------------- | ||
# | ||
# Copyright (c) 2024 Qualcomm Innovation Center, Inc. All rights reserved. | ||
# SPDX-License-Identifier: BSD-3-Clause | ||
# | ||
# ---------------------------------------------------------------------------- | ||
|
||
from typing import Dict, Type | ||
|
||
from torch import nn | ||
|
||
|
||
class PytorchTransform: | ||
""" | ||
PytorchTransform is the base class that can do any transformation to a given PyTorch module by overriding apply method. | ||
""" | ||
|
||
def __init__(self): | ||
raise TypeError("Transform classes are not to be instantiated. Directly use the `apply` method.") | ||
|
||
@classmethod | ||
def apply(cls, model: nn.Module) -> nn.Module: | ||
""" | ||
Override this class method to apply a transformation. | ||
:param model: The torch module to transform, this module may be tranformed in-place | ||
:returns: Torch module after applying the tranform | ||
""" | ||
raise NotImplementedError("Use subclasses for Pytorch transform") | ||
|
||
|
||
class ModuleMapping(PytorchTransform): | ||
""" | ||
Replaces the PyTorch modules based on the _module_mapping class variable. | ||
""" | ||
|
||
_module_mapping: Dict[Type[nn.Module], Type[nn.Module]] | ||
|
||
@classmethod | ||
def apply(cls, model: nn.Module) -> nn.Module: | ||
for module in model.modules(): | ||
if repl_module := cls._module_mapping.get(type(module)): | ||
module.__class__ = repl_module | ||
return model | ||
|
||
@classmethod | ||
def register(cls, from_module: type, to_module: type): | ||
""" | ||
Add a new module type in the module mapping for this transform. :: | ||
FlashAttention.register(LLamaAttention, LlamaFlashAttention) | ||
""" | ||
cls._module_mapping[from_module] = to_module |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
# ----------------------------------------------------------------------------- | ||
# | ||
# Copyright (c) 2024 Qualcomm Innovation Center, Inc. All rights reserved. | ||
# SPDX-License-Identifier: BSD-3-Clause | ||
# | ||
# ---------------------------------------------------------------------------- | ||
|
||
import pytest | ||
import torch | ||
from torch import nn | ||
|
||
from QEfficient.base.pytorch_transforms import ModuleMapping | ||
|
||
|
||
def test_module_mapping_transform(): | ||
with pytest.raises(TypeError): | ||
ModuleMapping() | ||
|
||
class TestTransform(ModuleMapping): | ||
_module_mapping = {nn.Linear: nn.Identity} | ||
|
||
class TestModel(nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
|
||
self.a = nn.Linear(32, 64) | ||
self.b = nn.Linear(64, 32) | ||
|
||
def forward(self, x): | ||
x = self.a(x) | ||
x = self.b(x) | ||
return x | ||
|
||
model = TestModel() | ||
x = torch.rand(1, 32) | ||
y1 = model(x) | ||
assert torch.any(y1 != x) | ||
|
||
model = TestTransform.apply(model) | ||
y2 = model(x) | ||
assert torch.all(y2 == x) |