Skip to content

Commit

Permalink
transforms API base (quic#60)
Browse files Browse the repository at this point in the history
* Added transforms API base and tests

Signed-off-by: Ilango Rajagopal <[email protected]>

* transforms test make generic

Signed-off-by: Ilango Rajagopal <[email protected]>

* Refactor transforms heirarchy

Signed-off-by: Ilango Rajagopal <[email protected]>

* Remove redundant "Transform" suffix

Signed-off-by: Ilango Rajagopal <[email protected]>

* Remove unimplemented placeholder transforms

Signed-off-by: Ilango Rajagopal <[email protected]>

---------

Signed-off-by: Ilango Rajagopal <[email protected]>
  • Loading branch information
irajagop authored Jul 19, 2024
1 parent 2c44851 commit d7efe77
Show file tree
Hide file tree
Showing 3 changed files with 132 additions and 0 deletions.
39 changes: 39 additions & 0 deletions QEfficient/base/onnx_transforms.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# -----------------------------------------------------------------------------
#
# Copyright (c) 2024 Qualcomm Innovation Center, Inc. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# ----------------------------------------------------------------------------

from onnx import ModelProto


class OnnxTransform:
"""
OnnxTransform is the base class for graph modifications on exported onnx.
"""

def __init__(self):
raise TypeError("Transform classes are not to be instantiated. Directly use the `apply` method.")

@classmethod
def apply(cls, model: ModelProto) -> ModelProto:
"""
Override this class to apply a transformation.
:param model: The model's ONNX graph to transform
:returns: ONNX graph after applying the transform
"""
raise NotImplementedError("Use subclasses for ONNX transform")


class FP16Clip(OnnxTransform):
pass


class SplitWeights(OnnxTransform):
pass


class LoraAdapters(OnnxTransform):
pass
52 changes: 52 additions & 0 deletions QEfficient/base/pytorch_transforms.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# -----------------------------------------------------------------------------
#
# Copyright (c) 2024 Qualcomm Innovation Center, Inc. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# ----------------------------------------------------------------------------

from typing import Dict, Type

from torch import nn


class PytorchTransform:
"""
PytorchTransform is the base class that can do any transformation to a given PyTorch module by overriding apply method.
"""

def __init__(self):
raise TypeError("Transform classes are not to be instantiated. Directly use the `apply` method.")

@classmethod
def apply(cls, model: nn.Module) -> nn.Module:
"""
Override this class method to apply a transformation.
:param model: The torch module to transform, this module may be tranformed in-place
:returns: Torch module after applying the tranform
"""
raise NotImplementedError("Use subclasses for Pytorch transform")


class ModuleMapping(PytorchTransform):
"""
Replaces the PyTorch modules based on the _module_mapping class variable.
"""

_module_mapping: Dict[Type[nn.Module], Type[nn.Module]]

@classmethod
def apply(cls, model: nn.Module) -> nn.Module:
for module in model.modules():
if repl_module := cls._module_mapping.get(type(module)):
module.__class__ = repl_module
return model

@classmethod
def register(cls, from_module: type, to_module: type):
"""
Add a new module type in the module mapping for this transform. ::
FlashAttention.register(LLamaAttention, LlamaFlashAttention)
"""
cls._module_mapping[from_module] = to_module
41 changes: 41 additions & 0 deletions tests/base/test_pytorch_transforms.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# -----------------------------------------------------------------------------
#
# Copyright (c) 2024 Qualcomm Innovation Center, Inc. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# ----------------------------------------------------------------------------

import pytest
import torch
from torch import nn

from QEfficient.base.pytorch_transforms import ModuleMapping


def test_module_mapping_transform():
with pytest.raises(TypeError):
ModuleMapping()

class TestTransform(ModuleMapping):
_module_mapping = {nn.Linear: nn.Identity}

class TestModel(nn.Module):
def __init__(self):
super().__init__()

self.a = nn.Linear(32, 64)
self.b = nn.Linear(64, 32)

def forward(self, x):
x = self.a(x)
x = self.b(x)
return x

model = TestModel()
x = torch.rand(1, 32)
y1 = model(x)
assert torch.any(y1 != x)

model = TestTransform.apply(model)
y2 = model(x)
assert torch.all(y2 == x)

0 comments on commit d7efe77

Please sign in to comment.